Add int search space
This commit is contained in:
		| @@ -53,7 +53,11 @@ def copy_data(source_dir, target_dir, meta_path): | ||||
|             target_path = os.path.join(target_dir, file_name) | ||||
|             if os.path.exists(source_path): | ||||
|                 s2t[source_path] = target_path | ||||
|     print("Map from {:} to {:}, find {:} missed ckps.".format(source_dir, target_dir, len(s2t))) | ||||
|     print( | ||||
|         "Map from {:} to {:}, find {:} missed ckps.".format( | ||||
|             source_dir, target_dir, len(s2t) | ||||
|         ) | ||||
|     ) | ||||
|     for s, t in s2t.items(): | ||||
|         copyfile(s, t) | ||||
|  | ||||
| @@ -63,9 +67,18 @@ if __name__ == "__main__": | ||||
|         description="NATS-Bench (topology search space) file manager.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--mode", type=str, required=True, choices=["check", "copy"], help="The script mode.") | ||||
|     parser.add_argument( | ||||
|         "--save_dir", type=str, default="output/NATS-Bench-topology", help="Folder to save checkpoints and log." | ||||
|         "--mode", | ||||
|         type=str, | ||||
|         required=True, | ||||
|         choices=["check", "copy"], | ||||
|         help="The script mode.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="output/NATS-Bench-topology", | ||||
|         help="Folder to save checkpoints and log.", | ||||
|     ) | ||||
|     parser.add_argument("--check_N", type=int, default=15625, help="For safety.") | ||||
|     # use for train the model | ||||
| @@ -75,8 +88,13 @@ if __name__ == "__main__": | ||||
|     if args.mode == "check": | ||||
|         for config, possible_seeds in zip(possible_configs, possible_seedss): | ||||
|             cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config) | ||||
|             seed2ckps, miss2ckps = obtain_valid_ckp(cur_save_dir, args.check_N, possible_seeds) | ||||
|             torch.save(dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps), "{:}/meta-{:}.pth".format(args.save_dir, config)) | ||||
|             seed2ckps, miss2ckps = obtain_valid_ckp( | ||||
|                 cur_save_dir, args.check_N, possible_seeds | ||||
|             ) | ||||
|             torch.save( | ||||
|                 dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps), | ||||
|                 "{:}/meta-{:}.pth".format(args.save_dir, config), | ||||
|             ) | ||||
|     elif args.mode == "copy": | ||||
|         for config in possible_configs: | ||||
|             cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user