Upgrade NAS-API to v2.0:
we use an abstract class NASBenchMetaAPI to define the spec of an API; it can be inherited to support different kinds of NAS API, while keep the query interface the same.
This commit is contained in:
		| @@ -22,7 +22,7 @@ def create_result_count(used_seed: int, dataset: Text, arch_config: Dict[Text, A | ||||
|                         results: Dict[Text, Any], dataloader_dict: Dict[Text, Any]) -> ResultsCount: | ||||
|   xresult = ResultsCount(dataset, results['net_state_dict'], results['train_acc1es'], results['train_losses'], | ||||
|                          results['param'], results['flop'], arch_config, used_seed, results['total_epoch'], None) | ||||
|   net_config = dict2config({'name': 'infer.tiny', 'C': arch_config['channel'], 'N': arch_config['num_cells'], 'genotype': CellStructure.str2structure(arch_config['arch_str']), 'num_classes':arch_config['class_num']}, None) | ||||
|   net_config = dict2config({'name': 'infer.tiny', 'C': arch_config['channel'], 'N': arch_config['num_cells'], 'genotype': CellStructure.str2structure(arch_config['arch_str']), 'num_classes': arch_config['class_num']}, None) | ||||
|   network = get_cell_based_tiny_net(net_config) | ||||
|   network.load_state_dict(xresult.get_net_param()) | ||||
|   if 'train_times' in results: # new version | ||||
| @@ -126,7 +126,6 @@ def correct_time_related_info(arch_index: int, arch_info_full: ArchResults, arch | ||||
|     arch_info.reset_pseudo_eval_times('ImageNet16-120', None, 'ori-test', eval_per_sample * nums['ImageNet16-120-test']) | ||||
|   # arch_info_full.debug_test() | ||||
|   # arch_info_less.debug_test() | ||||
|   # import pdb; pdb.set_trace() | ||||
|   return arch_info_full, arch_info_less | ||||
|  | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user