107 lines
		
	
	
		
			4.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			107 lines
		
	
	
		
			4.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| ##############################################################################
 | |
| # NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size #
 | |
| ##############################################################################
 | |
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08                          #
 | |
| ##############################################################################
 | |
| # Usage: python exps/NATS-Bench/sss-file-manager.py --mode check             #
 | |
| ##############################################################################
 | |
| import os, sys, time, torch, argparse
 | |
| from typing import List, Text, Dict, Any
 | |
| from shutil import copyfile
 | |
| from collections import defaultdict
 | |
| from copy import deepcopy
 | |
| from pathlib import Path
 | |
| 
 | |
| lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
 | |
| if str(lib_dir) not in sys.path:
 | |
|     sys.path.insert(0, str(lib_dir))
 | |
| from config_utils import dict2config, load_config
 | |
| from procedures import bench_evaluate_for_seed
 | |
| from procedures import get_machine_info
 | |
| from datasets import get_datasets
 | |
| from log_utils import Logger, AverageMeter, time_string, convert_secs2time
 | |
| 
 | |
| 
 | |
| def obtain_valid_ckp(save_dir: Text, total: int):
 | |
|     possible_seeds = [777, 888, 999]
 | |
|     seed2ckps = defaultdict(list)
 | |
|     miss2ckps = defaultdict(list)
 | |
|     for i in range(total):
 | |
|         for seed in possible_seeds:
 | |
|             path = os.path.join(save_dir, "arch-{:06d}-seed-{:04d}.pth".format(i, seed))
 | |
|             if os.path.exists(path):
 | |
|                 seed2ckps[seed].append(i)
 | |
|             else:
 | |
|                 miss2ckps[seed].append(i)
 | |
|     for seed, xlist in seed2ckps.items():
 | |
|         print(
 | |
|             "[{:}] [seed={:}] has {:5d}/{:5d} | miss {:5d}/{:5d}".format(
 | |
|                 save_dir, seed, len(xlist), total, total - len(xlist), total
 | |
|             )
 | |
|         )
 | |
|     return dict(seed2ckps), dict(miss2ckps)
 | |
| 
 | |
| 
 | |
| def copy_data(source_dir, target_dir, meta_path):
 | |
|     target_dir = Path(target_dir)
 | |
|     target_dir.mkdir(parents=True, exist_ok=True)
 | |
|     miss2ckps = torch.load(meta_path)["miss2ckps"]
 | |
|     s2t = {}
 | |
|     for seed, xlist in miss2ckps.items():
 | |
|         for i in xlist:
 | |
|             file_name = "arch-{:06d}-seed-{:04d}.pth".format(i, seed)
 | |
|             source_path = os.path.join(source_dir, file_name)
 | |
|             target_path = os.path.join(target_dir, file_name)
 | |
|             if os.path.exists(source_path):
 | |
|                 s2t[source_path] = target_path
 | |
|     print(
 | |
|         "Map from {:} to {:}, find {:} missed ckps.".format(
 | |
|             source_dir, target_dir, len(s2t)
 | |
|         )
 | |
|     )
 | |
|     for s, t in s2t.items():
 | |
|         copyfile(s, t)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     parser = argparse.ArgumentParser(
 | |
|         description="NATS-Bench (size search space) file manager.",
 | |
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--mode",
 | |
|         type=str,
 | |
|         required=True,
 | |
|         choices=["check", "copy"],
 | |
|         help="The script mode.",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--save_dir",
 | |
|         type=str,
 | |
|         default="output/NATS-Bench-size",
 | |
|         help="Folder to save checkpoints and log.",
 | |
|     )
 | |
|     parser.add_argument("--check_N", type=int, default=32768, help="For safety.")
 | |
|     # use for train the model
 | |
|     args = parser.parse_args()
 | |
|     possible_configs = ["01", "12", "90"]
 | |
|     if args.mode == "check":
 | |
|         for config in possible_configs:
 | |
|             cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config)
 | |
|             seed2ckps, miss2ckps = obtain_valid_ckp(cur_save_dir, args.check_N)
 | |
|             torch.save(
 | |
|                 dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps),
 | |
|                 "{:}/meta-{:}.pth".format(args.save_dir, config),
 | |
|             )
 | |
|     elif args.mode == "copy":
 | |
|         for config in possible_configs:
 | |
|             cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config)
 | |
|             cur_copy_dir = "{:}/copy-{:}".format(args.save_dir, config)
 | |
|             cur_meta_path = "{:}/meta-{:}.pth".format(args.save_dir, config)
 | |
|             if os.path.exists(cur_meta_path):
 | |
|                 copy_data(cur_save_dir, cur_copy_dir, cur_meta_path)
 | |
|             else:
 | |
|                 print("Do not find : {:}".format(cur_meta_path))
 | |
|     else:
 | |
|         raise ValueError("invalid mode : {:}".format(args.mode))
 |