add autodl
This commit is contained in:
		
							
								
								
									
										76
									
								
								AutoDL-Projects/xautodl/nas_infer_model/DXYs/CifarNet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										76
									
								
								AutoDL-Projects/xautodl/nas_infer_model/DXYs/CifarNet.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,76 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from .construct_utils import drop_path | ||||
| from .head_utils      import CifarHEAD, AuxiliaryHeadCIFAR | ||||
| from .base_cells      import InferCell | ||||
|  | ||||
|  | ||||
| class NetworkCIFAR(nn.Module): | ||||
|  | ||||
|   def __init__(self, C, N, stem_multiplier, auxiliary, genotype, num_classes): | ||||
|     super(NetworkCIFAR, self).__init__() | ||||
|     self._C               = C | ||||
|     self._layerN          = N | ||||
|     self._stem_multiplier = stem_multiplier | ||||
|  | ||||
|     C_curr = self._stem_multiplier * C | ||||
|     self.stem = CifarHEAD(C_curr) | ||||
|    | ||||
|     layer_channels   = [C    ] * N + [C*2 ] + [C*2  ] * N + [C*4 ] + [C*4  ] * N     | ||||
|     layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|     block_indexs     = [0    ] * N + [-1  ] + [1    ] * N + [-1  ] + [2    ] * N | ||||
|     block2index      = {0:[], 1:[], 2:[]} | ||||
|  | ||||
|     C_prev_prev, C_prev, C_curr = C_curr, C_curr, C | ||||
|     reduction_prev, spatial, dims = False, 1, [] | ||||
|     self.auxiliary_index = None | ||||
|     self.auxiliary_head  = None | ||||
|     self.cells = nn.ModuleList() | ||||
|     for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): | ||||
|       cell = InferCell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) | ||||
|       reduction_prev = reduction | ||||
|       self.cells.append( cell ) | ||||
|       C_prev_prev, C_prev = C_prev, cell._multiplier*C_curr | ||||
|       if reduction and C_curr == C*4: | ||||
|         if auxiliary: | ||||
|           self.auxiliary_head = AuxiliaryHeadCIFAR(C_prev, num_classes) | ||||
|           self.auxiliary_index = index | ||||
|  | ||||
|       if reduction: spatial *= 2 | ||||
|       dims.append( (C_prev, spatial) ) | ||||
|        | ||||
|     self._Layer= len(self.cells) | ||||
|  | ||||
|  | ||||
|     self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|     self.classifier = nn.Linear(C_prev, num_classes) | ||||
|     self.drop_path_prob = -1 | ||||
|  | ||||
|   def update_drop_path(self, drop_path_prob): | ||||
|     self.drop_path_prob = drop_path_prob | ||||
|  | ||||
|   def auxiliary_param(self): | ||||
|     if self.auxiliary_head is None: return [] | ||||
|     else: return list( self.auxiliary_head.parameters() ) | ||||
|  | ||||
|   def get_message(self): | ||||
|     return self.extra_repr() | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return ('{name}(C={_C}, N={_layerN}, L={_Layer}, stem={_stem_multiplier}, drop-path={drop_path_prob})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     stem_feature, logits_aux = self.stem(inputs), None | ||||
|     cell_results = [stem_feature, stem_feature] | ||||
|     for i, cell in enumerate(self.cells): | ||||
|       cell_feature = cell(cell_results[-2], cell_results[-1], self.drop_path_prob) | ||||
|       cell_results.append( cell_feature ) | ||||
|  | ||||
|       if self.auxiliary_index is not None and i == self.auxiliary_index and self.training: | ||||
|         logits_aux = self.auxiliary_head( cell_results[-1] ) | ||||
|     out = self.global_pooling( cell_results[-1] ) | ||||
|     out = out.view(out.size(0), -1) | ||||
|     logits = self.classifier(out) | ||||
|  | ||||
|     if logits_aux is None: return out, logits | ||||
|     else                 : return out, [logits, logits_aux] | ||||
							
								
								
									
										77
									
								
								AutoDL-Projects/xautodl/nas_infer_model/DXYs/ImageNet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								AutoDL-Projects/xautodl/nas_infer_model/DXYs/ImageNet.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,77 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from .construct_utils import drop_path | ||||
| from .base_cells import InferCell | ||||
| from .head_utils import ImageNetHEAD, AuxiliaryHeadImageNet | ||||
|  | ||||
|  | ||||
| class NetworkImageNet(nn.Module): | ||||
|  | ||||
|   def __init__(self, C, N, auxiliary, genotype, num_classes): | ||||
|     super(NetworkImageNet, self).__init__() | ||||
|     self._C          = C | ||||
|     self._layerN     = N | ||||
|     layer_channels   = [C    ] * N + [C*2 ] + [C*2  ] * N + [C*4 ] + [C*4] * N | ||||
|     layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|     self.stem0 = nn.Sequential( | ||||
|       nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False), | ||||
|       nn.BatchNorm2d(C // 2), | ||||
|       nn.ReLU(inplace=True), | ||||
|       nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False), | ||||
|       nn.BatchNorm2d(C), | ||||
|     ) | ||||
|  | ||||
|     self.stem1 = nn.Sequential( | ||||
|       nn.ReLU(inplace=True), | ||||
|       nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False), | ||||
|       nn.BatchNorm2d(C), | ||||
|     ) | ||||
|  | ||||
|     C_prev_prev, C_prev, C_curr, reduction_prev = C, C, C, True | ||||
|  | ||||
|     self.cells = nn.ModuleList() | ||||
|     self.auxiliary_index = None | ||||
|     for i, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): | ||||
|       cell = InferCell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) | ||||
|       reduction_prev = reduction | ||||
|       self.cells += [cell] | ||||
|       C_prev_prev, C_prev = C_prev, cell._multiplier * C_curr | ||||
|       if reduction and C_curr == C*4: | ||||
|         C_to_auxiliary = C_prev | ||||
|         self.auxiliary_index = i | ||||
|    | ||||
|     self._NNN = len(self.cells) | ||||
|     if auxiliary: | ||||
|       self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes) | ||||
|     else: | ||||
|       self.auxiliary_head = None | ||||
|     self.global_pooling = nn.AvgPool2d(7) | ||||
|     self.classifier     = nn.Linear(C_prev, num_classes) | ||||
|     self.drop_path_prob = -1 | ||||
|  | ||||
|   def update_drop_path(self, drop_path_prob): | ||||
|     self.drop_path_prob = drop_path_prob | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return ('{name}(C={_C}, N=[{_layerN}, {_NNN}], aux-index={auxiliary_index}, drop-path={drop_path_prob})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|  | ||||
|   def get_message(self): | ||||
|     return self.extra_repr() | ||||
|  | ||||
|   def auxiliary_param(self): | ||||
|     if self.auxiliary_head is None: return [] | ||||
|     else: return list( self.auxiliary_head.parameters() ) | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     s0 = self.stem0(inputs) | ||||
|     s1 = self.stem1(s0) | ||||
|     logits_aux = None | ||||
|     for i, cell in enumerate(self.cells): | ||||
|       s0, s1 = s1, cell(s0, s1, self.drop_path_prob) | ||||
|       if i == self.auxiliary_index and self.auxiliary_head and self.training: | ||||
|         logits_aux = self.auxiliary_head(s1) | ||||
|     out = self.global_pooling(s1) | ||||
|     logits = self.classifier(out.view(out.size(0), -1)) | ||||
|  | ||||
|     if logits_aux is None: return out, logits | ||||
|     else                 : return out, [logits, logits_aux] | ||||
							
								
								
									
										5
									
								
								AutoDL-Projects/xautodl/nas_infer_model/DXYs/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								AutoDL-Projects/xautodl/nas_infer_model/DXYs/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| # Performance-Aware Template Network for One-Shot Neural Architecture Search | ||||
| from .CifarNet import NetworkCIFAR as CifarNet | ||||
| from .ImageNet import NetworkImageNet as ImageNet | ||||
| from .genotypes import Networks | ||||
| from .genotypes import build_genotype_from_dict | ||||
							
								
								
									
										173
									
								
								AutoDL-Projects/xautodl/nas_infer_model/DXYs/base_cells.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										173
									
								
								AutoDL-Projects/xautodl/nas_infer_model/DXYs/base_cells.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,173 @@ | ||||
| import math | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from .construct_utils import drop_path | ||||
| from ..operations import OPS, Identity, FactorizedReduce, ReLUConvBN | ||||
|  | ||||
|  | ||||
| class MixedOp(nn.Module): | ||||
|  | ||||
|   def __init__(self, C, stride, PRIMITIVES): | ||||
|     super(MixedOp, self).__init__() | ||||
|     self._ops = nn.ModuleList() | ||||
|     self.name2idx = {} | ||||
|     for idx, primitive in enumerate(PRIMITIVES): | ||||
|       op = OPS[primitive](C, C, stride, False) | ||||
|       self._ops.append(op) | ||||
|       assert primitive not in self.name2idx, '{:} has already in'.format(primitive) | ||||
|       self.name2idx[primitive] = idx | ||||
|  | ||||
|   def forward(self, x, weights, op_name): | ||||
|     if op_name is None: | ||||
|       if weights is None: | ||||
|         return [op(x) for op in self._ops] | ||||
|       else: | ||||
|         return sum(w * op(x) for w, op in zip(weights, self._ops)) | ||||
|     else: | ||||
|       op_index = self.name2idx[op_name] | ||||
|       return self._ops[op_index](x) | ||||
|  | ||||
|  | ||||
|  | ||||
| class SearchCell(nn.Module): | ||||
|  | ||||
|   def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, PRIMITIVES, use_residual): | ||||
|     super(SearchCell, self).__init__() | ||||
|     self.reduction  = reduction | ||||
|     self.PRIMITIVES = deepcopy(PRIMITIVES) | ||||
|    | ||||
|     if reduction_prev: | ||||
|       self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2, affine=False) | ||||
|     else: | ||||
|       self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False) | ||||
|     self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False) | ||||
|     self._steps        = steps | ||||
|     self._multiplier   = multiplier | ||||
|     self._use_residual = use_residual | ||||
|  | ||||
|     self._ops = nn.ModuleList() | ||||
|     for i in range(self._steps): | ||||
|       for j in range(2+i): | ||||
|         stride = 2 if reduction and j < 2 else 1 | ||||
|         op = MixedOp(C, stride, self.PRIMITIVES) | ||||
|         self._ops.append(op) | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return ('{name}(residual={_use_residual}, steps={_steps}, multiplier={_multiplier})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|  | ||||
|   def forward(self, S0, S1, weights, connect, adjacency, drop_prob, modes): | ||||
|     if modes[0] is None: | ||||
|       if modes[1] == 'normal': | ||||
|         output = self.__forwardBoth(S0, S1, weights, connect, adjacency, drop_prob) | ||||
|       elif modes[1] == 'only_W': | ||||
|         output = self.__forwardOnlyW(S0, S1, drop_prob) | ||||
|     else: | ||||
|       test_genotype = modes[0] | ||||
|       if self.reduction: operations, concats = test_genotype.reduce, test_genotype.reduce_concat | ||||
|       else             : operations, concats = test_genotype.normal, test_genotype.normal_concat | ||||
|       s0, s1 = self.preprocess0(S0), self.preprocess1(S1) | ||||
|       states, offset = [s0, s1], 0 | ||||
|       assert self._steps == len(operations), '{:} vs. {:}'.format(self._steps, len(operations)) | ||||
|       for i, (opA, opB) in enumerate(operations): | ||||
|         A = self._ops[offset + opA[1]](states[opA[1]], None, opA[0]) | ||||
|         B = self._ops[offset + opB[1]](states[opB[1]], None, opB[0]) | ||||
|         state = A + B | ||||
|         offset += len(states) | ||||
|         states.append(state) | ||||
|       output = torch.cat([states[i] for i in concats], dim=1) | ||||
|     if self._use_residual and S1.size() == output.size(): | ||||
|       return S1 + output | ||||
|     else: return output | ||||
|    | ||||
|   def __forwardBoth(self, S0, S1, weights, connect, adjacency, drop_prob): | ||||
|     s0, s1 = self.preprocess0(S0), self.preprocess1(S1) | ||||
|     states, offset = [s0, s1], 0 | ||||
|     for i in range(self._steps): | ||||
|       clist = [] | ||||
|       for j, h in enumerate(states): | ||||
|         x = self._ops[offset+j](h, weights[offset+j], None) | ||||
|         if self.training and drop_prob > 0.: | ||||
|           x = drop_path(x, math.pow(drop_prob, 1./len(states))) | ||||
|         clist.append( x ) | ||||
|       connection = torch.mm(connect['{:}'.format(i)], adjacency[i]).squeeze(0) | ||||
|       state = sum(w * node for w, node in zip(connection, clist)) | ||||
|       offset += len(states) | ||||
|       states.append(state) | ||||
|     return torch.cat(states[-self._multiplier:], dim=1) | ||||
|  | ||||
|   def __forwardOnlyW(self, S0, S1, drop_prob): | ||||
|     s0, s1 = self.preprocess0(S0), self.preprocess1(S1) | ||||
|     states, offset = [s0, s1], 0 | ||||
|     for i in range(self._steps): | ||||
|       clist = [] | ||||
|       for j, h in enumerate(states): | ||||
|         xs = self._ops[offset+j](h, None, None) | ||||
|         clist += xs | ||||
|       if self.training and drop_prob > 0.: | ||||
|         xlist = [drop_path(x, math.pow(drop_prob, 1./len(states))) for x in clist] | ||||
|       else: xlist = clist | ||||
|       state = sum(xlist) * 2 / len(xlist) | ||||
|       offset += len(states) | ||||
|       states.append(state) | ||||
|     return torch.cat(states[-self._multiplier:], dim=1) | ||||
|  | ||||
|  | ||||
|  | ||||
| class InferCell(nn.Module): | ||||
|  | ||||
|   def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev): | ||||
|     super(InferCell, self).__init__() | ||||
|     print(C_prev_prev, C_prev, C) | ||||
|  | ||||
|     if reduction_prev is None: | ||||
|       self.preprocess0 = Identity() | ||||
|     elif reduction_prev: | ||||
|       self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2) | ||||
|     else: | ||||
|       self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0) | ||||
|     self.preprocess1   = ReLUConvBN(C_prev, C, 1, 1, 0) | ||||
|      | ||||
|     if reduction: step_ops, concat = genotype.reduce, genotype.reduce_concat | ||||
|     else        : step_ops, concat = genotype.normal, genotype.normal_concat | ||||
|     self._steps        = len(step_ops) | ||||
|     self._concat       = concat | ||||
|     self._multiplier   = len(concat) | ||||
|     self._ops          = nn.ModuleList() | ||||
|     self._indices      = [] | ||||
|     for operations in step_ops: | ||||
|       for name, index in operations: | ||||
|         stride = 2 if reduction and index < 2 else 1 | ||||
|         if reduction_prev is None and index == 0: | ||||
|           op = OPS[name](C_prev_prev, C, stride, True) | ||||
|         else: | ||||
|           op = OPS[name](C          , C, stride, True) | ||||
|         self._ops.append( op ) | ||||
|         self._indices.append( index ) | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return ('{name}(steps={_steps}, concat={_concat})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|  | ||||
|   def forward(self, S0, S1, drop_prob): | ||||
|     s0 = self.preprocess0(S0) | ||||
|     s1 = self.preprocess1(S1) | ||||
|  | ||||
|     states = [s0, s1] | ||||
|     for i in range(self._steps): | ||||
|       h1 = states[self._indices[2*i]] | ||||
|       h2 = states[self._indices[2*i+1]] | ||||
|       op1 = self._ops[2*i] | ||||
|       op2 = self._ops[2*i+1] | ||||
|       h1 = op1(h1) | ||||
|       h2 = op2(h2) | ||||
|       if self.training and drop_prob > 0.: | ||||
|         if not isinstance(op1, Identity): | ||||
|           h1 = drop_path(h1, drop_prob) | ||||
|         if not isinstance(op2, Identity): | ||||
|           h2 = drop_path(h2, drop_prob) | ||||
|  | ||||
|       state = h1 + h2 | ||||
|       states += [state] | ||||
|     output = torch.cat([states[i] for i in self._concat], dim=1) | ||||
|     return output | ||||
| @@ -0,0 +1,60 @@ | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
|  | ||||
|  | ||||
| def drop_path(x, drop_prob): | ||||
|   if drop_prob > 0.: | ||||
|     keep_prob = 1. - drop_prob | ||||
|     mask = x.new_zeros(x.size(0), 1, 1, 1) | ||||
|     mask = mask.bernoulli_(keep_prob) | ||||
|     x = torch.div(x, keep_prob) | ||||
|     x.mul_(mask) | ||||
|   return x | ||||
|  | ||||
|  | ||||
| def return_alphas_str(basemodel): | ||||
|   if hasattr(basemodel, 'alphas_normal'): | ||||
|     string = 'normal [{:}] : \n-->>{:}'.format(basemodel.alphas_normal.size(), F.softmax(basemodel.alphas_normal, dim=-1) ) | ||||
|   else: string = '' | ||||
|   if hasattr(basemodel, 'alphas_reduce'): | ||||
|     string = string + '\nreduce : {:}'.format( F.softmax(basemodel.alphas_reduce, dim=-1) ) | ||||
|  | ||||
|   if hasattr(basemodel, 'get_adjacency'): | ||||
|     adjacency = basemodel.get_adjacency() | ||||
|     for i in range( len(adjacency) ): | ||||
|       weight = F.softmax( basemodel.connect_normal[str(i)], dim=-1 ) | ||||
|       adj = torch.mm(weight, adjacency[i]).view(-1) | ||||
|       adj = ['{:3.3f}'.format(x) for x in adj.cpu().tolist()] | ||||
|       string = string + '\nnormal--{:}-->{:}'.format(i, ', '.join(adj)) | ||||
|     for i in range( len(adjacency) ): | ||||
|       weight = F.softmax( basemodel.connect_reduce[str(i)], dim=-1 ) | ||||
|       adj = torch.mm(weight, adjacency[i]).view(-1) | ||||
|       adj = ['{:3.3f}'.format(x) for x in adj.cpu().tolist()] | ||||
|       string = string + '\nreduce--{:}-->{:}'.format(i, ', '.join(adj)) | ||||
|  | ||||
|   if hasattr(basemodel, 'alphas_connect'): | ||||
|     weight = F.softmax(basemodel.alphas_connect, dim=-1).cpu() | ||||
|     ZERO = ['{:.3f}'.format(x) for x in weight[:,0].tolist()] | ||||
|     IDEN = ['{:.3f}'.format(x) for x in weight[:,1].tolist()] | ||||
|     string = string + '\nconnect [{:}] : \n ->{:}\n ->{:}'.format( list(basemodel.alphas_connect.size()), ZERO, IDEN ) | ||||
|   else: | ||||
|     string = string + '\nconnect = None' | ||||
|    | ||||
|   if hasattr(basemodel, 'get_gcn_out'): | ||||
|     outputs = basemodel.get_gcn_out(True) | ||||
|     for i, output in enumerate(outputs): | ||||
|       string = string + '\nnormal:[{:}] : {:}'.format(i, F.softmax(output, dim=-1) ) | ||||
|  | ||||
|   return string | ||||
|  | ||||
|  | ||||
| def remove_duplicate_archs(all_archs): | ||||
|   archs = [] | ||||
|   str_archs = ['{:}'.format(x) for x in all_archs] | ||||
|   for i, arch_x in enumerate(str_archs): | ||||
|     choose = True | ||||
|     for j in range(i): | ||||
|       if arch_x == str_archs[j]: | ||||
|         choose = False; break | ||||
|     if choose: archs.append(all_archs[i]) | ||||
|   return archs | ||||
							
								
								
									
										182
									
								
								AutoDL-Projects/xautodl/nas_infer_model/DXYs/genotypes.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										182
									
								
								AutoDL-Projects/xautodl/nas_infer_model/DXYs/genotypes.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,182 @@ | ||||
| from collections import namedtuple | ||||
|  | ||||
| Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat connectN connects') | ||||
| #Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') | ||||
|  | ||||
| PRIMITIVES_small = [ | ||||
|     'max_pool_3x3', | ||||
|     'avg_pool_3x3', | ||||
|     'skip_connect', | ||||
|     'sep_conv_3x3', | ||||
|     'sep_conv_5x5', | ||||
|     'conv_3x1_1x3', | ||||
| ] | ||||
|  | ||||
| PRIMITIVES_large = [ | ||||
|     'max_pool_3x3', | ||||
|     'avg_pool_3x3', | ||||
|     'skip_connect', | ||||
|     'sep_conv_3x3', | ||||
|     'sep_conv_5x5', | ||||
|     'dil_conv_3x3', | ||||
|     'dil_conv_5x5', | ||||
|     'conv_3x1_1x3', | ||||
| ] | ||||
|  | ||||
| PRIMITIVES_huge = [ | ||||
|     'skip_connect', | ||||
|     'nor_conv_1x1', | ||||
|     'max_pool_3x3', | ||||
|     'avg_pool_3x3', | ||||
|     'nor_conv_3x3', | ||||
|     'sep_conv_3x3', | ||||
|     'dil_conv_3x3', | ||||
|     'conv_3x1_1x3', | ||||
|     'sep_conv_5x5', | ||||
|     'dil_conv_5x5', | ||||
|     'sep_conv_7x7', | ||||
|     'conv_7x1_1x7', | ||||
|     'att_squeeze', | ||||
| ] | ||||
|  | ||||
| PRIMITIVES = {'small': PRIMITIVES_small, | ||||
|               'large': PRIMITIVES_large, | ||||
|               'huge' : PRIMITIVES_huge} | ||||
|  | ||||
| NASNet = Genotype( | ||||
|   normal = [ | ||||
|     (('sep_conv_5x5', 1), ('sep_conv_3x3', 0)), | ||||
|     (('sep_conv_5x5', 0), ('sep_conv_3x3', 0)), | ||||
|     (('avg_pool_3x3', 1), ('skip_connect', 0)), | ||||
|     (('avg_pool_3x3', 0), ('avg_pool_3x3', 0)), | ||||
|     (('sep_conv_3x3', 1), ('skip_connect', 1)), | ||||
|   ], | ||||
|   normal_concat = [2, 3, 4, 5, 6], | ||||
|   reduce = [ | ||||
|     (('sep_conv_5x5', 1), ('sep_conv_7x7', 0)), | ||||
|     (('max_pool_3x3', 1), ('sep_conv_7x7', 0)), | ||||
|     (('avg_pool_3x3', 1), ('sep_conv_5x5', 0)), | ||||
|     (('skip_connect', 3), ('avg_pool_3x3', 2)), | ||||
|     (('sep_conv_3x3', 2), ('max_pool_3x3', 1)), | ||||
|   ], | ||||
|   reduce_concat = [4, 5, 6], | ||||
|   connectN=None, connects=None, | ||||
| ) | ||||
|  | ||||
| PNASNet = Genotype( | ||||
|   normal = [ | ||||
|     (('sep_conv_5x5', 0), ('max_pool_3x3', 0)), | ||||
|     (('sep_conv_7x7', 1), ('max_pool_3x3', 1)), | ||||
|     (('sep_conv_5x5', 1), ('sep_conv_3x3', 1)), | ||||
|     (('sep_conv_3x3', 4), ('max_pool_3x3', 1)), | ||||
|     (('sep_conv_3x3', 0), ('skip_connect', 1)), | ||||
|   ], | ||||
|   normal_concat = [2, 3, 4, 5, 6], | ||||
|   reduce = [ | ||||
|     (('sep_conv_5x5', 0), ('max_pool_3x3', 0)), | ||||
|     (('sep_conv_7x7', 1), ('max_pool_3x3', 1)), | ||||
|     (('sep_conv_5x5', 1), ('sep_conv_3x3', 1)), | ||||
|     (('sep_conv_3x3', 4), ('max_pool_3x3', 1)), | ||||
|     (('sep_conv_3x3', 0), ('skip_connect', 1)), | ||||
|   ], | ||||
|   reduce_concat = [2, 3, 4, 5, 6], | ||||
|   connectN=None, connects=None, | ||||
| ) | ||||
|  | ||||
|  | ||||
| DARTS_V1 = Genotype( | ||||
|   normal=[ | ||||
|     (('sep_conv_3x3', 1), ('sep_conv_3x3', 0)), # step 1 | ||||
|     (('skip_connect', 0), ('sep_conv_3x3', 1)), # step 2 | ||||
|     (('skip_connect', 0), ('sep_conv_3x3', 1)), # step 3 | ||||
|     (('sep_conv_3x3', 0), ('skip_connect', 2))  # step 4 | ||||
|   ], | ||||
|   normal_concat=[2, 3, 4, 5], | ||||
|   reduce=[ | ||||
|     (('max_pool_3x3', 0), ('max_pool_3x3', 1)), # step 1 | ||||
|     (('skip_connect', 2), ('max_pool_3x3', 0)), # step 2 | ||||
|     (('max_pool_3x3', 0), ('skip_connect', 2)), # step 3 | ||||
|     (('skip_connect', 2), ('avg_pool_3x3', 0))  # step 4 | ||||
|   ], | ||||
|   reduce_concat=[2, 3, 4, 5], | ||||
|   connectN=None, connects=None, | ||||
| ) | ||||
|  | ||||
| # DARTS: Differentiable Architecture Search, ICLR 2019 | ||||
| DARTS_V2 = Genotype( | ||||
|   normal=[ | ||||
|     (('sep_conv_3x3', 0), ('sep_conv_3x3', 1)), # step 1 | ||||
|     (('sep_conv_3x3', 0), ('sep_conv_3x3', 1)), # step 2 | ||||
|     (('sep_conv_3x3', 1), ('skip_connect', 0)), # step 3 | ||||
|     (('skip_connect', 0), ('dil_conv_3x3', 2))  # step 4 | ||||
|   ], | ||||
|   normal_concat=[2, 3, 4, 5], | ||||
|   reduce=[ | ||||
|     (('max_pool_3x3', 0), ('max_pool_3x3', 1)), # step 1 | ||||
|     (('skip_connect', 2), ('max_pool_3x3', 1)), # step 2 | ||||
|     (('max_pool_3x3', 0), ('skip_connect', 2)), # step 3 | ||||
|     (('skip_connect', 2), ('max_pool_3x3', 1))  # step 4 | ||||
|   ], | ||||
|   reduce_concat=[2, 3, 4, 5], | ||||
|   connectN=None, connects=None, | ||||
| ) | ||||
|  | ||||
|  | ||||
| # One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 | ||||
| SETN = Genotype( | ||||
|   normal=[ | ||||
|     (('skip_connect', 0), ('sep_conv_5x5', 1)), | ||||
|     (('sep_conv_5x5', 0), ('sep_conv_3x3', 1)), | ||||
|     (('sep_conv_5x5', 1), ('sep_conv_5x5', 3)), | ||||
|     (('max_pool_3x3', 1), ('conv_3x1_1x3', 4))], | ||||
|   normal_concat=[2, 3, 4, 5], | ||||
|   reduce=[ | ||||
|     (('sep_conv_3x3', 0), ('sep_conv_5x5', 1)), | ||||
|     (('avg_pool_3x3', 0), ('sep_conv_5x5', 1)), | ||||
|     (('avg_pool_3x3', 0), ('sep_conv_5x5', 1)), | ||||
|     (('avg_pool_3x3', 0), ('skip_connect', 1))], | ||||
|   reduce_concat=[2, 3, 4, 5], | ||||
|   connectN=None, connects=None | ||||
| ) | ||||
|  | ||||
|  | ||||
| # Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 | ||||
| GDAS_V1 = Genotype( | ||||
|   normal=[ | ||||
|     (('skip_connect', 0), ('skip_connect', 1)), | ||||
|     (('skip_connect', 0), ('sep_conv_5x5', 2)), | ||||
|     (('sep_conv_3x3', 3), ('skip_connect', 0)), | ||||
|     (('sep_conv_5x5', 4), ('sep_conv_3x3', 3))], | ||||
|   normal_concat=[2, 3, 4, 5], | ||||
|   reduce=[ | ||||
|     (('sep_conv_5x5', 0), ('sep_conv_3x3', 1)),  | ||||
|     (('sep_conv_5x5', 2), ('sep_conv_5x5', 1)), | ||||
|     (('dil_conv_5x5', 2), ('sep_conv_3x3', 1)), | ||||
|     (('sep_conv_5x5', 0), ('sep_conv_5x5', 1))], | ||||
|   reduce_concat=[2, 3, 4, 5], | ||||
|   connectN=None, connects=None | ||||
| ) | ||||
|  | ||||
|  | ||||
|  | ||||
| Networks = {'DARTS_V1': DARTS_V1, | ||||
|             'DARTS_V2': DARTS_V2, | ||||
|             'DARTS'   : DARTS_V2, | ||||
|             'NASNet'  : NASNet, | ||||
|             'GDAS_V1' : GDAS_V1, | ||||
|             'PNASNet' : PNASNet, | ||||
|             'SETN'    : SETN, | ||||
|            } | ||||
|  | ||||
| # This function will return a Genotype from a dict. | ||||
| def build_genotype_from_dict(xdict): | ||||
|   def remove_value(nodes): | ||||
|     return [tuple([(x[0], x[1]) for x in node]) for node in nodes] | ||||
|   genotype = Genotype( | ||||
|       normal=remove_value(xdict['normal']), | ||||
|       normal_concat=xdict['normal_concat'], | ||||
|       reduce=remove_value(xdict['reduce']), | ||||
|       reduce_concat=xdict['reduce_concat'], | ||||
|       connectN=None, connects=None | ||||
|       ) | ||||
|   return genotype | ||||
							
								
								
									
										71
									
								
								AutoDL-Projects/xautodl/nas_infer_model/DXYs/head_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								AutoDL-Projects/xautodl/nas_infer_model/DXYs/head_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,71 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
|  | ||||
| class ImageNetHEAD(nn.Sequential): | ||||
|     def __init__(self, C, stride=2): | ||||
|         super(ImageNetHEAD, self).__init__() | ||||
|         self.add_module( | ||||
|             "conv1", | ||||
|             nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False), | ||||
|         ) | ||||
|         self.add_module("bn1", nn.BatchNorm2d(C // 2)) | ||||
|         self.add_module("relu1", nn.ReLU(inplace=True)) | ||||
|         self.add_module( | ||||
|             "conv2", | ||||
|             nn.Conv2d(C // 2, C, kernel_size=3, stride=stride, padding=1, bias=False), | ||||
|         ) | ||||
|         self.add_module("bn2", nn.BatchNorm2d(C)) | ||||
|  | ||||
|  | ||||
| class CifarHEAD(nn.Sequential): | ||||
|     def __init__(self, C): | ||||
|         super(CifarHEAD, self).__init__() | ||||
|         self.add_module("conv", nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False)) | ||||
|         self.add_module("bn", nn.BatchNorm2d(C)) | ||||
|  | ||||
|  | ||||
| class AuxiliaryHeadCIFAR(nn.Module): | ||||
|     def __init__(self, C, num_classes): | ||||
|         """assuming input size 8x8""" | ||||
|         super(AuxiliaryHeadCIFAR, self).__init__() | ||||
|         self.features = nn.Sequential( | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.AvgPool2d( | ||||
|                 5, stride=3, padding=0, count_include_pad=False | ||||
|             ),  # image size = 2 x 2 | ||||
|             nn.Conv2d(C, 128, 1, bias=False), | ||||
|             nn.BatchNorm2d(128), | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.Conv2d(128, 768, 2, bias=False), | ||||
|             nn.BatchNorm2d(768), | ||||
|             nn.ReLU(inplace=True), | ||||
|         ) | ||||
|         self.classifier = nn.Linear(768, num_classes) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.features(x) | ||||
|         x = self.classifier(x.view(x.size(0), -1)) | ||||
|         return x | ||||
|  | ||||
|  | ||||
| class AuxiliaryHeadImageNet(nn.Module): | ||||
|     def __init__(self, C, num_classes): | ||||
|         """assuming input size 14x14""" | ||||
|         super(AuxiliaryHeadImageNet, self).__init__() | ||||
|         self.features = nn.Sequential( | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False), | ||||
|             nn.Conv2d(C, 128, 1, bias=False), | ||||
|             nn.BatchNorm2d(128), | ||||
|             nn.ReLU(inplace=True), | ||||
|             nn.Conv2d(128, 768, 2, bias=False), | ||||
|             nn.BatchNorm2d(768), | ||||
|             nn.ReLU(inplace=True), | ||||
|         ) | ||||
|         self.classifier = nn.Linear(768, num_classes) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.features(x) | ||||
|         x = self.classifier(x.view(x.size(0), -1)) | ||||
|         return x | ||||
							
								
								
									
										51
									
								
								AutoDL-Projects/xautodl/nas_infer_model/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								AutoDL-Projects/xautodl/nas_infer_model/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,51 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| # I write this package to make AutoDL-Projects to be compatible with the old GDAS projects. | ||||
| # Ideally, this package will be merged into lib/models/cell_infers in future. | ||||
| # Currently, this package is used to reproduce the results in GDAS (Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019). | ||||
| ################################################## | ||||
|  | ||||
| import os, torch | ||||
|  | ||||
|  | ||||
| def obtain_nas_infer_model(config, extra_model_path=None): | ||||
|  | ||||
|     if config.arch == "dxys": | ||||
|         from .DXYs import CifarNet, ImageNet, Networks | ||||
|         from .DXYs import build_genotype_from_dict | ||||
|  | ||||
|         if config.genotype is None: | ||||
|             if extra_model_path is not None and not os.path.isfile(extra_model_path): | ||||
|                 raise ValueError( | ||||
|                     "When genotype in confiig is None, extra_model_path must be set as a path instead of {:}".format( | ||||
|                         extra_model_path | ||||
|                     ) | ||||
|                 ) | ||||
|             xdata = torch.load(extra_model_path) | ||||
|             current_epoch = xdata["epoch"] | ||||
|             genotype_dict = xdata["genotypes"][current_epoch - 1] | ||||
|             genotype = build_genotype_from_dict(genotype_dict) | ||||
|         else: | ||||
|             genotype = Networks[config.genotype] | ||||
|         if config.dataset == "cifar": | ||||
|             return CifarNet( | ||||
|                 config.ichannel, | ||||
|                 config.layers, | ||||
|                 config.stem_multi, | ||||
|                 config.auxiliary, | ||||
|                 genotype, | ||||
|                 config.class_num, | ||||
|             ) | ||||
|         elif config.dataset == "imagenet": | ||||
|             return ImageNet( | ||||
|                 config.ichannel, | ||||
|                 config.layers, | ||||
|                 config.auxiliary, | ||||
|                 genotype, | ||||
|                 config.class_num, | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError("invalid dataset : {:}".format(config.dataset)) | ||||
|     else: | ||||
|         raise ValueError("invalid nas arch type : {:}".format(config.arch)) | ||||
							
								
								
									
										183
									
								
								AutoDL-Projects/xautodl/nas_infer_model/operations.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										183
									
								
								AutoDL-Projects/xautodl/nas_infer_model/operations.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,183 @@ | ||||
| ############################################################################################## | ||||
| # This code is copied and modified from Hanxiao Liu's work (https://github.com/quark0/darts) # | ||||
| ############################################################################################## | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| OPS = { | ||||
|   'none'         : lambda C_in, C_out, stride, affine: Zero(stride), | ||||
|   'avg_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'avg'), | ||||
|   'max_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'max'), | ||||
|   'nor_conv_7x7' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), affine), | ||||
|   'nor_conv_3x3' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), affine), | ||||
|   'nor_conv_1x1' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), affine), | ||||
|   'skip_connect' : lambda C_in, C_out, stride, affine: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine), | ||||
|   'sep_conv_3x3' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 3, stride, 1, affine=affine), | ||||
|   'sep_conv_5x5' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 5, stride, 2, affine=affine), | ||||
|   'sep_conv_7x7' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 7, stride, 3, affine=affine), | ||||
|   'dil_conv_3x3' : lambda C_in, C_out, stride, affine: DilConv(C_in, C_out, 3, stride, 2, 2, affine=affine), | ||||
|   'dil_conv_5x5' : lambda C_in, C_out, stride, affine: DilConv(C_in, C_out, 5, stride, 4, 2, affine=affine), | ||||
|   'conv_7x1_1x7' : lambda C_in, C_out, stride, affine: Conv717(C_in, C_out, stride, affine), | ||||
|   'conv_3x1_1x3' : lambda C_in, C_out, stride, affine: Conv313(C_in, C_out, stride, affine) | ||||
| } | ||||
|  | ||||
|  | ||||
| class POOLING(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, stride, mode): | ||||
|     super(POOLING, self).__init__() | ||||
|     if C_in == C_out: | ||||
|       self.preprocess = None | ||||
|     else: | ||||
|       self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0) | ||||
|     if mode == 'avg'  : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) | ||||
|     elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1) | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     if self.preprocess is not None: | ||||
|       x = self.preprocess(inputs) | ||||
|     else: x = inputs | ||||
|     return self.op(x) | ||||
|  | ||||
|  | ||||
| class Conv313(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, stride, affine): | ||||
|     super(Conv313, self).__init__() | ||||
|     self.op = nn.Sequential( | ||||
|       nn.ReLU(inplace=False), | ||||
|       nn.Conv2d(C_in , C_out, (1,3), stride=(1, stride), padding=(0, 1), bias=False), | ||||
|       nn.Conv2d(C_out, C_out, (3,1), stride=(stride, 1), padding=(1, 0), bias=False), | ||||
|       nn.BatchNorm2d(C_out, affine=affine) | ||||
|     ) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     return self.op(x) | ||||
|  | ||||
|  | ||||
| class Conv717(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, stride, affine): | ||||
|     super(Conv717, self).__init__() | ||||
|     self.op = nn.Sequential( | ||||
|       nn.ReLU(inplace=False), | ||||
|       nn.Conv2d(C_in , C_out, (1,7), stride=(1, stride), padding=(0, 3), bias=False), | ||||
|       nn.Conv2d(C_out, C_out, (7,1), stride=(stride, 1), padding=(3, 0), bias=False), | ||||
|       nn.BatchNorm2d(C_out, affine=affine) | ||||
|     ) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     return self.op(x) | ||||
|  | ||||
|  | ||||
| class ReLUConvBN(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): | ||||
|     super(ReLUConvBN, self).__init__() | ||||
|     self.op = nn.Sequential( | ||||
|       nn.ReLU(inplace=False), | ||||
|       nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False), | ||||
|       nn.BatchNorm2d(C_out, affine=affine) | ||||
|     ) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     return self.op(x) | ||||
|  | ||||
|  | ||||
| class DilConv(nn.Module): | ||||
|      | ||||
|   def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): | ||||
|     super(DilConv, self).__init__() | ||||
|     self.op = nn.Sequential( | ||||
|       nn.ReLU(inplace=False), | ||||
|       nn.Conv2d(C_in, C_in,  kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False), | ||||
|       nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), | ||||
|       nn.BatchNorm2d(C_out, affine=affine), | ||||
|       ) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     return self.op(x) | ||||
|  | ||||
|  | ||||
| class SepConv(nn.Module): | ||||
|      | ||||
|   def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): | ||||
|     super(SepConv, self).__init__() | ||||
|     self.op = nn.Sequential( | ||||
|       nn.ReLU(inplace=False), | ||||
|       nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False), | ||||
|       nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), | ||||
|       nn.BatchNorm2d(C_in, affine=affine), | ||||
|       nn.ReLU(inplace=False), | ||||
|       nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=     1, padding=padding, groups=C_in, bias=False), | ||||
|       nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), | ||||
|       nn.BatchNorm2d(C_out, affine=affine), | ||||
|       ) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     return self.op(x) | ||||
|  | ||||
|  | ||||
| class Identity(nn.Module): | ||||
|  | ||||
|   def __init__(self): | ||||
|     super(Identity, self).__init__() | ||||
|  | ||||
|   def forward(self, x): | ||||
|     return x | ||||
|  | ||||
|  | ||||
| class Zero(nn.Module): | ||||
|  | ||||
|   def __init__(self, stride): | ||||
|     super(Zero, self).__init__() | ||||
|     self.stride = stride | ||||
|  | ||||
|   def forward(self, x): | ||||
|     if self.stride == 1: | ||||
|       return x.mul(0.) | ||||
|     return x[:,:,::self.stride,::self.stride].mul(0.) | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return 'stride={stride}'.format(**self.__dict__) | ||||
|  | ||||
|  | ||||
| class FactorizedReduce(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, stride, affine=True): | ||||
|     super(FactorizedReduce, self).__init__() | ||||
|     self.stride = stride | ||||
|     self.C_in   = C_in   | ||||
|     self.C_out  = C_out   | ||||
|     self.relu   = nn.ReLU(inplace=False) | ||||
|     if stride == 2: | ||||
|       #assert C_out % 2 == 0, 'C_out : {:}'.format(C_out) | ||||
|       C_outs = [C_out // 2, C_out - C_out // 2] | ||||
|       self.convs = nn.ModuleList() | ||||
|       for i in range(2): | ||||
|         self.convs.append( nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False) ) | ||||
|       self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0) | ||||
|     elif stride == 4: | ||||
|       assert C_out % 4 == 0, 'C_out : {:}'.format(C_out) | ||||
|       self.convs = nn.ModuleList() | ||||
|       for i in range(4): | ||||
|         self.convs.append( nn.Conv2d(C_in, C_out // 4, 1, stride=stride, padding=0, bias=False) ) | ||||
|       self.pad = nn.ConstantPad2d((0, 3, 0, 3), 0) | ||||
|     else: | ||||
|       raise ValueError('Invalid stride : {:}'.format(stride)) | ||||
|      | ||||
|     self.bn = nn.BatchNorm2d(C_out, affine=affine) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     x = self.relu(x) | ||||
|     y = self.pad(x) | ||||
|     if self.stride == 2: | ||||
|       out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:,1:])], dim=1) | ||||
|     else: | ||||
|       out = torch.cat([self.convs[0](x),            self.convs[1](y[:,:,1:-2,1:-2]), | ||||
|                        self.convs[2](y[:,:,2:-1,2:-1]), self.convs[3](y[:,:,3:,3:])], dim=1) | ||||
|     out = self.bn(out) | ||||
|     return out | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__) | ||||
		Reference in New Issue
	
	Block a user