Reformulate via black
This commit is contained in:
		| @@ -9,72 +9,82 @@ import os, sys, time, torch, argparse | ||||
| from typing import List, Text, Dict, Any | ||||
| from shutil import copyfile | ||||
| from collections import defaultdict | ||||
| from copy    import deepcopy | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
| from config_utils import dict2config, load_config | ||||
| from procedures   import bench_evaluate_for_seed | ||||
| from procedures   import get_machine_info | ||||
| from datasets     import get_datasets | ||||
| from log_utils    import Logger, AverageMeter, time_string, convert_secs2time | ||||
| from procedures import bench_evaluate_for_seed | ||||
| from procedures import get_machine_info | ||||
| from datasets import get_datasets | ||||
| from log_utils import Logger, AverageMeter, time_string, convert_secs2time | ||||
|  | ||||
|  | ||||
| def obtain_valid_ckp(save_dir: Text, total: int, possible_seeds: List[int]): | ||||
|   seed2ckps = defaultdict(list) | ||||
|   miss2ckps = defaultdict(list) | ||||
|   for i in range(total): | ||||
|     for seed in possible_seeds: | ||||
|       path = os.path.join(save_dir, 'arch-{:06d}-seed-{:04d}.pth'.format(i, seed)) | ||||
|       if os.path.exists(path): | ||||
|         seed2ckps[seed].append(i) | ||||
|       else: | ||||
|         miss2ckps[seed].append(i) | ||||
|   for seed, xlist in seed2ckps.items(): | ||||
|     print('[{:}] [seed={:}] has {:5d}/{:5d} | miss {:5d}/{:5d}'.format(save_dir, seed, len(xlist), total, total-len(xlist), total)) | ||||
|   return dict(seed2ckps), dict(miss2ckps) | ||||
|      | ||||
|     seed2ckps = defaultdict(list) | ||||
|     miss2ckps = defaultdict(list) | ||||
|     for i in range(total): | ||||
|         for seed in possible_seeds: | ||||
|             path = os.path.join(save_dir, "arch-{:06d}-seed-{:04d}.pth".format(i, seed)) | ||||
|             if os.path.exists(path): | ||||
|                 seed2ckps[seed].append(i) | ||||
|             else: | ||||
|                 miss2ckps[seed].append(i) | ||||
|     for seed, xlist in seed2ckps.items(): | ||||
|         print( | ||||
|             "[{:}] [seed={:}] has {:5d}/{:5d} | miss {:5d}/{:5d}".format( | ||||
|                 save_dir, seed, len(xlist), total, total - len(xlist), total | ||||
|             ) | ||||
|         ) | ||||
|     return dict(seed2ckps), dict(miss2ckps) | ||||
|  | ||||
|  | ||||
| def copy_data(source_dir, target_dir, meta_path): | ||||
|   target_dir = Path(target_dir) | ||||
|   target_dir.mkdir(parents=True, exist_ok=True) | ||||
|   miss2ckps = torch.load(meta_path)['miss2ckps'] | ||||
|   s2t = {} | ||||
|   for seed, xlist in miss2ckps.items(): | ||||
|     for i in xlist: | ||||
|       file_name = 'arch-{:06d}-seed-{:04d}.pth'.format(i, seed) | ||||
|       source_path = os.path.join(source_dir, file_name) | ||||
|       target_path = os.path.join(target_dir, file_name) | ||||
|       if os.path.exists(source_path): | ||||
|         s2t[source_path] = target_path | ||||
|   print('Map from {:} to {:}, find {:} missed ckps.'.format(source_dir, target_dir, len(s2t))) | ||||
|   for s, t in s2t.items(): | ||||
|     copyfile(s, t) | ||||
|     target_dir = Path(target_dir) | ||||
|     target_dir.mkdir(parents=True, exist_ok=True) | ||||
|     miss2ckps = torch.load(meta_path)["miss2ckps"] | ||||
|     s2t = {} | ||||
|     for seed, xlist in miss2ckps.items(): | ||||
|         for i in xlist: | ||||
|             file_name = "arch-{:06d}-seed-{:04d}.pth".format(i, seed) | ||||
|             source_path = os.path.join(source_dir, file_name) | ||||
|             target_path = os.path.join(target_dir, file_name) | ||||
|             if os.path.exists(source_path): | ||||
|                 s2t[source_path] = target_path | ||||
|     print("Map from {:} to {:}, find {:} missed ckps.".format(source_dir, target_dir, len(s2t))) | ||||
|     for s, t in s2t.items(): | ||||
|         copyfile(s, t) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   parser = argparse.ArgumentParser(description='NATS-Bench (topology search space) file manager.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--mode',        type=str, required=True, choices=['check', 'copy'], help='The script mode.') | ||||
|   parser.add_argument('--save_dir',    type=str, default='output/NATS-Bench-topology', help='Folder to save checkpoints and log.') | ||||
|   parser.add_argument('--check_N',     type=int, default=15625,  help='For safety.') | ||||
|   # use for train the model | ||||
|   args = parser.parse_args() | ||||
|   possible_configs = ['12', '200'] | ||||
|   possible_seedss = [[111, 777], [777, 888, 999]] | ||||
|   if args.mode == 'check': | ||||
|     for config, possible_seeds in zip(possible_configs, possible_seedss): | ||||
|       cur_save_dir = '{:}/raw-data-{:}'.format(args.save_dir, config) | ||||
|       seed2ckps, miss2ckps = obtain_valid_ckp(cur_save_dir, args.check_N, possible_seeds) | ||||
|       torch.save(dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps), '{:}/meta-{:}.pth'.format(args.save_dir, config)) | ||||
|   elif args.mode == 'copy': | ||||
|     for config in possible_configs: | ||||
|       cur_save_dir = '{:}/raw-data-{:}'.format(args.save_dir, config) | ||||
|       cur_copy_dir = '{:}/copy-{:}'.format(args.save_dir, config) | ||||
|       cur_meta_path = '{:}/meta-{:}.pth'.format(args.save_dir, config) | ||||
|       if os.path.exists(cur_meta_path): | ||||
|         copy_data(cur_save_dir, cur_copy_dir, cur_meta_path) | ||||
|       else: | ||||
|         print('Do not find : {:}'.format(cur_meta_path)) | ||||
|   else: | ||||
|     raise ValueError('invalid mode : {:}'.format(args.mode)) | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="NATS-Bench (topology search space) file manager.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--mode", type=str, required=True, choices=["check", "copy"], help="The script mode.") | ||||
|     parser.add_argument( | ||||
|         "--save_dir", type=str, default="output/NATS-Bench-topology", help="Folder to save checkpoints and log." | ||||
|     ) | ||||
|     parser.add_argument("--check_N", type=int, default=15625, help="For safety.") | ||||
|     # use for train the model | ||||
|     args = parser.parse_args() | ||||
|     possible_configs = ["12", "200"] | ||||
|     possible_seedss = [[111, 777], [777, 888, 999]] | ||||
|     if args.mode == "check": | ||||
|         for config, possible_seeds in zip(possible_configs, possible_seedss): | ||||
|             cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config) | ||||
|             seed2ckps, miss2ckps = obtain_valid_ckp(cur_save_dir, args.check_N, possible_seeds) | ||||
|             torch.save(dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps), "{:}/meta-{:}.pth".format(args.save_dir, config)) | ||||
|     elif args.mode == "copy": | ||||
|         for config in possible_configs: | ||||
|             cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config) | ||||
|             cur_copy_dir = "{:}/copy-{:}".format(args.save_dir, config) | ||||
|             cur_meta_path = "{:}/meta-{:}.pth".format(args.save_dir, config) | ||||
|             if os.path.exists(cur_meta_path): | ||||
|                 copy_data(cur_save_dir, cur_copy_dir, cur_meta_path) | ||||
|             else: | ||||
|                 print("Do not find : {:}".format(cur_meta_path)) | ||||
|     else: | ||||
|         raise ValueError("invalid mode : {:}".format(args.mode)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user