Add more algorithms
This commit is contained in:
		
							
								
								
									
										89
									
								
								others/GDAS/lib/nas/CifarNet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								others/GDAS/lib/nas/CifarNet.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,89 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from .construct_utils import Cell, Transition | ||||
|  | ||||
| class AuxiliaryHeadCIFAR(nn.Module): | ||||
|  | ||||
|   def __init__(self, C, num_classes): | ||||
|     """assuming input size 8x8""" | ||||
|     super(AuxiliaryHeadCIFAR, self).__init__() | ||||
|     self.features = nn.Sequential( | ||||
|       nn.ReLU(inplace=True), | ||||
|       nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2 | ||||
|       nn.Conv2d(C, 128, 1, bias=False), | ||||
|       nn.BatchNorm2d(128), | ||||
|       nn.ReLU(inplace=True), | ||||
|       nn.Conv2d(128, 768, 2, bias=False), | ||||
|       nn.BatchNorm2d(768), | ||||
|       nn.ReLU(inplace=True) | ||||
|     ) | ||||
|     self.classifier = nn.Linear(768, num_classes) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     x = self.features(x) | ||||
|     x = self.classifier(x.view(x.size(0),-1)) | ||||
|     return x | ||||
|  | ||||
|  | ||||
| class NetworkCIFAR(nn.Module): | ||||
|  | ||||
|   def __init__(self, C, num_classes, layers, auxiliary, genotype): | ||||
|     super(NetworkCIFAR, self).__init__() | ||||
|     self._layers = layers | ||||
|  | ||||
|     stem_multiplier = 3 | ||||
|     C_curr = stem_multiplier*C | ||||
|     self.stem = nn.Sequential( | ||||
|       nn.Conv2d(3, C_curr, 3, padding=1, bias=False), | ||||
|       nn.BatchNorm2d(C_curr) | ||||
|     ) | ||||
|      | ||||
|     C_prev_prev, C_prev, C_curr = C_curr, C_curr, C | ||||
|     self.cells = nn.ModuleList() | ||||
|     reduction_prev = False | ||||
|     for i in range(layers): | ||||
|       if i in [layers//3, 2*layers//3]: | ||||
|         C_curr *= 2 | ||||
|         reduction = True | ||||
|       else: | ||||
|         reduction = False | ||||
|       if reduction and genotype.reduce is None: | ||||
|         cell = Transition(C_prev_prev, C_prev, C_curr, reduction_prev) | ||||
|       else: | ||||
|         cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) | ||||
|       reduction_prev = reduction | ||||
|       self.cells.append( cell ) | ||||
|       C_prev_prev, C_prev = C_prev, cell.multiplier*C_curr | ||||
|       if i == 2*layers//3: | ||||
|         C_to_auxiliary = C_prev | ||||
|  | ||||
|     if auxiliary: | ||||
|       self.auxiliary_head = AuxiliaryHeadCIFAR(C_to_auxiliary, num_classes) | ||||
|     else: | ||||
|       self.auxiliary_head = None | ||||
|     self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|     self.classifier = nn.Linear(C_prev, num_classes) | ||||
|     self.drop_path_prob = -1 | ||||
|  | ||||
|   def update_drop_path(self, drop_path_prob): | ||||
|     self.drop_path_prob = drop_path_prob | ||||
|  | ||||
|   def auxiliary_param(self): | ||||
|     if self.auxiliary_head is None: return [] | ||||
|     else: return list( self.auxiliary_head.parameters() ) | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     s0 = s1 = self.stem(inputs) | ||||
|     for i, cell in enumerate(self.cells): | ||||
|       s0, s1 = s1, cell(s0, s1, self.drop_path_prob) | ||||
|       if i == 2*self._layers//3: | ||||
|         if self.auxiliary_head and self.training: | ||||
|           logits_aux = self.auxiliary_head(s1) | ||||
|     out = self.global_pooling(s1) | ||||
|     out = out.view(out.size(0), -1) | ||||
|     logits = self.classifier(out) | ||||
|  | ||||
|     if self.auxiliary_head and self.training: | ||||
|       return logits, logits_aux | ||||
|     else: | ||||
|       return logits | ||||
							
								
								
									
										104
									
								
								others/GDAS/lib/nas/ImageNet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										104
									
								
								others/GDAS/lib/nas/ImageNet.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,104 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from .construct_utils import Cell, Transition | ||||
|  | ||||
| class AuxiliaryHeadImageNet(nn.Module): | ||||
|  | ||||
|   def __init__(self, C, num_classes): | ||||
|     """assuming input size 14x14""" | ||||
|     super(AuxiliaryHeadImageNet, self).__init__() | ||||
|     self.features = nn.Sequential( | ||||
|       nn.ReLU(inplace=True), | ||||
|       nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False), | ||||
|       nn.Conv2d(C, 128, 1, bias=False), | ||||
|       nn.BatchNorm2d(128), | ||||
|       nn.ReLU(inplace=True), | ||||
|       nn.Conv2d(128, 768, 2, bias=False), | ||||
|       # NOTE: This batchnorm was omitted in my earlier implementation due to a typo. | ||||
|       # Commenting it out for consistency with the experiments in the paper. | ||||
|       # nn.BatchNorm2d(768), | ||||
|       nn.ReLU(inplace=True) | ||||
|     ) | ||||
|     self.classifier = nn.Linear(768, num_classes) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     x = self.features(x) | ||||
|     x = self.classifier(x.view(x.size(0),-1)) | ||||
|     return x | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| class NetworkImageNet(nn.Module): | ||||
|  | ||||
|   def __init__(self, C, num_classes, layers, auxiliary, genotype): | ||||
|     super(NetworkImageNet, self).__init__() | ||||
|     self._layers = layers | ||||
|  | ||||
|     self.stem0 = nn.Sequential( | ||||
|       nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False), | ||||
|       nn.BatchNorm2d(C // 2), | ||||
|       nn.ReLU(inplace=True), | ||||
|       nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False), | ||||
|       nn.BatchNorm2d(C), | ||||
|     ) | ||||
|  | ||||
|     self.stem1 = nn.Sequential( | ||||
|       nn.ReLU(inplace=True), | ||||
|       nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False), | ||||
|       nn.BatchNorm2d(C), | ||||
|     ) | ||||
|  | ||||
|     C_prev_prev, C_prev, C_curr = C, C, C | ||||
|  | ||||
|     self.cells = nn.ModuleList() | ||||
|     reduction_prev = True | ||||
|     for i in range(layers): | ||||
|       if i in [layers // 3, 2 * layers // 3]: | ||||
|         C_curr *= 2 | ||||
|         reduction = True | ||||
|       else: | ||||
|         reduction = False | ||||
|       if reduction and genotype.reduce is None: | ||||
|         cell = Transition(C_prev_prev, C_prev, C_curr, reduction_prev) | ||||
|       else: | ||||
|         cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) | ||||
|       reduction_prev = reduction | ||||
|       self.cells += [cell] | ||||
|       C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr | ||||
|       if i == 2 * layers // 3: | ||||
|         C_to_auxiliary = C_prev | ||||
|  | ||||
|     if auxiliary: | ||||
|       self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes) | ||||
|     else: | ||||
|       self.auxiliary_head = None | ||||
|     self.global_pooling = nn.AvgPool2d(7) | ||||
|     self.classifier = nn.Linear(C_prev, num_classes) | ||||
|     self.drop_path_prob = -1 | ||||
|  | ||||
|   def update_drop_path(self, drop_path_prob): | ||||
|     self.drop_path_prob = drop_path_prob | ||||
|  | ||||
|   def get_drop_path(self): | ||||
|     return self.drop_path_prob | ||||
|  | ||||
|   def auxiliary_param(self): | ||||
|     if self.auxiliary_head is None: return [] | ||||
|     else: return list( self.auxiliary_head.parameters() ) | ||||
|  | ||||
|   def forward(self, input): | ||||
|     s0 = self.stem0(input) | ||||
|     s1 = self.stem1(s0) | ||||
|     for i, cell in enumerate(self.cells): | ||||
|       s0, s1 = s1, cell(s0, s1, self.drop_path_prob) | ||||
|       #print ('{:} : {:} - {:}'.format(i, s0.size(), s1.size())) | ||||
|       if i == 2 * self._layers // 3: | ||||
|         if self.auxiliary_head and self.training: | ||||
|           logits_aux = self.auxiliary_head(s1) | ||||
|     out = self.global_pooling(s1) | ||||
|     logits = self.classifier(out.view(out.size(0), -1)) | ||||
|     if self.auxiliary_head and self.training: | ||||
|       return logits, logits_aux | ||||
|     else: | ||||
|       return logits | ||||
							
								
								
									
										27
									
								
								others/GDAS/lib/nas/SE_Module.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								others/GDAS/lib/nas/SE_Module.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,27 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| # Squeeze and Excitation module | ||||
|  | ||||
| class SqEx(nn.Module): | ||||
|  | ||||
|   def __init__(self, n_features, reduction=16): | ||||
|     super(SqEx, self).__init__() | ||||
|  | ||||
|     if n_features % reduction != 0: | ||||
|       raise ValueError('n_features must be divisible by reduction (default = 16)') | ||||
|  | ||||
|     self.linear1 = nn.Linear(n_features, n_features // reduction, bias=True) | ||||
|     self.nonlin1 = nn.ReLU(inplace=True) | ||||
|     self.linear2 = nn.Linear(n_features // reduction, n_features, bias=True) | ||||
|     self.nonlin2 = nn.Sigmoid() | ||||
|  | ||||
|   def forward(self, x): | ||||
|  | ||||
|     y = F.avg_pool2d(x, kernel_size=x.size()[2:4]) | ||||
|     y = y.permute(0, 2, 3, 1) | ||||
|     y = self.nonlin1(self.linear1(y)) | ||||
|     y = self.nonlin2(self.linear2(y)) | ||||
|     y = y.permute(0, 3, 1, 2) | ||||
|     y = x * y | ||||
|     return y | ||||
|  | ||||
							
								
								
									
										10
									
								
								others/GDAS/lib/nas/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								others/GDAS/lib/nas/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,10 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| from .CifarNet        import NetworkCIFAR | ||||
| from .ImageNet        import NetworkImageNet | ||||
|  | ||||
| # genotypes | ||||
| from .genotypes       import model_types | ||||
|  | ||||
| from .construct_utils import return_alphas_str | ||||
							
								
								
									
										152
									
								
								others/GDAS/lib/nas/construct_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										152
									
								
								others/GDAS/lib/nas/construct_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,152 @@ | ||||
| import random | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from .operations import OPS, FactorizedReduce, ReLUConvBN, Identity | ||||
|  | ||||
|  | ||||
| def random_select(length, ratio): | ||||
|   clist = [] | ||||
|   index = random.randint(0, length-1) | ||||
|   for i in range(length): | ||||
|     if i == index or random.random() < ratio: | ||||
|       clist.append( 1 ) | ||||
|     else: | ||||
|       clist.append( 0 ) | ||||
|   return clist | ||||
|  | ||||
|  | ||||
| def all_select(length): | ||||
|   return [1 for i in range(length)] | ||||
|  | ||||
|  | ||||
| def drop_path(x, drop_prob): | ||||
|   if drop_prob > 0.: | ||||
|     keep_prob = 1. - drop_prob | ||||
|     mask = x.new_zeros(x.size(0), 1, 1, 1) | ||||
|     mask = mask.bernoulli_(keep_prob) | ||||
|     x.div_(keep_prob) | ||||
|     x.mul_(mask) | ||||
|   return x | ||||
|  | ||||
|  | ||||
| def return_alphas_str(basemodel): | ||||
|   string = 'normal : {:}'.format( F.softmax(basemodel.alphas_normal, dim=-1) ) | ||||
|   if hasattr(basemodel, 'alphas_reduce'): | ||||
|     string = string + '\nreduce : {:}'.format( F.softmax(basemodel.alphas_reduce, dim=-1) ) | ||||
|   return string | ||||
|  | ||||
|  | ||||
| class Cell(nn.Module): | ||||
|  | ||||
|   def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev): | ||||
|     super(Cell, self).__init__() | ||||
|     print(C_prev_prev, C_prev, C) | ||||
|  | ||||
|     if reduction_prev: | ||||
|       self.preprocess0 = FactorizedReduce(C_prev_prev, C) | ||||
|     else: | ||||
|       self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0) | ||||
|     self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0) | ||||
|      | ||||
|     if reduction: | ||||
|       op_names, indices, values = zip(*genotype.reduce) | ||||
|       concat = genotype.reduce_concat | ||||
|     else: | ||||
|       op_names, indices, values = zip(*genotype.normal) | ||||
|       concat = genotype.normal_concat | ||||
|     self._compile(C, op_names, indices, values, concat, reduction) | ||||
|  | ||||
|   def _compile(self, C, op_names, indices, values, concat, reduction): | ||||
|     assert len(op_names) == len(indices) | ||||
|     self._steps = len(op_names) // 2 | ||||
|     self._concat = concat | ||||
|     self.multiplier = len(concat) | ||||
|  | ||||
|     self._ops = nn.ModuleList() | ||||
|     for name, index in zip(op_names, indices): | ||||
|       stride = 2 if reduction and index < 2 else 1 | ||||
|       op = OPS[name](C, stride, True) | ||||
|       self._ops.append( op ) | ||||
|     self._indices = indices | ||||
|     self._values  = values | ||||
|  | ||||
|   def forward(self, s0, s1, drop_prob): | ||||
|     s0 = self.preprocess0(s0) | ||||
|     s1 = self.preprocess1(s1) | ||||
|  | ||||
|     states = [s0, s1] | ||||
|     for i in range(self._steps): | ||||
|       h1 = states[self._indices[2*i]] | ||||
|       h2 = states[self._indices[2*i+1]] | ||||
|       op1 = self._ops[2*i] | ||||
|       op2 = self._ops[2*i+1] | ||||
|       h1 = op1(h1) | ||||
|       h2 = op2(h2) | ||||
|       if self.training and drop_prob > 0.: | ||||
|         if not isinstance(op1, Identity): | ||||
|           h1 = drop_path(h1, drop_prob) | ||||
|         if not isinstance(op2, Identity): | ||||
|           h2 = drop_path(h2, drop_prob) | ||||
|  | ||||
|       s = h1 + h2 | ||||
|  | ||||
|       states += [s] | ||||
|     return torch.cat([states[i] for i in self._concat], dim=1) | ||||
|  | ||||
|  | ||||
|  | ||||
| class Transition(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_prev_prev, C_prev, C, reduction_prev, multiplier=4): | ||||
|     super(Transition, self).__init__() | ||||
|     if reduction_prev: | ||||
|       self.preprocess0 = FactorizedReduce(C_prev_prev, C) | ||||
|     else: | ||||
|       self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0) | ||||
|     self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0) | ||||
|     self.multiplier  = multiplier | ||||
|  | ||||
|     self.reduction = True | ||||
|     self.ops1 = nn.ModuleList( | ||||
|                   [nn.Sequential( | ||||
|                       nn.ReLU(inplace=False), | ||||
|                       nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=False), | ||||
|                       nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=False), | ||||
|                       nn.BatchNorm2d(C, affine=True), | ||||
|                       nn.ReLU(inplace=False), | ||||
|                       nn.Conv2d(C, C, 1, stride=1, padding=0, bias=False), | ||||
|                       nn.BatchNorm2d(C, affine=True)), | ||||
|                    nn.Sequential( | ||||
|                       nn.ReLU(inplace=False), | ||||
|                       nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=False), | ||||
|                       nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=False), | ||||
|                       nn.BatchNorm2d(C, affine=True), | ||||
|                       nn.ReLU(inplace=False), | ||||
|                       nn.Conv2d(C, C, 1, stride=1, padding=0, bias=False), | ||||
|                       nn.BatchNorm2d(C, affine=True))]) | ||||
|  | ||||
|     self.ops2 = nn.ModuleList( | ||||
|                   [nn.Sequential( | ||||
|                       nn.MaxPool2d(3, stride=2, padding=1), | ||||
|                       nn.BatchNorm2d(C, affine=True)), | ||||
|                    nn.Sequential( | ||||
|                       nn.MaxPool2d(3, stride=2, padding=1), | ||||
|                       nn.BatchNorm2d(C, affine=True))]) | ||||
|  | ||||
|  | ||||
|   def forward(self, s0, s1, drop_prob = -1): | ||||
|     s0 = self.preprocess0(s0) | ||||
|     s1 = self.preprocess1(s1) | ||||
|  | ||||
|     X0 = self.ops1[0] (s0) | ||||
|     X1 = self.ops1[1] (s1) | ||||
|     if self.training and drop_prob > 0.: | ||||
|       X0, X1 = drop_path(X0, drop_prob), drop_path(X1, drop_prob) | ||||
|  | ||||
|     #X2 = self.ops2[0] (X0+X1) | ||||
|     X2 = self.ops2[0] (s0) | ||||
|     X3 = self.ops2[1] (s1) | ||||
|     if self.training and drop_prob > 0.: | ||||
|       X2, X3 = drop_path(X2, drop_prob), drop_path(X3, drop_prob) | ||||
|     return torch.cat([X0, X1, X2, X3], dim=1) | ||||
							
								
								
									
										245
									
								
								others/GDAS/lib/nas/genotypes.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										245
									
								
								others/GDAS/lib/nas/genotypes.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,245 @@ | ||||
| from collections import namedtuple | ||||
|  | ||||
| Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') | ||||
|  | ||||
| PRIMITIVES = [ | ||||
|     'none', | ||||
|     'max_pool_3x3', | ||||
|     'avg_pool_3x3', | ||||
|     'skip_connect', | ||||
|     'sep_conv_3x3', | ||||
|     'sep_conv_5x5', | ||||
|     'dil_conv_3x3', | ||||
|     'dil_conv_5x5' | ||||
| ] | ||||
|  | ||||
| NASNet = Genotype( | ||||
|   normal = [ | ||||
|     ('sep_conv_5x5', 1, 1.0), | ||||
|     ('sep_conv_3x3', 0, 1.0), | ||||
|     ('sep_conv_5x5', 0, 1.0), | ||||
|     ('sep_conv_3x3', 0, 1.0), | ||||
|     ('avg_pool_3x3', 1, 1.0), | ||||
|     ('skip_connect', 0, 1.0), | ||||
|     ('avg_pool_3x3', 0, 1.0), | ||||
|     ('avg_pool_3x3', 0, 1.0), | ||||
|     ('sep_conv_3x3', 1, 1.0), | ||||
|     ('skip_connect', 1, 1.0), | ||||
|   ], | ||||
|   normal_concat = [2, 3, 4, 5, 6], | ||||
|   reduce = [ | ||||
|     ('sep_conv_5x5', 1, 1.0), | ||||
|     ('sep_conv_7x7', 0, 1.0), | ||||
|     ('max_pool_3x3', 1, 1.0), | ||||
|     ('sep_conv_7x7', 0, 1.0), | ||||
|     ('avg_pool_3x3', 1, 1.0), | ||||
|     ('sep_conv_5x5', 0, 1.0), | ||||
|     ('skip_connect', 3, 1.0), | ||||
|     ('avg_pool_3x3', 2, 1.0), | ||||
|     ('sep_conv_3x3', 2, 1.0), | ||||
|     ('max_pool_3x3', 1, 1.0), | ||||
|   ], | ||||
|   reduce_concat = [4, 5, 6], | ||||
| ) | ||||
|      | ||||
| AmoebaNet = Genotype( | ||||
|   normal = [ | ||||
|     ('avg_pool_3x3', 0, 1.0), | ||||
|     ('max_pool_3x3', 1, 1.0), | ||||
|     ('sep_conv_3x3', 0, 1.0), | ||||
|     ('sep_conv_5x5', 2, 1.0), | ||||
|     ('sep_conv_3x3', 0, 1.0), | ||||
|     ('avg_pool_3x3', 3, 1.0), | ||||
|     ('sep_conv_3x3', 1, 1.0), | ||||
|     ('skip_connect', 1, 1.0), | ||||
|     ('skip_connect', 0, 1.0), | ||||
|     ('avg_pool_3x3', 1, 1.0), | ||||
|     ], | ||||
|   normal_concat = [4, 5, 6], | ||||
|   reduce = [ | ||||
|     ('avg_pool_3x3', 0, 1.0), | ||||
|     ('sep_conv_3x3', 1, 1.0), | ||||
|     ('max_pool_3x3', 0, 1.0), | ||||
|     ('sep_conv_7x7', 2, 1.0), | ||||
|     ('sep_conv_7x7', 0, 1.0), | ||||
|     ('avg_pool_3x3', 1, 1.0), | ||||
|     ('max_pool_3x3', 0, 1.0), | ||||
|     ('max_pool_3x3', 1, 1.0), | ||||
|     ('conv_7x1_1x7', 0, 1.0), | ||||
|     ('sep_conv_3x3', 5, 1.0), | ||||
|   ], | ||||
|   reduce_concat = [3, 4, 6] | ||||
| ) | ||||
|  | ||||
| DARTS_V1 = Genotype( | ||||
|   normal=[ | ||||
|     ('sep_conv_3x3', 1, 1.0), | ||||
|     ('sep_conv_3x3', 0, 1.0), | ||||
|     ('skip_connect', 0, 1.0), | ||||
|     ('sep_conv_3x3', 1, 1.0), | ||||
|     ('skip_connect', 0, 1.0), | ||||
|     ('sep_conv_3x3', 1, 1.0), | ||||
|     ('sep_conv_3x3', 0, 1.0), | ||||
|     ('skip_connect', 2, 1.0)], | ||||
|   normal_concat=[2, 3, 4, 5], | ||||
|   reduce=[ | ||||
|     ('max_pool_3x3', 0, 1.0), | ||||
|     ('max_pool_3x3', 1, 1.0), | ||||
|     ('skip_connect', 2, 1.0), | ||||
|     ('max_pool_3x3', 0, 1.0), | ||||
|     ('max_pool_3x3', 0, 1.0), | ||||
|     ('skip_connect', 2, 1.0), | ||||
|     ('skip_connect', 2, 1.0), | ||||
|     ('avg_pool_3x3', 0, 1.0)], | ||||
|   reduce_concat=[2, 3, 4, 5] | ||||
| ) | ||||
|  | ||||
| DARTS_V2 = Genotype( | ||||
|   normal=[ | ||||
|     ('sep_conv_3x3', 0, 1.0), | ||||
|     ('sep_conv_3x3', 1, 1.0), | ||||
|     ('sep_conv_3x3', 0, 1.0), | ||||
|     ('sep_conv_3x3', 1, 1.0), | ||||
|     ('sep_conv_3x3', 1, 1.0), | ||||
|     ('skip_connect', 0, 1.0), | ||||
|     ('skip_connect', 0, 1.0), | ||||
|     ('dil_conv_3x3', 2, 1.0)], | ||||
|   normal_concat=[2, 3, 4, 5], | ||||
|   reduce=[ | ||||
|     ('max_pool_3x3', 0, 1.0), | ||||
|     ('max_pool_3x3', 1, 1.0), | ||||
|     ('skip_connect', 2, 1.0), | ||||
|     ('max_pool_3x3', 1, 1.0), | ||||
|     ('max_pool_3x3', 0, 1.0), | ||||
|     ('skip_connect', 2, 1.0), | ||||
|     ('skip_connect', 2, 1.0), | ||||
|     ('max_pool_3x3', 1, 1.0)], | ||||
|   reduce_concat=[2, 3, 4, 5] | ||||
| ) | ||||
|  | ||||
| PNASNet = Genotype( | ||||
|   normal = [ | ||||
|     ('sep_conv_5x5', 0, 1.0), | ||||
|     ('max_pool_3x3', 0, 1.0), | ||||
|     ('sep_conv_7x7', 1, 1.0), | ||||
|     ('max_pool_3x3', 1, 1.0), | ||||
|     ('sep_conv_5x5', 1, 1.0), | ||||
|     ('sep_conv_3x3', 1, 1.0), | ||||
|     ('sep_conv_3x3', 4, 1.0), | ||||
|     ('max_pool_3x3', 1, 1.0), | ||||
|     ('sep_conv_3x3', 0, 1.0), | ||||
|     ('skip_connect', 1, 1.0), | ||||
|   ], | ||||
|   normal_concat = [2, 3, 4, 5, 6], | ||||
|   reduce = [ | ||||
|     ('sep_conv_5x5', 0, 1.0), | ||||
|     ('max_pool_3x3', 0, 1.0), | ||||
|     ('sep_conv_7x7', 1, 1.0), | ||||
|     ('max_pool_3x3', 1, 1.0), | ||||
|     ('sep_conv_5x5', 1, 1.0), | ||||
|     ('sep_conv_3x3', 1, 1.0), | ||||
|     ('sep_conv_3x3', 4, 1.0), | ||||
|     ('max_pool_3x3', 1, 1.0), | ||||
|     ('sep_conv_3x3', 0, 1.0), | ||||
|     ('skip_connect', 1, 1.0), | ||||
|   ], | ||||
|   reduce_concat = [2, 3, 4, 5, 6], | ||||
| ) | ||||
|  | ||||
| # https://arxiv.org/pdf/1802.03268.pdf | ||||
| ENASNet = Genotype( | ||||
|   normal = [ | ||||
|     ('sep_conv_3x3', 1, 1.0), | ||||
|     ('skip_connect', 1, 1.0), | ||||
|     ('sep_conv_5x5', 1, 1.0), | ||||
|     ('skip_connect', 0, 1.0), | ||||
|     ('avg_pool_3x3', 0, 1.0), | ||||
|     ('sep_conv_3x3', 1, 1.0), | ||||
|     ('sep_conv_3x3', 0, 1.0), | ||||
|     ('avg_pool_3x3', 1, 1.0), | ||||
|     ('sep_conv_5x5', 1, 1.0), | ||||
|     ('avg_pool_3x3', 0, 1.0), | ||||
|   ], | ||||
|   normal_concat = [2, 3, 4, 5, 6], | ||||
|   reduce = [ | ||||
|     ('sep_conv_5x5', 0, 1.0), | ||||
|     ('sep_conv_3x3', 1, 1.0), # 2 | ||||
|     ('sep_conv_3x3', 1, 1.0), | ||||
|     ('avg_pool_3x3', 1, 1.0), # 3 | ||||
|     ('sep_conv_3x3', 1, 1.0), | ||||
|     ('avg_pool_3x3', 1, 1.0), # 4 | ||||
|     ('avg_pool_3x3', 1, 1.0), | ||||
|     ('sep_conv_5x5', 4, 1.0), # 5 | ||||
|     ('sep_conv_3x3', 5, 1.0), | ||||
|     ('sep_conv_5x5', 0, 1.0), | ||||
|   ], | ||||
|   reduce_concat = [2, 3, 4, 5, 6], | ||||
| ) | ||||
|  | ||||
| DARTS = DARTS_V2 | ||||
|  | ||||
| # Search by normal and reduce | ||||
| GDAS_V1 = Genotype( | ||||
|   normal=[('skip_connect', 0, 0.13017432391643524), ('skip_connect', 1, 0.12947972118854523), ('skip_connect', 0, 0.13062666356563568), ('sep_conv_5x5', 2, 0.12980839610099792), ('sep_conv_3x3', 3, 0.12923765182495117), ('skip_connect', 0, 0.12901571393013), ('sep_conv_5x5', 4, 0.12938997149467468), ('sep_conv_3x3', 3, 0.1289220005273819)], | ||||
|   normal_concat=range(2, 6), | ||||
|   reduce=[('sep_conv_5x5', 0, 0.12862831354141235), ('sep_conv_3x3', 1, 0.12783904373645782), ('sep_conv_5x5', 2, 0.12725995481014252), ('sep_conv_5x5', 1, 0.12705285847187042), ('dil_conv_5x5', 2, 0.12797553837299347), ('sep_conv_3x3', 1, 0.12737272679805756), ('sep_conv_5x5', 0, 0.12833961844444275), ('sep_conv_5x5', 1, 0.12758426368236542)], | ||||
|   reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| # Search by normal and fixing reduction | ||||
| GDAS_F1 = Genotype( | ||||
|   normal=[('skip_connect', 0, 0.16), ('skip_connect', 1, 0.13), ('skip_connect', 0, 0.17), ('sep_conv_3x3', 2, 0.15), ('skip_connect', 0, 0.17), ('sep_conv_3x3', 2, 0.15), ('skip_connect', 0, 0.16), ('sep_conv_3x3', 2, 0.15)], | ||||
|   normal_concat=[2, 3, 4, 5], | ||||
|   reduce=None, | ||||
|   reduce_concat=[2, 3, 4, 5], | ||||
| ) | ||||
|  | ||||
| # Combine DMS_V1 and DMS_F1 | ||||
| GDAS_GF = Genotype( | ||||
|   normal=[('skip_connect', 0, 0.13017432391643524), ('skip_connect', 1, 0.12947972118854523), ('skip_connect', 0, 0.13062666356563568), ('sep_conv_5x5', 2, 0.12980839610099792), ('sep_conv_3x3', 3, 0.12923765182495117), ('skip_connect', 0, 0.12901571393013), ('sep_conv_5x5', 4, 0.12938997149467468), ('sep_conv_3x3', 3, 0.1289220005273819)], | ||||
|   normal_concat=range(2, 6), | ||||
|   reduce=None, | ||||
|   reduce_concat=range(2, 6) | ||||
| ) | ||||
| GDAS_FG = Genotype( | ||||
|   normal=[('skip_connect', 0, 0.16), ('skip_connect', 1, 0.13), ('skip_connect', 0, 0.17), ('sep_conv_3x3', 2, 0.15), ('skip_connect', 0, 0.17), ('sep_conv_3x3', 2, 0.15), ('skip_connect', 0, 0.16), ('sep_conv_3x3', 2, 0.15)], | ||||
|   normal_concat=range(2, 6), | ||||
|   reduce=[('sep_conv_5x5', 0, 0.12862831354141235), ('sep_conv_3x3', 1, 0.12783904373645782), ('sep_conv_5x5', 2, 0.12725995481014252), ('sep_conv_5x5', 1, 0.12705285847187042), ('dil_conv_5x5', 2, 0.12797553837299347), ('sep_conv_3x3', 1, 0.12737272679805756), ('sep_conv_5x5', 0, 0.12833961844444275), ('sep_conv_5x5', 1, 0.12758426368236542)], | ||||
|   reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
| PDARTS = Genotype( | ||||
|   normal=[ | ||||
|     ('skip_connect', 0, 1.0), | ||||
|     ('dil_conv_3x3', 1, 1.0), | ||||
|     ('skip_connect', 0, 1.0), | ||||
|     ('sep_conv_3x3', 1, 1.0), | ||||
|     ('sep_conv_3x3', 1, 1.0), | ||||
|     ('sep_conv_3x3', 3, 1.0), | ||||
|     ('sep_conv_3x3', 0, 1.0), | ||||
|     ('dil_conv_5x5', 4, 1.0)], | ||||
|   normal_concat=range(2, 6), | ||||
|   reduce=[ | ||||
|     ('avg_pool_3x3', 0, 1.0), | ||||
|     ('sep_conv_5x5', 1, 1.0), | ||||
|     ('sep_conv_3x3', 0, 1.0), | ||||
|     ('dil_conv_5x5', 2, 1.0), | ||||
|     ('max_pool_3x3', 0, 1.0), | ||||
|     ('dil_conv_3x3', 1, 1.0), | ||||
|     ('dil_conv_3x3', 1, 1.0), | ||||
|     ('dil_conv_5x5', 3, 1.0)], | ||||
|   reduce_concat=range(2, 6) | ||||
| ) | ||||
|  | ||||
|  | ||||
| model_types = {'DARTS_V1': DARTS_V1, | ||||
|                'DARTS_V2': DARTS_V2, | ||||
|                'NASNet'  : NASNet, | ||||
|                'PNASNet' : PNASNet,  | ||||
|                'AmoebaNet': AmoebaNet, | ||||
|                'ENASNet' : ENASNet, | ||||
|                'PDARTS'  : PDARTS, | ||||
|                'GDAS_V1' : GDAS_V1, | ||||
|                'GDAS_F1' : GDAS_F1, | ||||
|                'GDAS_GF' : GDAS_GF, | ||||
|                'GDAS_FG' : GDAS_FG} | ||||
							
								
								
									
										19
									
								
								others/GDAS/lib/nas/head_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								others/GDAS/lib/nas/head_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,19 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
|  | ||||
| class ImageNetHEAD(nn.Sequential): | ||||
|   def __init__(self, C, stride=2): | ||||
|     super(ImageNetHEAD, self).__init__() | ||||
|     self.add_module('conv1', nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False)) | ||||
|     self.add_module('bn1'  , nn.BatchNorm2d(C // 2)) | ||||
|     self.add_module('relu1', nn.ReLU(inplace=True)) | ||||
|     self.add_module('conv2', nn.Conv2d(C // 2, C, kernel_size=3, stride=stride, padding=1, bias=False)) | ||||
|     self.add_module('bn2'  , nn.BatchNorm2d(C)) | ||||
|  | ||||
|  | ||||
| class CifarHEAD(nn.Sequential): | ||||
|   def __init__(self, C): | ||||
|     super(CifarHEAD, self).__init__() | ||||
|     self.add_module('conv', nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False)) | ||||
|     self.add_module('bn', nn.BatchNorm2d(C)) | ||||
							
								
								
									
										122
									
								
								others/GDAS/lib/nas/operations.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										122
									
								
								others/GDAS/lib/nas/operations.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,122 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| OPS = { | ||||
|   'none'         : lambda C, stride, affine: Zero(stride), | ||||
|   'avg_pool_3x3' : lambda C, stride, affine: nn.Sequential( | ||||
|                                                nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False), | ||||
|                                                nn.BatchNorm2d(C, affine=False) ), | ||||
|   'max_pool_3x3' : lambda C, stride, affine: nn.Sequential( | ||||
|                                                nn.MaxPool2d(3, stride=stride, padding=1), | ||||
|                                                nn.BatchNorm2d(C, affine=False) ), | ||||
|   'skip_connect' : lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine), | ||||
|   'sep_conv_3x3' : lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine), | ||||
|   'sep_conv_5x5' : lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine), | ||||
|   'sep_conv_7x7' : lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine), | ||||
|   'dil_conv_3x3' : lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), | ||||
|   'dil_conv_5x5' : lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), | ||||
|   'conv_7x1_1x7' : lambda C, stride, affine: Conv717(C, C, stride, affine), | ||||
| } | ||||
|  | ||||
| class Conv717(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, stride, affine): | ||||
|     super(Conv717, self).__init__() | ||||
|     self.op = nn.Sequential( | ||||
|       nn.ReLU(inplace=False), | ||||
|       nn.Conv2d(C_in , C_out, (1,7), stride=(1, stride), padding=(0, 3), bias=False), | ||||
|       nn.Conv2d(C_out, C_out, (7,1), stride=(stride, 1), padding=(3, 0), bias=False), | ||||
|       nn.BatchNorm2d(C_out, affine=affine) | ||||
|     ) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     return self.op(x) | ||||
|  | ||||
|  | ||||
| class ReLUConvBN(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): | ||||
|     super(ReLUConvBN, self).__init__() | ||||
|     self.op = nn.Sequential( | ||||
|       nn.ReLU(inplace=False), | ||||
|       nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False), | ||||
|       nn.BatchNorm2d(C_out, affine=affine) | ||||
|     ) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     return self.op(x) | ||||
|  | ||||
|  | ||||
| class DilConv(nn.Module): | ||||
|      | ||||
|   def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): | ||||
|     super(DilConv, self).__init__() | ||||
|     self.op = nn.Sequential( | ||||
|       nn.ReLU(inplace=False), | ||||
|       nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False), | ||||
|       nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), | ||||
|       nn.BatchNorm2d(C_out, affine=affine), | ||||
|       ) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     return self.op(x) | ||||
|  | ||||
|  | ||||
| class SepConv(nn.Module): | ||||
|      | ||||
|   def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): | ||||
|     super(SepConv, self).__init__() | ||||
|     self.op = nn.Sequential( | ||||
|       nn.ReLU(inplace=False), | ||||
|       nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False), | ||||
|       nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), | ||||
|       nn.BatchNorm2d(C_in, affine=affine), | ||||
|       nn.ReLU(inplace=False), | ||||
|       nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False), | ||||
|       nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), | ||||
|       nn.BatchNorm2d(C_out, affine=affine), | ||||
|       ) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     return self.op(x) | ||||
|  | ||||
|  | ||||
| class Identity(nn.Module): | ||||
|  | ||||
|   def __init__(self): | ||||
|     super(Identity, self).__init__() | ||||
|  | ||||
|   def forward(self, x): | ||||
|     return x | ||||
|  | ||||
|  | ||||
| class Zero(nn.Module): | ||||
|  | ||||
|   def __init__(self, stride): | ||||
|     super(Zero, self).__init__() | ||||
|     self.stride = stride | ||||
|  | ||||
|   def forward(self, x): | ||||
|     if self.stride == 1: | ||||
|       return x.mul(0.) | ||||
|     return x[:,:,::self.stride,::self.stride].mul(0.) | ||||
|  | ||||
|  | ||||
| class FactorizedReduce(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, affine=True): | ||||
|     super(FactorizedReduce, self).__init__() | ||||
|     assert C_out % 2 == 0 | ||||
|     self.relu = nn.ReLU(inplace=False) | ||||
|     self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) | ||||
|     self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)  | ||||
|     self.bn = nn.BatchNorm2d(C_out, affine=affine) | ||||
|     self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0) | ||||
|  | ||||
|  | ||||
|   def forward(self, x): | ||||
|     x = self.relu(x) | ||||
|     y = self.pad(x) | ||||
|     out = torch.cat([self.conv_1(x), self.conv_2(y[:,:,1:,1:])], dim=1) | ||||
|     out = self.bn(out) | ||||
|     return out | ||||
		Reference in New Issue
	
	Block a user