85 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			85 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #####################################################
 | |
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
 | |
| #####################################################
 | |
| # python exps/NAS-Bench-201/check.py --base_str C16-N5-LESS
 | |
| #####################################################
 | |
| import sys, time, argparse, collections
 | |
| import torch
 | |
| from pathlib import Path
 | |
| from collections import defaultdict
 | |
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
 | |
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
 | |
| from log_utils    import AverageMeter, time_string, convert_secs2time
 | |
| 
 | |
| 
 | |
| def check_files(save_dir, meta_file, basestr):
 | |
|   meta_infos = torch.load(meta_file, map_location='cpu')
 | |
|   meta_archs = meta_infos['archs']
 | |
|   meta_num_archs = meta_infos['total']
 | |
|   assert meta_num_archs == len(meta_archs), 'invalid number of archs : {:} vs {:}'.format(meta_num_archs, len(meta_archs))
 | |
| 
 | |
|   sub_model_dirs = sorted(list(save_dir.glob('*-*-{:}'.format(basestr))))
 | |
|   print ('{:} find {:} directories used to save checkpoints'.format(time_string(), len(sub_model_dirs)))
 | |
|   
 | |
|   subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0
 | |
|   num_seeds = defaultdict(lambda: 0)
 | |
|   for index, sub_dir in enumerate(sub_model_dirs):
 | |
|     xcheckpoints = list(sub_dir.glob('arch-*-seed-*.pth'))
 | |
|     #xcheckpoints = list(sub_dir.glob('arch-*-seed-0777.pth')) + list(sub_dir.glob('arch-*-seed-0888.pth')) + list(sub_dir.glob('arch-*-seed-0999.pth'))
 | |
|     arch_indexes = set()
 | |
|     for checkpoint in xcheckpoints:
 | |
|       temp_names = checkpoint.name.split('-')
 | |
|       assert len(temp_names) == 4 and temp_names[0] == 'arch' and temp_names[2] == 'seed', 'invalid checkpoint name : {:}'.format(checkpoint.name)
 | |
|       arch_indexes.add( temp_names[1] )
 | |
|     subdir2archs[sub_dir] = sorted(list(arch_indexes))
 | |
|     num_evaluated_arch   += len(arch_indexes)
 | |
|     # count number of seeds for each architecture
 | |
|     for arch_index in arch_indexes:
 | |
|       num_seeds[ len(list(sub_dir.glob('arch-{:}-seed-*.pth'.format(arch_index)))) ] += 1
 | |
|   print('There are {:5d} architectures that have been evaluated ({:} in total, {:} ckps in total).'.format(num_evaluated_arch, meta_num_archs, sum(k*v for k, v in num_seeds.items())))
 | |
|   for key in sorted( list( num_seeds.keys() ) ): print ('There are {:5d} architectures that are evaluated {:} times.'.format(num_seeds[key], key))
 | |
| 
 | |
|   dir2ckps, dir2ckp_exists = dict(), dict()
 | |
|   start_time, epoch_time = time.time(), AverageMeter()
 | |
|   for IDX, (sub_dir, arch_indexes) in enumerate(subdir2archs.items()):
 | |
|     if basestr == 'C16-N5':
 | |
|       seeds = [777, 888, 999]
 | |
|     elif basestr == 'C16-N5-LESS':
 | |
|       seeds = [111, 777]
 | |
|     else:
 | |
|       raise ValueError('Invalid base str : {:}'.format(basestr))
 | |
|     numrs = defaultdict(lambda: 0)
 | |
|     all_checkpoints, all_ckp_exists = [], []
 | |
|     for arch_index in arch_indexes:
 | |
|       checkpoints = ['arch-{:}-seed-{:04d}.pth'.format(arch_index, seed) for seed in seeds]
 | |
|       ckp_exists  = [(sub_dir/x).exists() for x in checkpoints]
 | |
|       arch_index  = int(arch_index)
 | |
|       assert 0 <= arch_index < len(meta_archs), 'invalid arch-index {:} (not found in meta_archs)'.format(arch_index)
 | |
|       all_checkpoints += checkpoints
 | |
|       all_ckp_exists  += ckp_exists
 | |
|       numrs[sum(ckp_exists)] += 1
 | |
|     dir2ckps[ str(sub_dir) ]       = all_checkpoints
 | |
|     dir2ckp_exists[ str(sub_dir) ] = all_ckp_exists
 | |
|     # measure time
 | |
|     epoch_time.update(time.time() - start_time)
 | |
|     start_time = time.time()
 | |
|     numrstr = ', '.join( ['{:}: {:03d}'.format(x, numrs[x]) for x in sorted(numrs.keys())] )
 | |
|     print('{:} load [{:2d}/{:2d}] [{:03d} archs] [{:04d}->{:04d} ckps] {:} done, need {:}. {:}'.format(time_string(), IDX+1, len(subdir2archs), len(arch_indexes), len(all_checkpoints), sum(all_ckp_exists), sub_dir, convert_secs2time(epoch_time.avg * (len(subdir2archs)-IDX-1), True), numrstr))
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
| 
 | |
|   parser = argparse.ArgumentParser(description='NAS Benchmark 201', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
 | |
|   parser.add_argument('--base_save_dir', type=str, default='./output/NAS-BENCH-201-4', help='The base-name of folder to save checkpoints and log.')
 | |
|   parser.add_argument('--meta_path',     type=str, default='./output/NAS-BENCH-201-4/meta-node-4.pth', help='The meta file path.')
 | |
|   parser.add_argument('--base_str',      type=str, default='C16-N5',                   help='The basic string.')
 | |
|   args = parser.parse_args()
 | |
| 
 | |
|   save_dir = Path(args.base_save_dir)
 | |
|   meta_path = Path(args.meta_path)
 | |
|   assert save_dir.exists(),  'invalid save dir path : {:}'.format(save_dir)
 | |
|   assert meta_path.exists(), 'invalid saved meta path : {:}'.format(meta_path)
 | |
|   print ('check NAS-Bench-201 in {:}'.format(save_dir))
 | |
| 
 | |
|   check_files(save_dir, meta_path, args.base_str)
 |