support get_net_config for NAS-Bench-201
This commit is contained in:
parent
133fd21ecc
commit
25e529f788
@ -72,7 +72,16 @@ index = api.query_index_by_arch('|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1
|
|||||||
api.show(index)
|
api.show(index)
|
||||||
```
|
```
|
||||||
|
|
||||||
5. For other usages, please see `lib/nas_201_api/api.py`. We provide some usage information in the comments for the corresponding functions. If what you want is not provided, please feel free to open an issue for discussion, and I am happy to answer any questions regarding NAS-Bench-201.
|
5. Create the network from api:
|
||||||
|
```
|
||||||
|
config = api.get_net_config(123, 'cifar10') # obtain the network configuration for the 123-th architecture on the CIFAR-10 dataset
|
||||||
|
from models import get_cell_based_tiny_net # this module is in AutoDL-Projects/lib/models
|
||||||
|
network = get_cell_based_tiny_net(config) # create the network from configurration
|
||||||
|
print(network) # show the structure of this architecture
|
||||||
|
```
|
||||||
|
If you want to load the trained weights of this created network, you need to use `api.get_net_param(123, ...)` to obtain the weights and then load it to the network.
|
||||||
|
|
||||||
|
6. For other usages, please see `lib/nas_201_api/api.py`. We provide some usage information in the comments for the corresponding functions. If what you want is not provided, please feel free to open an issue for discussion, and I am happy to answer any questions regarding NAS-Bench-201.
|
||||||
|
|
||||||
|
|
||||||
### Detailed Instruction
|
### Detailed Instruction
|
||||||
|
@ -16,6 +16,7 @@ from .cell_searchs import CellStructure, CellArchitectures
|
|||||||
|
|
||||||
# Cell-based NAS Models
|
# Cell-based NAS Models
|
||||||
def get_cell_based_tiny_net(config):
|
def get_cell_based_tiny_net(config):
|
||||||
|
if isinstance(config, dict): config = dict2config(config, None) # to support the argument being a dict
|
||||||
super_type = getattr(config, 'super_type', 'basic')
|
super_type = getattr(config, 'super_type', 'basic')
|
||||||
group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM']
|
group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM']
|
||||||
if super_type == 'basic' and config.name in group_names:
|
if super_type == 'basic' and config.name in group_names:
|
||||||
@ -30,7 +31,12 @@ def get_cell_based_tiny_net(config):
|
|||||||
config.stem_multiplier, config.num_classes, config.space, config.affine, config.track_running_stats)
|
config.stem_multiplier, config.num_classes, config.space, config.affine, config.track_running_stats)
|
||||||
elif config.name == 'infer.tiny':
|
elif config.name == 'infer.tiny':
|
||||||
from .cell_infers import TinyNetwork
|
from .cell_infers import TinyNetwork
|
||||||
return TinyNetwork(config.C, config.N, config.genotype, config.num_classes)
|
if hasattr(config, 'genotype'):
|
||||||
|
genotype = config.genotype
|
||||||
|
elif hasattr(config, 'arch_str'):
|
||||||
|
genotype = CellStructure.str2structure(config.arch_str)
|
||||||
|
else: raise ValueError('Can not find genotype from this config : {:}'.format(config))
|
||||||
|
return TinyNetwork(config.C, config.N, genotype, config.num_classes)
|
||||||
else:
|
else:
|
||||||
raise ValueError('invalid network name : {:}'.format(config.name))
|
raise ValueError('invalid network name : {:}'.format(config.name))
|
||||||
|
|
||||||
|
@ -93,6 +93,8 @@ class NASBench201API(object):
|
|||||||
else: arch_index = -1
|
else: arch_index = -1
|
||||||
return arch_index
|
return arch_index
|
||||||
|
|
||||||
|
# Overwrite all information of the 'index'-th architecture in the search space.
|
||||||
|
# It will load its data from 'archive_root'.
|
||||||
def reload(self, archive_root, index):
|
def reload(self, archive_root, index):
|
||||||
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
|
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
|
||||||
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index))
|
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index))
|
||||||
@ -123,9 +125,18 @@ class NASBench201API(object):
|
|||||||
print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
|
print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# query information with the training of 12 epochs or 200 epochs
|
# This 'query_by_index' function is used to query information with the training of 12 epochs or 200 epochs.
|
||||||
# if dataname is None, return the ArchResults
|
# ------
|
||||||
|
# If use_12epochs_result=True, we train the model by 12 epochs (see config in configs/nas-benchmark/LESS.config)
|
||||||
|
# If use_12epochs_result=False, we train the model by 200 epochs (see config in configs/nas-benchmark/CIFAR.config)
|
||||||
|
# ------
|
||||||
|
# If dataname is None, return the ArchResults
|
||||||
# else, return a dict with all trials on that dataset (the key is the seed)
|
# else, return a dict with all trials on that dataset (the key is the seed)
|
||||||
|
# Options are 'cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120'.
|
||||||
|
# -- cifar10-valid : training the model on the CIFAR-10 training set.
|
||||||
|
# -- cifar10 : training the model on the CIFAR-10 training + validation set.
|
||||||
|
# -- cifar100 : training the model on the CIFAR-100 training set.
|
||||||
|
# -- ImageNet16-120 : training the model on the ImageNet16-120 training set.
|
||||||
def query_by_index(self, arch_index, dataname=None, use_12epochs_result=False):
|
def query_by_index(self, arch_index, dataname=None, use_12epochs_result=False):
|
||||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||||
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||||
@ -166,13 +177,41 @@ class NASBench201API(object):
|
|||||||
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
|
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
|
||||||
return copy.deepcopy(self.meta_archs[index])
|
return copy.deepcopy(self.meta_archs[index])
|
||||||
|
|
||||||
# obtain the trained weights of the `index`-th architecture on `dataset` with the seed of `seed`
|
"""
|
||||||
|
This function is used to obtain the trained weights of the `index`-th architecture on `dataset` with the seed of `seed`
|
||||||
|
Args [seed]:
|
||||||
|
-- None : return a dict containing the trained weights of all trials, where each key is a seed and its corresponding value is the weights.
|
||||||
|
-- a interger : return the weights of a specific trial, whose seed is this interger.
|
||||||
|
Args [use_12epochs_result]:
|
||||||
|
-- True : train the model by 12 epochs
|
||||||
|
-- False : train the model by 200 epochs
|
||||||
|
"""
|
||||||
def get_net_param(self, index, dataset, seed, use_12epochs_result=False):
|
def get_net_param(self, index, dataset, seed, use_12epochs_result=False):
|
||||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||||
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||||
archresult = arch2infos[index]
|
archresult = arch2infos[index]
|
||||||
return archresult.get_net_param(dataset, seed)
|
return archresult.get_net_param(dataset, seed)
|
||||||
|
|
||||||
|
"""
|
||||||
|
This function is used to obtain the configuration for the `index`-th architecture on `dataset`.
|
||||||
|
Args [dataset] (4 possible options):
|
||||||
|
-- cifar10-valid : training the model on the CIFAR-10 training set.
|
||||||
|
-- cifar10 : training the model on the CIFAR-10 training + validation set.
|
||||||
|
-- cifar100 : training the model on the CIFAR-100 training set.
|
||||||
|
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
|
||||||
|
This function will return a dict.
|
||||||
|
========= Some examlpes for using this function:
|
||||||
|
config = api.get_net_config(128, 'cifar10')
|
||||||
|
"""
|
||||||
|
def get_net_config(self, index, dataset):
|
||||||
|
archresult = self.arch2infos_full[index]
|
||||||
|
all_results = archresult.query(dataset, None)
|
||||||
|
if len(all_results) == 0: raise ValueError('can not find one valid trial for the {:}-th architecture on {:}'.format(index, dataset))
|
||||||
|
for seed, result in all_results.items():
|
||||||
|
return result.get_config(None)
|
||||||
|
#print ('SEED [{:}] : {:}'.format(seed, result))
|
||||||
|
raise ValueError('Impossible to reach here!')
|
||||||
|
|
||||||
# obtain the cost metric for the `index`-th architecture on a dataset
|
# obtain the cost metric for the `index`-th architecture on a dataset
|
||||||
def get_cost_info(self, index, dataset, use_12epochs_result=False):
|
def get_cost_info(self, index, dataset, use_12epochs_result=False):
|
||||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||||
@ -333,6 +372,7 @@ class NASBench201API(object):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ArchResults(object):
|
class ArchResults(object):
|
||||||
|
|
||||||
def __init__(self, arch_index, arch_str):
|
def __init__(self, arch_index, arch_str):
|
||||||
@ -615,8 +655,13 @@ class ResultsCount(object):
|
|||||||
def get_net_param(self):
|
def get_net_param(self):
|
||||||
return self.net_state_dict
|
return self.net_state_dict
|
||||||
|
|
||||||
|
# This function is used to obtain the config dict for this architecture.
|
||||||
def get_config(self, str2structure):
|
def get_config(self, str2structure):
|
||||||
#return copy.deepcopy(self.arch_config)
|
if str2structure is None:
|
||||||
|
return {'name': 'infer.tiny', 'C': self.arch_config['channel'], \
|
||||||
|
'N' : self.arch_config['num_cells'], \
|
||||||
|
'arch_str': self.arch_config['arch_str'], 'num_classes': self.arch_config['class_num']}
|
||||||
|
else:
|
||||||
return {'name': 'infer.tiny', 'C': self.arch_config['channel'], \
|
return {'name': 'infer.tiny', 'C': self.arch_config['channel'], \
|
||||||
'N' : self.arch_config['num_cells'], \
|
'N' : self.arch_config['num_cells'], \
|
||||||
'genotype': str2structure(self.arch_config['arch_str']), 'num_classes': self.arch_config['class_num']}
|
'genotype': str2structure(self.arch_config['arch_str']), 'num_classes': self.arch_config['class_num']}
|
||||||
|
Loading…
Reference in New Issue
Block a user