update CVPR-2019-GDAS re-train NASNet-search-space searched models
This commit is contained in:
		
							
								
								
									
										10
									
								
								configs/archs/NAS-CIFAR-none.config
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								configs/archs/NAS-CIFAR-none.config
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,10 @@ | |||||||
|  | { | ||||||
|  |   "super_type": ["str",  "infer-nasnet.cifar"], | ||||||
|  |   "genotype"  : ["none", "none"], | ||||||
|  |   "dataset"   : ["str",  "cifar"], | ||||||
|  |   "ichannel"  : ["int",   33], | ||||||
|  |   "layers"    : ["int",    6], | ||||||
|  |   "stem_multi": ["int",    3], | ||||||
|  |   "auxiliary" : ["bool",   1], | ||||||
|  |   "drop_path_prob": ["float", 0.2] | ||||||
|  | } | ||||||
							
								
								
									
										9
									
								
								configs/archs/NAS-IMAGENET-none.config
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								configs/archs/NAS-IMAGENET-none.config
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,9 @@ | |||||||
|  | { | ||||||
|  |   "super_type": ["str",  "infer-nasnet.imagenet"], | ||||||
|  |   "genotype"  : ["none", "none"], | ||||||
|  |   "dataset"   : ["str",  "imagenet"], | ||||||
|  |   "ichannel"  : ["int",   50], | ||||||
|  |   "layers"    : ["int",    4], | ||||||
|  |   "auxiliary" : ["bool",   1], | ||||||
|  |   "drop_path_prob": ["float", 0] | ||||||
|  | } | ||||||
| @@ -41,7 +41,16 @@ Please use the following scripts to use GDAS to search as in the original paper: | |||||||
| ``` | ``` | ||||||
| CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/GDAS-search-NASNet-space.sh cifar10 1 -1 | CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/GDAS-search-NASNet-space.sh cifar10 1 -1 | ||||||
| ``` | ``` | ||||||
| If you want to train the searched architecture found by the above scripts, you need to add the config of that architecture (will be printed in log) in [genotypes.py](https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/nas_infer_model/DXYs/genotypes.py). |  | ||||||
|  | **After searching***, if you want to re-train the searched architecture found by the above script, you can use the following script: | ||||||
|  | ``` | ||||||
|  | CUDA_VISIBLE_DEVICES=0 bash ./scripts/retrain-searched-net.sh cifar10 gdas-searched \ | ||||||
|  | 		     output/search-cell-darts/GDAS-cifar10-BN1/checkpoint/seed-945-basic.pth 96 -1 | ||||||
|  | ``` | ||||||
|  | Note that `gdas-searched` is a string to indicate the name of the saved dir and `output/search-cell-darts/GDAS-cifar10-BN1/checkpoint/seed-945-basic.pth` is the file path that the searching algorithm generated. | ||||||
|  |  | ||||||
|  | The above script does not apply heavy augmentation to train the model, so the accuracy will be lower than the original paper. | ||||||
|  | If you want to change the default hyper-parameter for re-training, please have a look at `./scripts/retrain-searched-net.sh` and `configs/archs/NAS-*-none.config`. | ||||||
|  |  | ||||||
| ### Searching on a small search space (NAS-Bench-201) | ### Searching on a small search space (NAS-Bench-201) | ||||||
| The GDAS searching codes on a small search space: | The GDAS searching codes on a small search space: | ||||||
|   | |||||||
| @@ -39,7 +39,9 @@ def main(args): | |||||||
|   if args.model_source == 'normal': |   if args.model_source == 'normal': | ||||||
|     base_model   = obtain_model(model_config) |     base_model   = obtain_model(model_config) | ||||||
|   elif args.model_source == 'nas': |   elif args.model_source == 'nas': | ||||||
|     base_model   = obtain_nas_infer_model(model_config) |     base_model   = obtain_nas_infer_model(model_config, args.extra_model_path) | ||||||
|  |   elif args.model_source == 'autodl-searched': | ||||||
|  |     base_model   = obtain_model(model_config, args.extra_model_path) | ||||||
|   else: |   else: | ||||||
|     raise ValueError('invalid model-source : {:}'.format(args.model_source)) |     raise ValueError('invalid model-source : {:}'.format(args.model_source)) | ||||||
|   flop, param  = get_model_infos(base_model, xshape) |   flop, param  = get_model_infos(base_model, xshape) | ||||||
|   | |||||||
| @@ -12,6 +12,7 @@ def obtain_basic_args(): | |||||||
|   parser.add_argument('--optim_config',     type=str,                   help='The path to the optimizer configuration') |   parser.add_argument('--optim_config',     type=str,                   help='The path to the optimizer configuration') | ||||||
|   parser.add_argument('--procedure'   ,     type=str,                   help='The procedure basic prefix.') |   parser.add_argument('--procedure'   ,     type=str,                   help='The procedure basic prefix.') | ||||||
|   parser.add_argument('--model_source',     type=str,  default='normal',help='The source of model defination.') |   parser.add_argument('--model_source',     type=str,  default='normal',help='The source of model defination.') | ||||||
|  |   parser.add_argument('--extra_model_path', type=str,  default=None,    help='The extra model ckp file (help to indicate the searched architecture).') | ||||||
|   add_shared_args( parser ) |   add_shared_args( parser ) | ||||||
|   # Optimization options |   # Optimization options | ||||||
|   parser.add_argument('--batch_size',       type=int,  default=2,       help='Batch size for training.') |   parser.add_argument('--batch_size',       type=int,  default=2,       help='Batch size for training.') | ||||||
|   | |||||||
| @@ -29,7 +29,8 @@ def convert_param(original_lists): | |||||||
|     elif ctype == 'float': |     elif ctype == 'float': | ||||||
|       x = float(x) |       x = float(x) | ||||||
|     elif ctype == 'none': |     elif ctype == 'none': | ||||||
|       assert x == 'None', 'for none type, the value must be None instead of {:}'.format(x) |       if x.lower() != 'none': | ||||||
|  |         raise ValueError('For the none type, the value must be none instead of {:}'.format(x)) | ||||||
|       x = None |       x = None | ||||||
|     else: |     else: | ||||||
|       raise TypeError('Does not know this type : {:}'.format(ctype)) |       raise TypeError('Does not know this type : {:}'.format(ctype)) | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ | |||||||
| ################################################## | ################################################## | ||||||
| from os import path as osp | from os import path as osp | ||||||
| from typing import List, Text | from typing import List, Text | ||||||
|  | import torch | ||||||
|  |  | ||||||
| __all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \ | __all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \ | ||||||
|            'obtain_model', 'obtain_search_model', 'load_net_from_checkpoint', \ |            'obtain_model', 'obtain_search_model', 'load_net_from_checkpoint', \ | ||||||
| @@ -38,6 +39,9 @@ def get_cell_based_tiny_net(config): | |||||||
|       genotype = CellStructure.str2structure(config.arch_str) |       genotype = CellStructure.str2structure(config.arch_str) | ||||||
|     else: raise ValueError('Can not find genotype from this config : {:}'.format(config)) |     else: raise ValueError('Can not find genotype from this config : {:}'.format(config)) | ||||||
|     return TinyNetwork(config.C, config.N, genotype, config.num_classes) |     return TinyNetwork(config.C, config.N, genotype, config.num_classes) | ||||||
|  |   elif config.name == 'infer.nasnet-cifar': | ||||||
|  |     from .cell_infers import NASNetonCIFAR | ||||||
|  |     raise NotImplementedError | ||||||
|   else: |   else: | ||||||
|     raise ValueError('invalid network name : {:}'.format(config.name)) |     raise ValueError('invalid network name : {:}'.format(config.name)) | ||||||
|  |  | ||||||
| @@ -52,13 +56,12 @@ def get_search_spaces(xtype, name) -> List[Text]: | |||||||
|     raise ValueError('invalid search-space type is {:}'.format(xtype)) |     raise ValueError('invalid search-space type is {:}'.format(xtype)) | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_cifar_models(config): | def get_cifar_models(config, extra_path=None): | ||||||
|   from .CifarResNet      import CifarResNet |  | ||||||
|   from .CifarDenseNet    import DenseNet |  | ||||||
|   from .CifarWideResNet  import CifarWideResNet |  | ||||||
|    |  | ||||||
|   super_type = getattr(config, 'super_type', 'basic') |   super_type = getattr(config, 'super_type', 'basic') | ||||||
|   if super_type == 'basic': |   if super_type == 'basic': | ||||||
|  |     from .CifarResNet      import CifarResNet | ||||||
|  |     from .CifarDenseNet    import DenseNet | ||||||
|  |     from .CifarWideResNet  import CifarWideResNet | ||||||
|     if config.arch == 'resnet': |     if config.arch == 'resnet': | ||||||
|       return CifarResNet(config.module, config.depth, config.class_num, config.zero_init_residual) |       return CifarResNet(config.module, config.depth, config.class_num, config.zero_init_residual) | ||||||
|     elif config.arch == 'densenet': |     elif config.arch == 'densenet': | ||||||
| @@ -71,6 +74,7 @@ def get_cifar_models(config): | |||||||
|     from .shape_infers import InferWidthCifarResNet |     from .shape_infers import InferWidthCifarResNet | ||||||
|     from .shape_infers import InferDepthCifarResNet |     from .shape_infers import InferDepthCifarResNet | ||||||
|     from .shape_infers import InferCifarResNet |     from .shape_infers import InferCifarResNet | ||||||
|  |     from .cell_infers import NASNetonCIFAR | ||||||
|     assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type) |     assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type) | ||||||
|     infer_mode = super_type.split('-')[1] |     infer_mode = super_type.split('-')[1] | ||||||
|     if infer_mode == 'width': |     if infer_mode == 'width': | ||||||
| @@ -79,6 +83,16 @@ def get_cifar_models(config): | |||||||
|       return InferDepthCifarResNet(config.module, config.depth, config.xblocks, config.class_num, config.zero_init_residual) |       return InferDepthCifarResNet(config.module, config.depth, config.xblocks, config.class_num, config.zero_init_residual) | ||||||
|     elif infer_mode == 'shape': |     elif infer_mode == 'shape': | ||||||
|       return InferCifarResNet(config.module, config.depth, config.xblocks, config.xchannels, config.class_num, config.zero_init_residual) |       return InferCifarResNet(config.module, config.depth, config.xblocks, config.xchannels, config.class_num, config.zero_init_residual) | ||||||
|  |     elif infer_mode == 'nasnet.cifar': | ||||||
|  |       genotype = config.genotype | ||||||
|  |       if extra_path is not None:  # reload genotype by extra_path | ||||||
|  |         if not osp.isfile(extra_path): raise ValueError('invalid extra_path : {:}'.format(extra_path)) | ||||||
|  |         xdata = torch.load(extra_path) | ||||||
|  |         current_epoch = xdata['epoch'] | ||||||
|  |         genotype = xdata['genotypes'][current_epoch-1] | ||||||
|  |       C = config.C if hasattr(config, 'C') else config.ichannel | ||||||
|  |       N = config.N if hasattr(config, 'N') else config.layers | ||||||
|  |       return NASNetonCIFAR(C, N, config.stem_multi, config.class_num, genotype, config.auxiliary) | ||||||
|     else: |     else: | ||||||
|       raise ValueError('invalid infer-mode : {:}'.format(infer_mode)) |       raise ValueError('invalid infer-mode : {:}'.format(infer_mode)) | ||||||
|   else: |   else: | ||||||
| @@ -111,9 +125,10 @@ def get_imagenet_models(config): | |||||||
|     raise ValueError('invalid super-type : {:}'.format(super_type)) |     raise ValueError('invalid super-type : {:}'.format(super_type)) | ||||||
|  |  | ||||||
|  |  | ||||||
| def obtain_model(config): | # Try to obtain the network by config. | ||||||
|  | def obtain_model(config, extra_path=None): | ||||||
|   if config.dataset == 'cifar': |   if config.dataset == 'cifar': | ||||||
|     return get_cifar_models(config) |     return get_cifar_models(config, extra_path) | ||||||
|   elif config.dataset == 'imagenet': |   elif config.dataset == 'imagenet': | ||||||
|     return get_imagenet_models(config) |     return get_imagenet_models(config) | ||||||
|   else: |   else: | ||||||
| @@ -152,7 +167,6 @@ def obtain_search_model(config): | |||||||
|  |  | ||||||
|  |  | ||||||
| def load_net_from_checkpoint(checkpoint): | def load_net_from_checkpoint(checkpoint): | ||||||
|   import torch |  | ||||||
|   assert osp.isfile(checkpoint), 'checkpoint {:} does not exist'.format(checkpoint) |   assert osp.isfile(checkpoint), 'checkpoint {:} does not exist'.format(checkpoint) | ||||||
|   checkpoint   = torch.load(checkpoint) |   checkpoint   = torch.load(checkpoint) | ||||||
|   model_config = dict2config(checkpoint['model-config'], None) |   model_config = dict2config(checkpoint['model-config'], None) | ||||||
|   | |||||||
| @@ -2,3 +2,4 @@ | |||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||||
| ##################################################### | ##################################################### | ||||||
| from .tiny_network import TinyNetwork | from .tiny_network import TinyNetwork | ||||||
|  | from .nasnet_cifar import NASNetonCIFAR | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ | |||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||||
| ##################################################### | ##################################################### | ||||||
|  |  | ||||||
|  | import torch | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from ..cell_operations import OPS | from ..cell_operations import OPS | ||||||
| @@ -50,3 +51,70 @@ class InferCell(nn.Module): | |||||||
|       node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) ) |       node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) ) | ||||||
|       nodes.append( node_feature ) |       nodes.append( node_feature ) | ||||||
|     return nodes[-1] |     return nodes[-1] | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018 | ||||||
|  | class NASNetInferCell(nn.Module): | ||||||
|  |  | ||||||
|  |   def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats): | ||||||
|  |     super(NASNetInferCell, self).__init__() | ||||||
|  |     self.reduction = reduction | ||||||
|  |     if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats) | ||||||
|  |     else             : self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats) | ||||||
|  |     self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats) | ||||||
|  |  | ||||||
|  |     if not reduction: | ||||||
|  |       nodes, concats = genotype['normal'], genotype['normal_concat'] | ||||||
|  |     else: | ||||||
|  |       nodes, concats = genotype['reduce'], genotype['reduce_concat'] | ||||||
|  |     self._multiplier = len(concats) | ||||||
|  |     self._concats = concats | ||||||
|  |     self._steps = len(nodes) | ||||||
|  |     self._nodes = nodes | ||||||
|  |     self.edges = nn.ModuleDict() | ||||||
|  |     for i, node in enumerate(nodes): | ||||||
|  |       for in_node in node: | ||||||
|  |         name, j = in_node[0], in_node[1] | ||||||
|  |         stride = 2 if reduction and j < 2 else 1 | ||||||
|  |         node_str = '{:}<-{:}'.format(i+2, j) | ||||||
|  |         self.edges[node_str] = OPS[name](C, C, stride, affine, track_running_stats) | ||||||
|  |  | ||||||
|  |   # [TODO] to support drop_prob in this function.. | ||||||
|  |   def forward(self, s0, s1, unused_drop_prob): | ||||||
|  |     s0 = self.preprocess0(s0) | ||||||
|  |     s1 = self.preprocess1(s1) | ||||||
|  |  | ||||||
|  |     states = [s0, s1] | ||||||
|  |     for i, node in enumerate(self._nodes): | ||||||
|  |       clist = [] | ||||||
|  |       for in_node in node: | ||||||
|  |         name, j = in_node[0], in_node[1] | ||||||
|  |         node_str = '{:}<-{:}'.format(i+2, j) | ||||||
|  |         op = self.edges[ node_str ] | ||||||
|  |         clist.append( op(states[j]) ) | ||||||
|  |       states.append( sum(clist) ) | ||||||
|  |     return torch.cat([states[x] for x in self._concats], dim=1) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class AuxiliaryHeadCIFAR(nn.Module): | ||||||
|  |  | ||||||
|  |   def __init__(self, C, num_classes): | ||||||
|  |     """assuming input size 8x8""" | ||||||
|  |     super(AuxiliaryHeadCIFAR, self).__init__() | ||||||
|  |     self.features = nn.Sequential( | ||||||
|  |       nn.ReLU(inplace=True), | ||||||
|  |       nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2 | ||||||
|  |       nn.Conv2d(C, 128, 1, bias=False), | ||||||
|  |       nn.BatchNorm2d(128), | ||||||
|  |       nn.ReLU(inplace=True), | ||||||
|  |       nn.Conv2d(128, 768, 2, bias=False), | ||||||
|  |       nn.BatchNorm2d(768), | ||||||
|  |       nn.ReLU(inplace=True) | ||||||
|  |     ) | ||||||
|  |     self.classifier = nn.Linear(768, num_classes) | ||||||
|  |  | ||||||
|  |   def forward(self, x): | ||||||
|  |     x = self.features(x) | ||||||
|  |     x = self.classifier(x.view(x.size(0),-1)) | ||||||
|  |     return x | ||||||
|   | |||||||
							
								
								
									
										71
									
								
								lib/models/cell_infers/nasnet_cifar.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								lib/models/cell_infers/nasnet_cifar.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,71 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||||
|  | ##################################################### | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | from copy import deepcopy | ||||||
|  | from .cells import NASNetInferCell as InferCell, AuxiliaryHeadCIFAR | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # The macro structure is based on NASNet | ||||||
|  | class NASNetonCIFAR(nn.Module): | ||||||
|  |  | ||||||
|  |   def __init__(self, C, N, stem_multiplier, num_classes, genotype, auxiliary, affine=True, track_running_stats=True): | ||||||
|  |     super(NASNetonCIFAR, self).__init__() | ||||||
|  |     self._C        = C | ||||||
|  |     self._layerN   = N | ||||||
|  |     self.stem = nn.Sequential( | ||||||
|  |                     nn.Conv2d(3, C*stem_multiplier, kernel_size=3, padding=1, bias=False), | ||||||
|  |                     nn.BatchNorm2d(C*stem_multiplier)) | ||||||
|  |    | ||||||
|  |     # config for each layer | ||||||
|  |     layer_channels   = [C    ] * N + [C*2 ] + [C*2  ] * (N-1) + [C*4 ] + [C*4  ] * (N-1) | ||||||
|  |     layer_reductions = [False] * N + [True] + [False] * (N-1) + [True] + [False] * (N-1) | ||||||
|  |  | ||||||
|  |     C_prev_prev, C_prev, C_curr, reduction_prev = C*stem_multiplier, C*stem_multiplier, C, False | ||||||
|  |     self.auxiliary_index = None | ||||||
|  |     self.auxiliary_head  = None | ||||||
|  |     self.cells = nn.ModuleList() | ||||||
|  |     for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): | ||||||
|  |       cell = InferCell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, affine, track_running_stats) | ||||||
|  |       self.cells.append( cell ) | ||||||
|  |       C_prev_prev, C_prev, reduction_prev = C_prev, cell._multiplier*C_curr, reduction | ||||||
|  |       if reduction and C_curr == C*4 and auxiliary: | ||||||
|  |         self.auxiliary_head = AuxiliaryHeadCIFAR(C_prev, num_classes) | ||||||
|  |         self.auxiliary_index = index | ||||||
|  |     self._Layer     = len(self.cells) | ||||||
|  |     self.lastact    = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||||
|  |     self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||||
|  |     self.classifier = nn.Linear(C_prev, num_classes) | ||||||
|  |     self.drop_path_prob = -1 | ||||||
|  |  | ||||||
|  |   def update_drop_path(self, drop_path_prob): | ||||||
|  |     self.drop_path_prob = drop_path_prob | ||||||
|  |  | ||||||
|  |   def auxiliary_param(self): | ||||||
|  |     if self.auxiliary_head is None: return [] | ||||||
|  |     else: return list( self.auxiliary_head.parameters() ) | ||||||
|  |  | ||||||
|  |   def get_message(self): | ||||||
|  |     string = self.extra_repr() | ||||||
|  |     for i, cell in enumerate(self.cells): | ||||||
|  |       string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) | ||||||
|  |     return string | ||||||
|  |  | ||||||
|  |   def extra_repr(self): | ||||||
|  |     return ('{name}(C={_C}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||||
|  |  | ||||||
|  |   def forward(self, inputs): | ||||||
|  |     stem_feature, logits_aux = self.stem(inputs), None | ||||||
|  |     cell_results = [stem_feature, stem_feature] | ||||||
|  |     for i, cell in enumerate(self.cells): | ||||||
|  |       cell_feature = cell(cell_results[-2], cell_results[-1], self.drop_path_prob) | ||||||
|  |       cell_results.append( cell_feature ) | ||||||
|  |       if self.auxiliary_index is not None and i == self.auxiliary_index and self.training: | ||||||
|  |         logits_aux = self.auxiliary_head( cell_results[-1] ) | ||||||
|  |     out = self.lastact(cell_results[-1]) | ||||||
|  |     out = self.global_pooling( out ) | ||||||
|  |     out = out.view(out.size(0), -1) | ||||||
|  |     logits = self.classifier(out) | ||||||
|  |     if logits_aux is None: return out, logits | ||||||
|  |     else: return out, [logits, logits_aux] | ||||||
| @@ -155,7 +155,7 @@ class NASNetSearchCell(nn.Module): | |||||||
|     self.edges     = nn.ModuleDict() |     self.edges     = nn.ModuleDict() | ||||||
|     for i in range(self._steps): |     for i in range(self._steps): | ||||||
|       for j in range(2+i): |       for j in range(2+i): | ||||||
|         node_str = '{:}<-{:}'.format(i, j) |         node_str = '{:}<-{:}'.format(i, j)  # indicate the edge from node-(j) to node-(i+2) | ||||||
|         stride = 2 if reduction and j < 2 else 1 |         stride = 2 if reduction and j < 2 else 1 | ||||||
|         op = MixedOp(space, C, stride, affine, track_running_stats) |         op = MixedOp(space, C, stride, affine, track_running_stats) | ||||||
|         self.edges[ node_str ] = op |         self.edges[ node_str ] = op | ||||||
|   | |||||||
| @@ -5,8 +5,7 @@ import torch | |||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from typing import List, Text, Dict | from typing import List, Text, Dict | ||||||
| from .search_cells     import NASNetSearchCell as SearchCell | from .search_cells import NASNetSearchCell as SearchCell | ||||||
| from .genotypes        import Structure |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # The macro structure is based on NASNet | # The macro structure is based on NASNet | ||||||
|   | |||||||
| @@ -4,8 +4,7 @@ | |||||||
| import torch | import torch | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from .search_cells     import NASNetSearchCell as SearchCell | from .search_cells import NASNetSearchCell as SearchCell | ||||||
| from .genotypes        import Structure |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # The macro structure is based on NASNet | # The macro structure is based on NASNet | ||||||
|   | |||||||
| @@ -168,5 +168,15 @@ Networks = {'DARTS_V1': DARTS_V1, | |||||||
|             'SETN'    : SETN, |             'SETN'    : SETN, | ||||||
|            } |            } | ||||||
|  |  | ||||||
|  | # This function will return a Genotype from a dict. | ||||||
| def build_genotype_from_dict(xdict): | def build_genotype_from_dict(xdict): | ||||||
|   import pdb; pdb.set_trace() |   def remove_value(nodes): | ||||||
|  |     return [tuple([(x[0], x[1]) for x in node]) for node in nodes] | ||||||
|  |   genotype = Genotype( | ||||||
|  |       normal=remove_value(xdict['normal']), | ||||||
|  |       normal_concat=xdict['normal_concat'], | ||||||
|  |       reduce=remove_value(xdict['reduce']), | ||||||
|  |       reduce_concat=xdict['reduce_concat'], | ||||||
|  |       connectN=None, connects=None | ||||||
|  |       ) | ||||||
|  |   return genotype | ||||||
|   | |||||||
| @@ -6,12 +6,22 @@ | |||||||
| # Currently, this package is used to reproduce the results in GDAS (Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019). | # Currently, this package is used to reproduce the results in GDAS (Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019). | ||||||
| ################################################## | ################################################## | ||||||
|  |  | ||||||
| import torch | import os, torch | ||||||
|  |  | ||||||
| def obtain_nas_infer_model(config): | def obtain_nas_infer_model(config, extra_model_path=None): | ||||||
|  |    | ||||||
|   if config.arch == 'dxys': |   if config.arch == 'dxys': | ||||||
|     from .DXYs import CifarNet, ImageNet, Networks |     from .DXYs import CifarNet, ImageNet, Networks | ||||||
|     genotype = Networks[config.genotype] |     from .DXYs import build_genotype_from_dict | ||||||
|  |     if config.genotype is None: | ||||||
|  |       if extra_model_path is not None and not os.path.isfile(extra_model_path): | ||||||
|  |         raise ValueError('When genotype in confiig is None, extra_model_path must be set as a path instead of {:}'.format(extra_model_path)) | ||||||
|  |       xdata = torch.load(extra_model_path) | ||||||
|  |       current_epoch = xdata['epoch'] | ||||||
|  |       genotype_dict = xdata['genotypes'][current_epoch-1] | ||||||
|  |       genotype = build_genotype_from_dict(genotype_dict) | ||||||
|  |     else: | ||||||
|  |       genotype = Networks[config.genotype] | ||||||
|     if config.dataset == 'cifar': |     if config.dataset == 'cifar': | ||||||
|       return CifarNet(config.ichannel, config.layers, config.stem_multi, config.auxiliary, genotype, config.class_num) |       return CifarNet(config.ichannel, config.layers, config.stem_multi, config.auxiliary, genotype, config.class_num) | ||||||
|     elif config.dataset == 'imagenet': |     elif config.dataset == 'imagenet': | ||||||
|   | |||||||
| @@ -4,7 +4,7 @@ echo script name: $0 | |||||||
| echo $# arguments | echo $# arguments | ||||||
| if [ "$#" -ne 4 ] ;then | if [ "$#" -ne 4 ] ;then | ||||||
|   echo "Input illegal number of parameters " $# |   echo "Input illegal number of parameters " $# | ||||||
|   echo "Need 4 parameters for dataset and the-model-name and epochs and LR and the-batch-size and the-random-seed" |   echo "Need 4 parameters for dataset, the-model-name, the-batch-size and the-random-seed" | ||||||
|   exit 1 |   exit 1 | ||||||
| fi | fi | ||||||
| if [ "$TORCH_HOME" = "" ]; then | if [ "$TORCH_HOME" = "" ]; then | ||||||
|   | |||||||
							
								
								
									
										53
									
								
								scripts/retrain-searched-net.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								scripts/retrain-searched-net.sh
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,53 @@ | |||||||
|  | #!/bin/bash | ||||||
|  | # bash ./scripts/retrain-searched-net.sh cifar10 ${NAME} ${PATH} 256 -1 | ||||||
|  | echo script name: $0 | ||||||
|  | echo $# arguments | ||||||
|  | if [ "$#" -ne 5 ] ;then | ||||||
|  |   echo "Input illegal number of parameters " $# | ||||||
|  |   echo "Need 5 parameters for dataset, the save dir base name, the model path, the batch size, the random seed" | ||||||
|  |   exit 1 | ||||||
|  | fi | ||||||
|  | if [ "$TORCH_HOME" = "" ]; then | ||||||
|  |   echo "Must set TORCH_HOME envoriment variable for data dir saving" | ||||||
|  |   exit 1 | ||||||
|  | else | ||||||
|  |   echo "TORCH_HOME : $TORCH_HOME" | ||||||
|  | fi | ||||||
|  |  | ||||||
|  | dataset=$1 | ||||||
|  | save_name=$2 | ||||||
|  | model_path=$3 | ||||||
|  | batch=$4 | ||||||
|  | rseed=$5 | ||||||
|  |  | ||||||
|  | if [ ${dataset} == 'cifar10' ] || [ ${dataset} == 'cifar100' ]; then | ||||||
|  |   xpath=$TORCH_HOME/cifar.python | ||||||
|  |   base=CIFAR | ||||||
|  |   workers=4 | ||||||
|  |   cutout_length=16 | ||||||
|  | elif [ ${dataset} == 'imagenet-1k' ]; then | ||||||
|  |   xpath=$TORCH_HOME/ILSVRC2012 | ||||||
|  |   base=IMAGENET | ||||||
|  |   workers=28 | ||||||
|  |   cutout_length=-1 | ||||||
|  | else | ||||||
|  |   exit 1 | ||||||
|  |   echo 'Unknown dataset: '${dataset} | ||||||
|  | fi | ||||||
|  |  | ||||||
|  | SAVE_ROOT="./output" | ||||||
|  |  | ||||||
|  | save_dir=${SAVE_ROOT}/nas-infer/${dataset}-BS${batch}-${save_name} | ||||||
|  |  | ||||||
|  | python --version | ||||||
|  |  | ||||||
|  | python ./exps/basic-main.py --dataset ${dataset} \ | ||||||
|  | 	--data_path ${xpath} --model_source autodl-searched \ | ||||||
|  | 	--model_config ./configs/archs/NAS-${base}-none.config \ | ||||||
|  | 	--optim_config ./configs/opts/NAS-${base}.config \ | ||||||
|  | 	--extra_model_path ${model_path} \ | ||||||
|  | 	--procedure    basic \ | ||||||
|  | 	--save_dir     ${save_dir} \ | ||||||
|  | 	--cutout_length ${cutout_length} \ | ||||||
|  | 	--batch_size  ${batch} --rand_seed ${rseed} --workers ${workers} \ | ||||||
|  | 	--eval_frequency 1 --print_freq 500 --print_freq_eval 1000 | ||||||
		Reference in New Issue
	
	Block a user