support get_net_config for NAS-Bench-201
This commit is contained in:
		| @@ -16,6 +16,7 @@ from .cell_searchs import CellStructure, CellArchitectures | ||||
|  | ||||
| # Cell-based NAS Models | ||||
| def get_cell_based_tiny_net(config): | ||||
|   if isinstance(config, dict): config = dict2config(config, None) # to support the argument being a dict | ||||
|   super_type = getattr(config, 'super_type', 'basic') | ||||
|   group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM'] | ||||
|   if super_type == 'basic' and config.name in group_names: | ||||
| @@ -30,7 +31,12 @@ def get_cell_based_tiny_net(config): | ||||
|                     config.stem_multiplier, config.num_classes, config.space, config.affine, config.track_running_stats) | ||||
|   elif config.name == 'infer.tiny': | ||||
|     from .cell_infers import TinyNetwork | ||||
|     return TinyNetwork(config.C, config.N, config.genotype, config.num_classes) | ||||
|     if hasattr(config, 'genotype'): | ||||
|       genotype = config.genotype | ||||
|     elif hasattr(config, 'arch_str'): | ||||
|       genotype = CellStructure.str2structure(config.arch_str) | ||||
|     else: raise ValueError('Can not find genotype from this config : {:}'.format(config)) | ||||
|     return TinyNetwork(config.C, config.N, genotype, config.num_classes) | ||||
|   else: | ||||
|     raise ValueError('invalid network name : {:}'.format(config.name)) | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user