Reformulate via black
This commit is contained in:
		| @@ -7,78 +7,112 @@ import sys, time, argparse, collections | ||||
| import torch | ||||
| from pathlib import Path | ||||
| from collections import defaultdict | ||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
| from log_utils import AverageMeter, time_string, convert_secs2time | ||||
|  | ||||
|  | ||||
| def check_files(save_dir, meta_file, basestr): | ||||
|   meta_infos = torch.load(meta_file, map_location='cpu') | ||||
|   meta_archs = meta_infos['archs'] | ||||
|   meta_num_archs = meta_infos['total'] | ||||
|   assert meta_num_archs == len(meta_archs), 'invalid number of archs : {:} vs {:}'.format(meta_num_archs, len(meta_archs)) | ||||
|     meta_infos = torch.load(meta_file, map_location="cpu") | ||||
|     meta_archs = meta_infos["archs"] | ||||
|     meta_num_archs = meta_infos["total"] | ||||
|     assert meta_num_archs == len(meta_archs), "invalid number of archs : {:} vs {:}".format( | ||||
|         meta_num_archs, len(meta_archs) | ||||
|     ) | ||||
|  | ||||
|   sub_model_dirs = sorted(list(save_dir.glob('*-*-{:}'.format(basestr)))) | ||||
|   print ('{:} find {:} directories used to save checkpoints'.format(time_string(), len(sub_model_dirs))) | ||||
|    | ||||
|   subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0 | ||||
|   num_seeds = defaultdict(lambda: 0) | ||||
|   for index, sub_dir in enumerate(sub_model_dirs): | ||||
|     xcheckpoints = list(sub_dir.glob('arch-*-seed-*.pth')) | ||||
|     #xcheckpoints = list(sub_dir.glob('arch-*-seed-0777.pth')) + list(sub_dir.glob('arch-*-seed-0888.pth')) + list(sub_dir.glob('arch-*-seed-0999.pth')) | ||||
|     arch_indexes = set() | ||||
|     for checkpoint in xcheckpoints: | ||||
|       temp_names = checkpoint.name.split('-') | ||||
|       assert len(temp_names) == 4 and temp_names[0] == 'arch' and temp_names[2] == 'seed', 'invalid checkpoint name : {:}'.format(checkpoint.name) | ||||
|       arch_indexes.add( temp_names[1] ) | ||||
|     subdir2archs[sub_dir] = sorted(list(arch_indexes)) | ||||
|     num_evaluated_arch   += len(arch_indexes) | ||||
|     # count number of seeds for each architecture | ||||
|     for arch_index in arch_indexes: | ||||
|       num_seeds[ len(list(sub_dir.glob('arch-{:}-seed-*.pth'.format(arch_index)))) ] += 1 | ||||
|   print('There are {:5d} architectures that have been evaluated ({:} in total, {:} ckps in total).'.format(num_evaluated_arch, meta_num_archs, sum(k*v for k, v in num_seeds.items()))) | ||||
|   for key in sorted( list( num_seeds.keys() ) ): print ('There are {:5d} architectures that are evaluated {:} times.'.format(num_seeds[key], key)) | ||||
|     sub_model_dirs = sorted(list(save_dir.glob("*-*-{:}".format(basestr)))) | ||||
|     print("{:} find {:} directories used to save checkpoints".format(time_string(), len(sub_model_dirs))) | ||||
|  | ||||
|   dir2ckps, dir2ckp_exists = dict(), dict() | ||||
|   start_time, epoch_time = time.time(), AverageMeter() | ||||
|   for IDX, (sub_dir, arch_indexes) in enumerate(subdir2archs.items()): | ||||
|     if basestr == 'C16-N5': | ||||
|       seeds = [777, 888, 999] | ||||
|     elif basestr == 'C16-N5-LESS': | ||||
|       seeds = [111, 777] | ||||
|     else: | ||||
|       raise ValueError('Invalid base str : {:}'.format(basestr)) | ||||
|     numrs = defaultdict(lambda: 0) | ||||
|     all_checkpoints, all_ckp_exists = [], [] | ||||
|     for arch_index in arch_indexes: | ||||
|       checkpoints = ['arch-{:}-seed-{:04d}.pth'.format(arch_index, seed) for seed in seeds] | ||||
|       ckp_exists  = [(sub_dir/x).exists() for x in checkpoints] | ||||
|       arch_index  = int(arch_index) | ||||
|       assert 0 <= arch_index < len(meta_archs), 'invalid arch-index {:} (not found in meta_archs)'.format(arch_index) | ||||
|       all_checkpoints += checkpoints | ||||
|       all_ckp_exists  += ckp_exists | ||||
|       numrs[sum(ckp_exists)] += 1 | ||||
|     dir2ckps[ str(sub_dir) ]       = all_checkpoints | ||||
|     dir2ckp_exists[ str(sub_dir) ] = all_ckp_exists | ||||
|     # measure time | ||||
|     epoch_time.update(time.time() - start_time) | ||||
|     start_time = time.time() | ||||
|     numrstr = ', '.join( ['{:}: {:03d}'.format(x, numrs[x]) for x in sorted(numrs.keys())] ) | ||||
|     print('{:} load [{:2d}/{:2d}] [{:03d} archs] [{:04d}->{:04d} ckps] {:} done, need {:}. {:}'.format(time_string(), IDX+1, len(subdir2archs), len(arch_indexes), len(all_checkpoints), sum(all_ckp_exists), sub_dir, convert_secs2time(epoch_time.avg * (len(subdir2archs)-IDX-1), True), numrstr)) | ||||
|     subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0 | ||||
|     num_seeds = defaultdict(lambda: 0) | ||||
|     for index, sub_dir in enumerate(sub_model_dirs): | ||||
|         xcheckpoints = list(sub_dir.glob("arch-*-seed-*.pth")) | ||||
|         # xcheckpoints = list(sub_dir.glob('arch-*-seed-0777.pth')) + list(sub_dir.glob('arch-*-seed-0888.pth')) + list(sub_dir.glob('arch-*-seed-0999.pth')) | ||||
|         arch_indexes = set() | ||||
|         for checkpoint in xcheckpoints: | ||||
|             temp_names = checkpoint.name.split("-") | ||||
|             assert ( | ||||
|                 len(temp_names) == 4 and temp_names[0] == "arch" and temp_names[2] == "seed" | ||||
|             ), "invalid checkpoint name : {:}".format(checkpoint.name) | ||||
|             arch_indexes.add(temp_names[1]) | ||||
|         subdir2archs[sub_dir] = sorted(list(arch_indexes)) | ||||
|         num_evaluated_arch += len(arch_indexes) | ||||
|         # count number of seeds for each architecture | ||||
|         for arch_index in arch_indexes: | ||||
|             num_seeds[len(list(sub_dir.glob("arch-{:}-seed-*.pth".format(arch_index))))] += 1 | ||||
|     print( | ||||
|         "There are {:5d} architectures that have been evaluated ({:} in total, {:} ckps in total).".format( | ||||
|             num_evaluated_arch, meta_num_archs, sum(k * v for k, v in num_seeds.items()) | ||||
|         ) | ||||
|     ) | ||||
|     for key in sorted(list(num_seeds.keys())): | ||||
|         print("There are {:5d} architectures that are evaluated {:} times.".format(num_seeds[key], key)) | ||||
|  | ||||
|     dir2ckps, dir2ckp_exists = dict(), dict() | ||||
|     start_time, epoch_time = time.time(), AverageMeter() | ||||
|     for IDX, (sub_dir, arch_indexes) in enumerate(subdir2archs.items()): | ||||
|         if basestr == "C16-N5": | ||||
|             seeds = [777, 888, 999] | ||||
|         elif basestr == "C16-N5-LESS": | ||||
|             seeds = [111, 777] | ||||
|         else: | ||||
|             raise ValueError("Invalid base str : {:}".format(basestr)) | ||||
|         numrs = defaultdict(lambda: 0) | ||||
|         all_checkpoints, all_ckp_exists = [], [] | ||||
|         for arch_index in arch_indexes: | ||||
|             checkpoints = ["arch-{:}-seed-{:04d}.pth".format(arch_index, seed) for seed in seeds] | ||||
|             ckp_exists = [(sub_dir / x).exists() for x in checkpoints] | ||||
|             arch_index = int(arch_index) | ||||
|             assert 0 <= arch_index < len(meta_archs), "invalid arch-index {:} (not found in meta_archs)".format( | ||||
|                 arch_index | ||||
|             ) | ||||
|             all_checkpoints += checkpoints | ||||
|             all_ckp_exists += ckp_exists | ||||
|             numrs[sum(ckp_exists)] += 1 | ||||
|         dir2ckps[str(sub_dir)] = all_checkpoints | ||||
|         dir2ckp_exists[str(sub_dir)] = all_ckp_exists | ||||
|         # measure time | ||||
|         epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|         numrstr = ", ".join(["{:}: {:03d}".format(x, numrs[x]) for x in sorted(numrs.keys())]) | ||||
|         print( | ||||
|             "{:} load [{:2d}/{:2d}] [{:03d} archs] [{:04d}->{:04d} ckps] {:} done, need {:}. {:}".format( | ||||
|                 time_string(), | ||||
|                 IDX + 1, | ||||
|                 len(subdir2archs), | ||||
|                 len(arch_indexes), | ||||
|                 len(all_checkpoints), | ||||
|                 sum(all_ckp_exists), | ||||
|                 sub_dir, | ||||
|                 convert_secs2time(epoch_time.avg * (len(subdir2archs) - IDX - 1), True), | ||||
|                 numrstr, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
| if __name__ == "__main__": | ||||
|  | ||||
|   parser = argparse.ArgumentParser(description='NAS Benchmark 201', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--base_save_dir', type=str, default='./output/NAS-BENCH-201-4', help='The base-name of folder to save checkpoints and log.') | ||||
|   parser.add_argument('--meta_path',     type=str, default='./output/NAS-BENCH-201-4/meta-node-4.pth', help='The meta file path.') | ||||
|   parser.add_argument('--base_str',      type=str, default='C16-N5',                   help='The basic string.') | ||||
|   args = parser.parse_args() | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="NAS Benchmark 201", formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--base_save_dir", | ||||
|         type=str, | ||||
|         default="./output/NAS-BENCH-201-4", | ||||
|         help="The base-name of folder to save checkpoints and log.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--meta_path", type=str, default="./output/NAS-BENCH-201-4/meta-node-4.pth", help="The meta file path." | ||||
|     ) | ||||
|     parser.add_argument("--base_str", type=str, default="C16-N5", help="The basic string.") | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|   save_dir = Path(args.base_save_dir) | ||||
|   meta_path = Path(args.meta_path) | ||||
|   assert save_dir.exists(),  'invalid save dir path : {:}'.format(save_dir) | ||||
|   assert meta_path.exists(), 'invalid saved meta path : {:}'.format(meta_path) | ||||
|   print ('check NAS-Bench-201 in {:}'.format(save_dir)) | ||||
|     save_dir = Path(args.base_save_dir) | ||||
|     meta_path = Path(args.meta_path) | ||||
|     assert save_dir.exists(), "invalid save dir path : {:}".format(save_dir) | ||||
|     assert meta_path.exists(), "invalid saved meta path : {:}".format(meta_path) | ||||
|     print("check NAS-Bench-201 in {:}".format(save_dir)) | ||||
|  | ||||
|   check_files(save_dir, meta_path, args.base_str) | ||||
|     check_files(save_dir, meta_path, args.base_str) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user