From 25e529f788f1f66903ea34f720e702ee48ddcc69 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Sun, 2 Feb 2020 17:20:38 +1100 Subject: [PATCH] support get_net_config for NAS-Bench-201 --- docs/NAS-Bench-201.md | 11 +++++++- lib/models/__init__.py | 8 +++++- lib/nas_201_api/api.py | 59 +++++++++++++++++++++++++++++++++++++----- 3 files changed, 69 insertions(+), 9 deletions(-) diff --git a/docs/NAS-Bench-201.md b/docs/NAS-Bench-201.md index 31cdd82..7042bca 100644 --- a/docs/NAS-Bench-201.md +++ b/docs/NAS-Bench-201.md @@ -72,7 +72,16 @@ index = api.query_index_by_arch('|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1 api.show(index) ``` -5. For other usages, please see `lib/nas_201_api/api.py`. We provide some usage information in the comments for the corresponding functions. If what you want is not provided, please feel free to open an issue for discussion, and I am happy to answer any questions regarding NAS-Bench-201. +5. Create the network from api: +``` +config = api.get_net_config(123, 'cifar10') # obtain the network configuration for the 123-th architecture on the CIFAR-10 dataset +from models import get_cell_based_tiny_net # this module is in AutoDL-Projects/lib/models +network = get_cell_based_tiny_net(config) # create the network from configurration +print(network) # show the structure of this architecture +``` +If you want to load the trained weights of this created network, you need to use `api.get_net_param(123, ...)` to obtain the weights and then load it to the network. + +6. For other usages, please see `lib/nas_201_api/api.py`. We provide some usage information in the comments for the corresponding functions. If what you want is not provided, please feel free to open an issue for discussion, and I am happy to answer any questions regarding NAS-Bench-201. ### Detailed Instruction diff --git a/lib/models/__init__.py b/lib/models/__init__.py index 34087b4..456f26c 100644 --- a/lib/models/__init__.py +++ b/lib/models/__init__.py @@ -16,6 +16,7 @@ from .cell_searchs import CellStructure, CellArchitectures # Cell-based NAS Models def get_cell_based_tiny_net(config): + if isinstance(config, dict): config = dict2config(config, None) # to support the argument being a dict super_type = getattr(config, 'super_type', 'basic') group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM'] if super_type == 'basic' and config.name in group_names: @@ -30,7 +31,12 @@ def get_cell_based_tiny_net(config): config.stem_multiplier, config.num_classes, config.space, config.affine, config.track_running_stats) elif config.name == 'infer.tiny': from .cell_infers import TinyNetwork - return TinyNetwork(config.C, config.N, config.genotype, config.num_classes) + if hasattr(config, 'genotype'): + genotype = config.genotype + elif hasattr(config, 'arch_str'): + genotype = CellStructure.str2structure(config.arch_str) + else: raise ValueError('Can not find genotype from this config : {:}'.format(config)) + return TinyNetwork(config.C, config.N, genotype, config.num_classes) else: raise ValueError('invalid network name : {:}'.format(config.name)) diff --git a/lib/nas_201_api/api.py b/lib/nas_201_api/api.py index 2222e71..9712fe2 100644 --- a/lib/nas_201_api/api.py +++ b/lib/nas_201_api/api.py @@ -93,6 +93,8 @@ class NASBench201API(object): else: arch_index = -1 return arch_index + # Overwrite all information of the 'index'-th architecture in the search space. + # It will load its data from 'archive_root'. def reload(self, archive_root, index): assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root) xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index)) @@ -123,9 +125,18 @@ class NASBench201API(object): print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index)) return None - # query information with the training of 12 epochs or 200 epochs - # if dataname is None, return the ArchResults + # This 'query_by_index' function is used to query information with the training of 12 epochs or 200 epochs. + # ------ + # If use_12epochs_result=True, we train the model by 12 epochs (see config in configs/nas-benchmark/LESS.config) + # If use_12epochs_result=False, we train the model by 200 epochs (see config in configs/nas-benchmark/CIFAR.config) + # ------ + # If dataname is None, return the ArchResults # else, return a dict with all trials on that dataset (the key is the seed) + # Options are 'cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120'. + # -- cifar10-valid : training the model on the CIFAR-10 training set. + # -- cifar10 : training the model on the CIFAR-10 training + validation set. + # -- cifar100 : training the model on the CIFAR-100 training set. + # -- ImageNet16-120 : training the model on the ImageNet16-120 training set. def query_by_index(self, arch_index, dataname=None, use_12epochs_result=False): if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less else : basestr, arch2infos = '200epochs', self.arch2infos_full @@ -166,12 +177,40 @@ class NASBench201API(object): assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs)) return copy.deepcopy(self.meta_archs[index]) - # obtain the trained weights of the `index`-th architecture on `dataset` with the seed of `seed` + """ + This function is used to obtain the trained weights of the `index`-th architecture on `dataset` with the seed of `seed` + Args [seed]: + -- None : return a dict containing the trained weights of all trials, where each key is a seed and its corresponding value is the weights. + -- a interger : return the weights of a specific trial, whose seed is this interger. + Args [use_12epochs_result]: + -- True : train the model by 12 epochs + -- False : train the model by 200 epochs + """ def get_net_param(self, index, dataset, seed, use_12epochs_result=False): if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less else : basestr, arch2infos = '200epochs', self.arch2infos_full archresult = arch2infos[index] return archresult.get_net_param(dataset, seed) + + """ + This function is used to obtain the configuration for the `index`-th architecture on `dataset`. + Args [dataset] (4 possible options): + -- cifar10-valid : training the model on the CIFAR-10 training set. + -- cifar10 : training the model on the CIFAR-10 training + validation set. + -- cifar100 : training the model on the CIFAR-100 training set. + -- ImageNet16-120 : training the model on the ImageNet16-120 training set. + This function will return a dict. + ========= Some examlpes for using this function: + config = api.get_net_config(128, 'cifar10') + """ + def get_net_config(self, index, dataset): + archresult = self.arch2infos_full[index] + all_results = archresult.query(dataset, None) + if len(all_results) == 0: raise ValueError('can not find one valid trial for the {:}-th architecture on {:}'.format(index, dataset)) + for seed, result in all_results.items(): + return result.get_config(None) + #print ('SEED [{:}] : {:}'.format(seed, result)) + raise ValueError('Impossible to reach here!') # obtain the cost metric for the `index`-th architecture on a dataset def get_cost_info(self, index, dataset, use_12epochs_result=False): @@ -333,6 +372,7 @@ class NASBench201API(object): + class ArchResults(object): def __init__(self, arch_index, arch_str): @@ -615,11 +655,16 @@ class ResultsCount(object): def get_net_param(self): return self.net_state_dict + # This function is used to obtain the config dict for this architecture. def get_config(self, str2structure): - #return copy.deepcopy(self.arch_config) - return {'name': 'infer.tiny', 'C': self.arch_config['channel'], \ - 'N' : self.arch_config['num_cells'], \ - 'genotype': str2structure(self.arch_config['arch_str']), 'num_classes': self.arch_config['class_num']} + if str2structure is None: + return {'name': 'infer.tiny', 'C': self.arch_config['channel'], \ + 'N' : self.arch_config['num_cells'], \ + 'arch_str': self.arch_config['arch_str'], 'num_classes': self.arch_config['class_num']} + else: + return {'name': 'infer.tiny', 'C': self.arch_config['channel'], \ + 'N' : self.arch_config['num_cells'], \ + 'genotype': str2structure(self.arch_config['arch_str']), 'num_classes': self.arch_config['class_num']} def state_dict(self): _state_dict = {key: value for key, value in self.__dict__.items()}