Upgrade API of NAS-Bench-201
This commit is contained in:
		| @@ -1,7 +1,7 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | ||||
| ##################################################### | ||||
| # python exps/NAS-Bench-201/check.py --base_save_dir  | ||||
| # python exps/NAS-Bench-201/check.py --base_str C16-N5-LESS | ||||
| ##################################################### | ||||
| import sys, time, argparse, collections | ||||
| import torch | ||||
| @@ -13,10 +13,9 @@ from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
|  | ||||
|  | ||||
| def check_files(save_dir, meta_file, basestr): | ||||
|   meta_infos     = torch.load(meta_file, map_location='cpu') | ||||
|   meta_archs     = meta_infos['archs'] | ||||
|   meta_infos = torch.load(meta_file, map_location='cpu') | ||||
|   meta_archs = meta_infos['archs'] | ||||
|   meta_num_archs = meta_infos['total'] | ||||
|   meta_max_node  = meta_infos['max_node'] | ||||
|   assert meta_num_archs == len(meta_archs), 'invalid number of archs : {:} vs {:}'.format(meta_num_archs, len(meta_archs)) | ||||
|  | ||||
|   sub_model_dirs = sorted(list(save_dir.glob('*-*-{:}'.format(basestr)))) | ||||
| @@ -43,7 +42,12 @@ def check_files(save_dir, meta_file, basestr): | ||||
|   dir2ckps, dir2ckp_exists = dict(), dict() | ||||
|   start_time, epoch_time = time.time(), AverageMeter() | ||||
|   for IDX, (sub_dir, arch_indexes) in enumerate(subdir2archs.items()): | ||||
|     seeds = [777, 888, 999] | ||||
|     if basestr == 'C16-N5': | ||||
|       seeds = [777, 888, 999] | ||||
|     elif basestr == 'C16-N5-LESS': | ||||
|       seeds = [111, 777] | ||||
|     else: | ||||
|       raise ValueError('Invalid base str : {:}'.format(basestr)) | ||||
|     numrs = defaultdict(lambda: 0) | ||||
|     all_checkpoints, all_ckp_exists = [], [] | ||||
|     for arch_index in arch_indexes: | ||||
| @@ -66,17 +70,15 @@ def check_files(save_dir, meta_file, basestr): | ||||
| if __name__ == '__main__': | ||||
|  | ||||
|   parser = argparse.ArgumentParser(description='NAS Benchmark 201', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--base_save_dir',  type=str, default='./output/NAS-BENCH-201-4',     help='The base-name of folder to save checkpoints and log.') | ||||
|   parser.add_argument('--max_node',       type=int, default=4,                                 help='The maximum node in a cell.') | ||||
|   parser.add_argument('--channel',        type=int, default=16,                                help='The number of channels.') | ||||
|   parser.add_argument('--num_cells',      type=int, default=5,                                 help='The number of cells in one stage.') | ||||
|   parser.add_argument('--base_save_dir', type=str, default='./output/NAS-BENCH-201-4', help='The base-name of folder to save checkpoints and log.') | ||||
|   parser.add_argument('--meta_path',     type=str, default='./output/NAS-BENCH-201-4/meta-node-4.pth', help='The meta file path.') | ||||
|   parser.add_argument('--base_str',      type=str, default='C16-N5',                   help='The basic string.') | ||||
|   args = parser.parse_args() | ||||
|    | ||||
|   save_dir  = Path( args.base_save_dir ) | ||||
|   meta_path = save_dir / 'meta-node-{:}.pth'.format(args.max_node) | ||||
|  | ||||
|   save_dir = Path(args.base_save_dir) | ||||
|   meta_path = Path(args.meta_path) | ||||
|   assert save_dir.exists(),  'invalid save dir path : {:}'.format(save_dir) | ||||
|   assert meta_path.exists(), 'invalid saved meta path : {:}'.format(meta_path) | ||||
|   print ('check NAS-Bench-201 in {:}'.format(save_dir)) | ||||
|  | ||||
|   basestr = 'C{:}-N{:}'.format(args.channel, args.num_cells) | ||||
|   check_files(save_dir, meta_path, basestr) | ||||
|   check_files(save_dir, meta_path, args.base_str) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user