| 
									
										
										
										
											2020-09-02 07:34:12 +00:00
										 |  |  | ############################################################################## | 
					
						
							| 
									
										
										
										
											2021-01-25 21:48:14 +08:00
										 |  |  | # NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size # | 
					
						
							| 
									
										
										
										
											2020-09-02 07:34:12 +00:00
										 |  |  | ############################################################################## | 
					
						
							|  |  |  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08                          # | 
					
						
							|  |  |  | ############################################################################## | 
					
						
							|  |  |  | # Usage: python exps/NATS-Bench/tss-file-manager.py --mode check             # | 
					
						
							|  |  |  | ############################################################################## | 
					
						
							|  |  |  | import os, sys, time, torch, argparse | 
					
						
							|  |  |  | from typing import List, Text, Dict, Any | 
					
						
							|  |  |  | from shutil import copyfile | 
					
						
							|  |  |  | from collections import defaultdict | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  | from copy import deepcopy | 
					
						
							| 
									
										
										
										
											2020-09-02 07:34:12 +00:00
										 |  |  | from pathlib import Path | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  | lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | 
					
						
							|  |  |  | if str(lib_dir) not in sys.path: | 
					
						
							|  |  |  |     sys.path.insert(0, str(lib_dir)) | 
					
						
							| 
									
										
										
										
											2020-09-02 07:34:12 +00:00
										 |  |  | from config_utils import dict2config, load_config | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  | from procedures import bench_evaluate_for_seed | 
					
						
							|  |  |  | from procedures import get_machine_info | 
					
						
							|  |  |  | from datasets import get_datasets | 
					
						
							|  |  |  | from log_utils import Logger, AverageMeter, time_string, convert_secs2time | 
					
						
							| 
									
										
										
										
											2020-09-02 07:34:12 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def obtain_valid_ckp(save_dir: Text, total: int, possible_seeds: List[int]): | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     seed2ckps = defaultdict(list) | 
					
						
							|  |  |  |     miss2ckps = defaultdict(list) | 
					
						
							|  |  |  |     for i in range(total): | 
					
						
							|  |  |  |         for seed in possible_seeds: | 
					
						
							|  |  |  |             path = os.path.join(save_dir, "arch-{:06d}-seed-{:04d}.pth".format(i, seed)) | 
					
						
							|  |  |  |             if os.path.exists(path): | 
					
						
							|  |  |  |                 seed2ckps[seed].append(i) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 miss2ckps[seed].append(i) | 
					
						
							|  |  |  |     for seed, xlist in seed2ckps.items(): | 
					
						
							|  |  |  |         print( | 
					
						
							|  |  |  |             "[{:}] [seed={:}] has {:5d}/{:5d} | miss {:5d}/{:5d}".format( | 
					
						
							|  |  |  |                 save_dir, seed, len(xlist), total, total - len(xlist), total | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     return dict(seed2ckps), dict(miss2ckps) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-02 07:34:12 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | def copy_data(source_dir, target_dir, meta_path): | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     target_dir = Path(target_dir) | 
					
						
							|  |  |  |     target_dir.mkdir(parents=True, exist_ok=True) | 
					
						
							|  |  |  |     miss2ckps = torch.load(meta_path)["miss2ckps"] | 
					
						
							|  |  |  |     s2t = {} | 
					
						
							|  |  |  |     for seed, xlist in miss2ckps.items(): | 
					
						
							|  |  |  |         for i in xlist: | 
					
						
							|  |  |  |             file_name = "arch-{:06d}-seed-{:04d}.pth".format(i, seed) | 
					
						
							|  |  |  |             source_path = os.path.join(source_dir, file_name) | 
					
						
							|  |  |  |             target_path = os.path.join(target_dir, file_name) | 
					
						
							|  |  |  |             if os.path.exists(source_path): | 
					
						
							|  |  |  |                 s2t[source_path] = target_path | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |     print( | 
					
						
							|  |  |  |         "Map from {:} to {:}, find {:} missed ckps.".format( | 
					
						
							|  |  |  |             source_dir, target_dir, len(s2t) | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     for s, t in s2t.items(): | 
					
						
							|  |  |  |         copyfile(s, t) | 
					
						
							| 
									
										
										
										
											2020-09-02 07:34:12 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  |     parser = argparse.ArgumentParser( | 
					
						
							|  |  |  |         description="NATS-Bench (topology search space) file manager.", | 
					
						
							|  |  |  |         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |         "--mode", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         required=True, | 
					
						
							|  |  |  |         choices=["check", "copy"], | 
					
						
							|  |  |  |         help="The script mode.", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--save_dir", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         default="output/NATS-Bench-topology", | 
					
						
							|  |  |  |         help="Folder to save checkpoints and log.", | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument("--check_N", type=int, default=15625, help="For safety.") | 
					
						
							|  |  |  |     # use for train the model | 
					
						
							|  |  |  |     args = parser.parse_args() | 
					
						
							|  |  |  |     possible_configs = ["12", "200"] | 
					
						
							|  |  |  |     possible_seedss = [[111, 777], [777, 888, 999]] | 
					
						
							|  |  |  |     if args.mode == "check": | 
					
						
							|  |  |  |         for config, possible_seeds in zip(possible_configs, possible_seedss): | 
					
						
							|  |  |  |             cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config) | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |             seed2ckps, miss2ckps = obtain_valid_ckp( | 
					
						
							|  |  |  |                 cur_save_dir, args.check_N, possible_seeds | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             torch.save( | 
					
						
							|  |  |  |                 dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps), | 
					
						
							|  |  |  |                 "{:}/meta-{:}.pth".format(args.save_dir, config), | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     elif args.mode == "copy": | 
					
						
							|  |  |  |         for config in possible_configs: | 
					
						
							|  |  |  |             cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config) | 
					
						
							|  |  |  |             cur_copy_dir = "{:}/copy-{:}".format(args.save_dir, config) | 
					
						
							|  |  |  |             cur_meta_path = "{:}/meta-{:}.pth".format(args.save_dir, config) | 
					
						
							|  |  |  |             if os.path.exists(cur_meta_path): | 
					
						
							|  |  |  |                 copy_data(cur_save_dir, cur_copy_dir, cur_meta_path) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 print("Do not find : {:}".format(cur_meta_path)) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         raise ValueError("invalid mode : {:}".format(args.mode)) |