Upgrade API of NAS-Bench-201
This commit is contained in:
		| @@ -4,7 +4,6 @@ | ||||
| import os, sys, time, argparse, collections | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from pathlib import Path | ||||
| from collections import defaultdict | ||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||
| @@ -15,8 +14,7 @@ from datasets     import get_datasets | ||||
| # NAS-Bench-201 related module or function | ||||
| from models       import CellStructure, get_cell_based_tiny_net | ||||
| from nas_201_api  import ArchResults, ResultsCount | ||||
| from functions    import pure_evaluate | ||||
|  | ||||
| from procedures   import bench_pure_evaluate as pure_evaluate | ||||
|  | ||||
|  | ||||
| def create_result_count(used_seed, dataset, arch_config, results, dataloader_dict): | ||||
| @@ -69,7 +67,6 @@ def account_one_arch(arch_index, arch_str, checkpoints, datasets, dataloader_dic | ||||
|   return information | ||||
|  | ||||
|  | ||||
|  | ||||
| def GET_DataLoaders(workers): | ||||
|  | ||||
|   torch.set_num_threads(workers) | ||||
| @@ -137,7 +134,6 @@ def GET_DataLoaders(workers): | ||||
|   return loaders | ||||
|  | ||||
|  | ||||
|  | ||||
| def simplify(save_dir, meta_file, basestr, target_dir): | ||||
|   meta_infos     = torch.load(meta_file, map_location='cpu') | ||||
|   meta_archs     = meta_infos['archs'] # a list of architecture strings | ||||
| @@ -221,7 +217,6 @@ def simplify(save_dir, meta_file, basestr, target_dir): | ||||
|   print ('Save {:} / {:} architecture results into {:}.'.format(len(evaluated_indexes), meta_num_archs, save_file_name)) | ||||
|  | ||||
|  | ||||
|  | ||||
| def merge_all(save_dir, meta_file, basestr): | ||||
|   meta_infos     = torch.load(meta_file, map_location='cpu') | ||||
|   meta_archs     = meta_infos['archs'] | ||||
| @@ -268,7 +263,6 @@ def merge_all(save_dir, meta_file, basestr): | ||||
|   print ('Save {:} / {:} architecture results into {:}.'.format(len(evaluated_indexes), meta_num_archs, save_file_name)) | ||||
|  | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|  | ||||
|   parser = argparse.ArgumentParser(description='NAS-BENCH-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
| @@ -280,7 +274,7 @@ if __name__ == '__main__': | ||||
|   parser.add_argument('--num_cells'    ,  type=int, default=5,                           help='The number of cells in one stage.') | ||||
|   args = parser.parse_args() | ||||
|    | ||||
|   save_dir  = Path( args.base_save_dir ) | ||||
|   save_dir  = Path(args.base_save_dir) | ||||
|   meta_path = save_dir / 'meta-node-{:}.pth'.format(args.max_node) | ||||
|   assert save_dir.exists(),  'invalid save dir path : {:}'.format(save_dir) | ||||
|   assert meta_path.exists(), 'invalid saved meta path : {:}'.format(meta_path) | ||||
| @@ -292,4 +286,4 @@ if __name__ == '__main__': | ||||
|   elif args.mode == 'merge': | ||||
|     merge_all(save_dir, meta_path, basestr) | ||||
|   else: | ||||
|     raise ValueError('invalid mode : {:}'.format(args.mode)) | ||||
|     raise ValueError('invalid mode : {:}'.format(args.mode)) | ||||
		Reference in New Issue
	
	Block a user