Add int search space
This commit is contained in:
		| @@ -31,24 +31,34 @@ from utils import get_md5_file | ||||
| NATS_SSS_BASE_NAME = "NATS-sss-v1_0"  # 2020.08.28 | ||||
|  | ||||
|  | ||||
| def account_one_arch(arch_index: int, arch_str: Text, checkpoints: List[Text], datasets: List[Text]) -> ArchResults: | ||||
| def account_one_arch( | ||||
|     arch_index: int, arch_str: Text, checkpoints: List[Text], datasets: List[Text] | ||||
| ) -> ArchResults: | ||||
|     information = ArchResults(arch_index, arch_str) | ||||
|  | ||||
|     for checkpoint_path in checkpoints: | ||||
|         try: | ||||
|             checkpoint = torch.load(checkpoint_path, map_location="cpu") | ||||
|         except: | ||||
|             raise ValueError("This checkpoint failed to be loaded : {:}".format(checkpoint_path)) | ||||
|             raise ValueError( | ||||
|                 "This checkpoint failed to be loaded : {:}".format(checkpoint_path) | ||||
|             ) | ||||
|         used_seed = checkpoint_path.name.split("-")[-1].split(".")[0] | ||||
|         ok_dataset = 0 | ||||
|         for dataset in datasets: | ||||
|             if dataset not in checkpoint: | ||||
|                 print("Can not find {:} in arch-{:} from {:}".format(dataset, arch_index, checkpoint_path)) | ||||
|                 print( | ||||
|                     "Can not find {:} in arch-{:} from {:}".format( | ||||
|                         dataset, arch_index, checkpoint_path | ||||
|                     ) | ||||
|                 ) | ||||
|                 continue | ||||
|             else: | ||||
|                 ok_dataset += 1 | ||||
|             results = checkpoint[dataset] | ||||
|             assert results["finish-train"], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format( | ||||
|             assert results[ | ||||
|                 "finish-train" | ||||
|             ], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format( | ||||
|                 arch_index, used_seed, dataset, checkpoint_path | ||||
|             ) | ||||
|             arch_config = { | ||||
| @@ -71,13 +81,20 @@ def account_one_arch(arch_index: int, arch_str: Text, checkpoints: List[Text], d | ||||
|                 None, | ||||
|             ) | ||||
|             xresult.update_train_info( | ||||
|                 results["train_acc1es"], results["train_acc5es"], results["train_losses"], results["train_times"] | ||||
|                 results["train_acc1es"], | ||||
|                 results["train_acc5es"], | ||||
|                 results["train_losses"], | ||||
|                 results["train_times"], | ||||
|             ) | ||||
|             xresult.update_eval( | ||||
|                 results["valid_acc1es"], results["valid_losses"], results["valid_times"] | ||||
|             ) | ||||
|             xresult.update_eval(results["valid_acc1es"], results["valid_losses"], results["valid_times"]) | ||||
|             information.update(dataset, int(used_seed), xresult) | ||||
|         if ok_dataset < len(datasets): | ||||
|             raise ValueError( | ||||
|                 "{:} does find enought data : {:} vs {:}".format(checkpoint_path, ok_dataset, len(datasets)) | ||||
|                 "{:} does find enought data : {:} vs {:}".format( | ||||
|                     checkpoint_path, ok_dataset, len(datasets) | ||||
|                 ) | ||||
|             ) | ||||
|     return information | ||||
|  | ||||
| @@ -107,7 +124,9 @@ def correct_time_related_info(hp2info: Dict[Text, ArchResults]): | ||||
|         arch_info.reset_latency("ImageNet16-120", None, image_latency) | ||||
|  | ||||
|     # CIFAR10 VALID | ||||
|     train_per_epoch_time = list(hp2info["01"].query("cifar10-valid", 777).train_times.values()) | ||||
|     train_per_epoch_time = list( | ||||
|         hp2info["01"].query("cifar10-valid", 777).train_times.values() | ||||
|     ) | ||||
|     train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time) | ||||
|     eval_ori_test_time, eval_x_valid_time = [], [] | ||||
|     for key, value in hp2info["01"].query("cifar10-valid", 777).eval_times.items(): | ||||
| @@ -121,11 +140,17 @@ def correct_time_related_info(hp2info: Dict[Text, ArchResults]): | ||||
|     eval_x_valid_time = sum(eval_x_valid_time) / len(eval_x_valid_time) | ||||
|     for hp, arch_info in hp2info.items(): | ||||
|         arch_info.reset_pseudo_train_times("cifar10-valid", None, train_per_epoch_time) | ||||
|         arch_info.reset_pseudo_eval_times("cifar10-valid", None, "x-valid", eval_x_valid_time) | ||||
|         arch_info.reset_pseudo_eval_times("cifar10-valid", None, "ori-test", eval_ori_test_time) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "cifar10-valid", None, "x-valid", eval_x_valid_time | ||||
|         ) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "cifar10-valid", None, "ori-test", eval_ori_test_time | ||||
|         ) | ||||
|  | ||||
|     # CIFAR10 | ||||
|     train_per_epoch_time = list(hp2info["01"].query("cifar10", 777).train_times.values()) | ||||
|     train_per_epoch_time = list( | ||||
|         hp2info["01"].query("cifar10", 777).train_times.values() | ||||
|     ) | ||||
|     train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time) | ||||
|     eval_ori_test_time = [] | ||||
|     for key, value in hp2info["01"].query("cifar10", 777).eval_times.items(): | ||||
| @@ -136,10 +161,14 @@ def correct_time_related_info(hp2info: Dict[Text, ArchResults]): | ||||
|     eval_ori_test_time = sum(eval_ori_test_time) / len(eval_ori_test_time) | ||||
|     for hp, arch_info in hp2info.items(): | ||||
|         arch_info.reset_pseudo_train_times("cifar10", None, train_per_epoch_time) | ||||
|         arch_info.reset_pseudo_eval_times("cifar10", None, "ori-test", eval_ori_test_time) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "cifar10", None, "ori-test", eval_ori_test_time | ||||
|         ) | ||||
|  | ||||
|     # CIFAR100 | ||||
|     train_per_epoch_time = list(hp2info["01"].query("cifar100", 777).train_times.values()) | ||||
|     train_per_epoch_time = list( | ||||
|         hp2info["01"].query("cifar100", 777).train_times.values() | ||||
|     ) | ||||
|     train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time) | ||||
|     eval_ori_test_time, eval_x_valid_time, eval_x_test_time = [], [], [] | ||||
|     for key, value in hp2info["01"].query("cifar100", 777).eval_times.items(): | ||||
| @@ -156,12 +185,18 @@ def correct_time_related_info(hp2info: Dict[Text, ArchResults]): | ||||
|     eval_x_test_time = sum(eval_x_test_time) / len(eval_x_test_time) | ||||
|     for hp, arch_info in hp2info.items(): | ||||
|         arch_info.reset_pseudo_train_times("cifar100", None, train_per_epoch_time) | ||||
|         arch_info.reset_pseudo_eval_times("cifar100", None, "x-valid", eval_x_valid_time) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "cifar100", None, "x-valid", eval_x_valid_time | ||||
|         ) | ||||
|         arch_info.reset_pseudo_eval_times("cifar100", None, "x-test", eval_x_test_time) | ||||
|         arch_info.reset_pseudo_eval_times("cifar100", None, "ori-test", eval_ori_test_time) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "cifar100", None, "ori-test", eval_ori_test_time | ||||
|         ) | ||||
|  | ||||
|     # ImageNet16-120 | ||||
|     train_per_epoch_time = list(hp2info["01"].query("ImageNet16-120", 777).train_times.values()) | ||||
|     train_per_epoch_time = list( | ||||
|         hp2info["01"].query("ImageNet16-120", 777).train_times.values() | ||||
|     ) | ||||
|     train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time) | ||||
|     eval_ori_test_time, eval_x_valid_time, eval_x_test_time = [], [], [] | ||||
|     for key, value in hp2info["01"].query("ImageNet16-120", 777).eval_times.items(): | ||||
| @@ -178,9 +213,15 @@ def correct_time_related_info(hp2info: Dict[Text, ArchResults]): | ||||
|     eval_x_test_time = sum(eval_x_test_time) / len(eval_x_test_time) | ||||
|     for hp, arch_info in hp2info.items(): | ||||
|         arch_info.reset_pseudo_train_times("ImageNet16-120", None, train_per_epoch_time) | ||||
|         arch_info.reset_pseudo_eval_times("ImageNet16-120", None, "x-valid", eval_x_valid_time) | ||||
|         arch_info.reset_pseudo_eval_times("ImageNet16-120", None, "x-test", eval_x_test_time) | ||||
|         arch_info.reset_pseudo_eval_times("ImageNet16-120", None, "ori-test", eval_ori_test_time) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "ImageNet16-120", None, "x-valid", eval_x_valid_time | ||||
|         ) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "ImageNet16-120", None, "x-test", eval_x_test_time | ||||
|         ) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "ImageNet16-120", None, "ori-test", eval_ori_test_time | ||||
|         ) | ||||
|     return hp2info | ||||
|  | ||||
|  | ||||
| @@ -200,7 +241,9 @@ def simplify(save_dir, save_name, nets, total): | ||||
|             seeds.add(seed) | ||||
|             nums.append(len(xlist)) | ||||
|             print("  [seed={:}] there are {:} checkpoints.".format(seed, len(xlist))) | ||||
|         assert len(nets) == total == max(nums), "there are some missed files : {:} vs {:}".format(max(nums), total) | ||||
|         assert ( | ||||
|             len(nets) == total == max(nums) | ||||
|         ), "there are some missed files : {:} vs {:}".format(max(nums), total) | ||||
|     print("{:} start simplify the checkpoint.".format(time_string())) | ||||
|  | ||||
|     datasets = ("cifar10-valid", "cifar10", "cifar100", "ImageNet16-120") | ||||
| @@ -225,7 +268,10 @@ def simplify(save_dir, save_name, nets, total): | ||||
|  | ||||
|         for hp in hps: | ||||
|             sub_save_dir = save_dir / "raw-data-{:}".format(hp) | ||||
|             ckps = [sub_save_dir / "arch-{:06d}-seed-{:}.pth".format(index, seed) for seed in seeds] | ||||
|             ckps = [ | ||||
|                 sub_save_dir / "arch-{:06d}-seed-{:}.pth".format(index, seed) | ||||
|                 for seed in seeds | ||||
|             ] | ||||
|             ckps = [x for x in ckps if x.exists()] | ||||
|             if len(ckps) == 0: | ||||
|                 raise ValueError("Invalid data : index={:}, hp={:}".format(index, hp)) | ||||
| @@ -238,21 +284,31 @@ def simplify(save_dir, save_name, nets, total): | ||||
|  | ||||
|         hp2info["01"].clear_params()  # to save some spaces... | ||||
|         to_save_data = OrderedDict( | ||||
|             {"01": hp2info["01"].state_dict(), "12": hp2info["12"].state_dict(), "90": hp2info["90"].state_dict()} | ||||
|             { | ||||
|                 "01": hp2info["01"].state_dict(), | ||||
|                 "12": hp2info["12"].state_dict(), | ||||
|                 "90": hp2info["90"].state_dict(), | ||||
|             } | ||||
|         ) | ||||
|         pickle_save(to_save_data, str(full_save_path)) | ||||
|  | ||||
|         for hp in hps: | ||||
|             hp2info[hp].clear_params() | ||||
|         to_save_data = OrderedDict( | ||||
|             {"01": hp2info["01"].state_dict(), "12": hp2info["12"].state_dict(), "90": hp2info["90"].state_dict()} | ||||
|             { | ||||
|                 "01": hp2info["01"].state_dict(), | ||||
|                 "12": hp2info["12"].state_dict(), | ||||
|                 "90": hp2info["90"].state_dict(), | ||||
|             } | ||||
|         ) | ||||
|         pickle_save(to_save_data, str(simple_save_path)) | ||||
|         arch2infos[index] = to_save_data | ||||
|         # measure elapsed time | ||||
|         arch_time.update(time.time() - end_time) | ||||
|         end_time = time.time() | ||||
|         need_time = "{:}".format(convert_secs2time(arch_time.avg * (total - index - 1), True)) | ||||
|         need_time = "{:}".format( | ||||
|             convert_secs2time(arch_time.avg * (total - index - 1), True) | ||||
|         ) | ||||
|         # print('{:} {:06d}/{:06d} : still need {:}'.format(time_string(), index, total, need_time)) | ||||
|     print("{:} {:} done.".format(time_string(), save_name)) | ||||
|     final_infos = { | ||||
| @@ -297,7 +353,8 @@ def traverse_net(candidates: List[int], N: int): | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="NATS-Bench (size search space)", formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||||
|         description="NATS-Bench (size search space)", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--base_save_dir", | ||||
| @@ -305,15 +362,27 @@ if __name__ == "__main__": | ||||
|         default="./output/NATS-Bench-size", | ||||
|         help="The base-name of folder to save checkpoints and log.", | ||||
|     ) | ||||
|     parser.add_argument("--candidateC", type=int, nargs="+", default=[8, 16, 24, 32, 40, 48, 56, 64], help=".") | ||||
|     parser.add_argument("--num_layers", type=int, default=5, help="The number of layers in a network.") | ||||
|     parser.add_argument( | ||||
|         "--candidateC", | ||||
|         type=int, | ||||
|         nargs="+", | ||||
|         default=[8, 16, 24, 32, 40, 48, 56, 64], | ||||
|         help=".", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--num_layers", type=int, default=5, help="The number of layers in a network." | ||||
|     ) | ||||
|     parser.add_argument("--check_N", type=int, default=32768, help="For safety.") | ||||
|     parser.add_argument("--save_name", type=str, default="process", help="The save directory.") | ||||
|     parser.add_argument( | ||||
|         "--save_name", type=str, default="process", help="The save directory." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     nets = traverse_net(args.candidateC, args.num_layers) | ||||
|     if len(nets) != args.check_N: | ||||
|         raise ValueError("Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N)) | ||||
|         raise ValueError( | ||||
|             "Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N) | ||||
|         ) | ||||
|  | ||||
|     save_dir = Path(args.base_save_dir) | ||||
|     simplify(save_dir, args.save_name, nets, args.check_N) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user