Add int search space
This commit is contained in:
		| @@ -32,7 +32,9 @@ from utils import get_md5_file | ||||
| from nas_201_api import NASBench201API | ||||
|  | ||||
|  | ||||
| api = NASBench201API("{:}/.torch/NAS-Bench-201-v1_0-e61699.pth".format(os.environ["HOME"])) | ||||
| api = NASBench201API( | ||||
|     "{:}/.torch/NAS-Bench-201-v1_0-e61699.pth".format(os.environ["HOME"]) | ||||
| ) | ||||
|  | ||||
| NATS_TSS_BASE_NAME = "NATS-tss-v1_0"  # 2020.08.28 | ||||
|  | ||||
| @@ -68,35 +70,58 @@ def create_result_count( | ||||
|     ) | ||||
|     if "train_times" in results:  # new version | ||||
|         xresult.update_train_info( | ||||
|             results["train_acc1es"], results["train_acc5es"], results["train_losses"], results["train_times"] | ||||
|             results["train_acc1es"], | ||||
|             results["train_acc5es"], | ||||
|             results["train_losses"], | ||||
|             results["train_times"], | ||||
|         ) | ||||
|         xresult.update_eval( | ||||
|             results["valid_acc1es"], results["valid_losses"], results["valid_times"] | ||||
|         ) | ||||
|         xresult.update_eval(results["valid_acc1es"], results["valid_losses"], results["valid_times"]) | ||||
|     else: | ||||
|         network = get_cell_based_tiny_net(net_config) | ||||
|         network.load_state_dict(xresult.get_net_param()) | ||||
|         if dataset == "cifar10-valid": | ||||
|             xresult.update_OLD_eval("x-valid", results["valid_acc1es"], results["valid_losses"]) | ||||
|             xresult.update_OLD_eval( | ||||
|                 "x-valid", results["valid_acc1es"], results["valid_losses"] | ||||
|             ) | ||||
|             loss, top1, top5, latencies = pure_evaluate( | ||||
|                 dataloader_dict["{:}@{:}".format("cifar10", "test")], network.cuda() | ||||
|             ) | ||||
|             xresult.update_OLD_eval("ori-test", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss}) | ||||
|             xresult.update_OLD_eval( | ||||
|                 "ori-test", | ||||
|                 {results["total_epoch"] - 1: top1}, | ||||
|                 {results["total_epoch"] - 1: loss}, | ||||
|             ) | ||||
|             xresult.update_latency(latencies) | ||||
|         elif dataset == "cifar10": | ||||
|             xresult.update_OLD_eval("ori-test", results["valid_acc1es"], results["valid_losses"]) | ||||
|             xresult.update_OLD_eval( | ||||
|                 "ori-test", results["valid_acc1es"], results["valid_losses"] | ||||
|             ) | ||||
|             loss, top1, top5, latencies = pure_evaluate( | ||||
|                 dataloader_dict["{:}@{:}".format(dataset, "test")], network.cuda() | ||||
|             ) | ||||
|             xresult.update_latency(latencies) | ||||
|         elif dataset == "cifar100" or dataset == "ImageNet16-120": | ||||
|             xresult.update_OLD_eval("ori-test", results["valid_acc1es"], results["valid_losses"]) | ||||
|             xresult.update_OLD_eval( | ||||
|                 "ori-test", results["valid_acc1es"], results["valid_losses"] | ||||
|             ) | ||||
|             loss, top1, top5, latencies = pure_evaluate( | ||||
|                 dataloader_dict["{:}@{:}".format(dataset, "valid")], network.cuda() | ||||
|             ) | ||||
|             xresult.update_OLD_eval("x-valid", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss}) | ||||
|             xresult.update_OLD_eval( | ||||
|                 "x-valid", | ||||
|                 {results["total_epoch"] - 1: top1}, | ||||
|                 {results["total_epoch"] - 1: loss}, | ||||
|             ) | ||||
|             loss, top1, top5, latencies = pure_evaluate( | ||||
|                 dataloader_dict["{:}@{:}".format(dataset, "test")], network.cuda() | ||||
|             ) | ||||
|             xresult.update_OLD_eval("x-test", {results["total_epoch"] - 1: top1}, {results["total_epoch"] - 1: loss}) | ||||
|             xresult.update_OLD_eval( | ||||
|                 "x-test", | ||||
|                 {results["total_epoch"] - 1: top1}, | ||||
|                 {results["total_epoch"] - 1: loss}, | ||||
|             ) | ||||
|             xresult.update_latency(latencies) | ||||
|         else: | ||||
|             raise ValueError("invalid dataset name : {:}".format(dataset)) | ||||
| @@ -112,12 +137,18 @@ def account_one_arch(arch_index, arch_str, checkpoints, datasets, dataloader_dic | ||||
|         ok_dataset = 0 | ||||
|         for dataset in datasets: | ||||
|             if dataset not in checkpoint: | ||||
|                 print("Can not find {:} in arch-{:} from {:}".format(dataset, arch_index, checkpoint_path)) | ||||
|                 print( | ||||
|                     "Can not find {:} in arch-{:} from {:}".format( | ||||
|                         dataset, arch_index, checkpoint_path | ||||
|                     ) | ||||
|                 ) | ||||
|                 continue | ||||
|             else: | ||||
|                 ok_dataset += 1 | ||||
|             results = checkpoint[dataset] | ||||
|             assert results["finish-train"], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format( | ||||
|             assert results[ | ||||
|                 "finish-train" | ||||
|             ], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format( | ||||
|                 arch_index, used_seed, dataset, checkpoint_path | ||||
|             ) | ||||
|             arch_config = { | ||||
| @@ -127,7 +158,9 @@ def account_one_arch(arch_index, arch_str, checkpoints, datasets, dataloader_dic | ||||
|                 "class_num": results["config"]["class_num"], | ||||
|             } | ||||
|  | ||||
|             xresult = create_result_count(used_seed, dataset, arch_config, results, dataloader_dict) | ||||
|             xresult = create_result_count( | ||||
|                 used_seed, dataset, arch_config, results, dataloader_dict | ||||
|             ) | ||||
|             information.update(dataset, int(used_seed), xresult) | ||||
|         if ok_dataset == 0: | ||||
|             raise ValueError("{:} does not find any data".format(checkpoint_path)) | ||||
| @@ -137,7 +170,8 @@ def account_one_arch(arch_index, arch_str, checkpoints, datasets, dataloader_dic | ||||
| def correct_time_related_info(arch_index: int, arch_infos: Dict[Text, ArchResults]): | ||||
|     # calibrate the latency based on NAS-Bench-201-v1_0-e61699.pth | ||||
|     cifar010_latency = ( | ||||
|         api.get_latency(arch_index, "cifar10-valid", hp="200") + api.get_latency(arch_index, "cifar10", hp="200") | ||||
|         api.get_latency(arch_index, "cifar10-valid", hp="200") | ||||
|         + api.get_latency(arch_index, "cifar10", hp="200") | ||||
|     ) / 2 | ||||
|     cifar100_latency = api.get_latency(arch_index, "cifar100", hp="200") | ||||
|     image_latency = api.get_latency(arch_index, "ImageNet16-120", hp="200") | ||||
| @@ -147,7 +181,9 @@ def correct_time_related_info(arch_index: int, arch_infos: Dict[Text, ArchResult | ||||
|         arch_info.reset_latency("cifar100", None, cifar100_latency) | ||||
|         arch_info.reset_latency("ImageNet16-120", None, image_latency) | ||||
|  | ||||
|     train_per_epoch_time = list(arch_infos["12"].query("cifar10-valid", 777).train_times.values()) | ||||
|     train_per_epoch_time = list( | ||||
|         arch_infos["12"].query("cifar10-valid", 777).train_times.values() | ||||
|     ) | ||||
|     train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time) | ||||
|     eval_ori_test_time, eval_x_valid_time = [], [] | ||||
|     for key, value in arch_infos["12"].query("cifar10-valid", 777).eval_times.items(): | ||||
| @@ -157,7 +193,9 @@ def correct_time_related_info(arch_index: int, arch_infos: Dict[Text, ArchResult | ||||
|             eval_x_valid_time.append(value) | ||||
|         else: | ||||
|             raise ValueError("-- {:} --".format(key)) | ||||
|     eval_ori_test_time, eval_x_valid_time = float(np.mean(eval_ori_test_time)), float(np.mean(eval_x_valid_time)) | ||||
|     eval_ori_test_time, eval_x_valid_time = float(np.mean(eval_ori_test_time)), float( | ||||
|         np.mean(eval_x_valid_time) | ||||
|     ) | ||||
|     nums = { | ||||
|         "ImageNet16-120-train": 151700, | ||||
|         "ImageNet16-120-valid": 3000, | ||||
| @@ -170,36 +208,72 @@ def correct_time_related_info(arch_index: int, arch_infos: Dict[Text, ArchResult | ||||
|         "cifar100-test": 10000, | ||||
|         "cifar100-valid": 5000, | ||||
|     } | ||||
|     eval_per_sample = (eval_ori_test_time + eval_x_valid_time) / (nums["cifar10-valid-valid"] + nums["cifar10-test"]) | ||||
|     eval_per_sample = (eval_ori_test_time + eval_x_valid_time) / ( | ||||
|         nums["cifar10-valid-valid"] + nums["cifar10-test"] | ||||
|     ) | ||||
|     for hp, arch_info in arch_infos.items(): | ||||
|         arch_info.reset_pseudo_train_times( | ||||
|             "cifar10-valid", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar10-valid-train"] | ||||
|             "cifar10-valid", | ||||
|             None, | ||||
|             train_per_epoch_time | ||||
|             / nums["cifar10-valid-train"] | ||||
|             * nums["cifar10-valid-train"], | ||||
|         ) | ||||
|         arch_info.reset_pseudo_train_times( | ||||
|             "cifar10", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar10-train"] | ||||
|             "cifar10", | ||||
|             None, | ||||
|             train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar10-train"], | ||||
|         ) | ||||
|         arch_info.reset_pseudo_train_times( | ||||
|             "cifar100", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar100-train"] | ||||
|             "cifar100", | ||||
|             None, | ||||
|             train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar100-train"], | ||||
|         ) | ||||
|         arch_info.reset_pseudo_train_times( | ||||
|             "ImageNet16-120", None, train_per_epoch_time / nums["cifar10-valid-train"] * nums["ImageNet16-120-train"] | ||||
|             "ImageNet16-120", | ||||
|             None, | ||||
|             train_per_epoch_time | ||||
|             / nums["cifar10-valid-train"] | ||||
|             * nums["ImageNet16-120-train"], | ||||
|         ) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "cifar10-valid", None, "x-valid", eval_per_sample * nums["cifar10-valid-valid"] | ||||
|         ) | ||||
|         arch_info.reset_pseudo_eval_times("cifar10-valid", None, "ori-test", eval_per_sample * nums["cifar10-test"]) | ||||
|         arch_info.reset_pseudo_eval_times("cifar10", None, "ori-test", eval_per_sample * nums["cifar10-test"]) | ||||
|         arch_info.reset_pseudo_eval_times("cifar100", None, "x-valid", eval_per_sample * nums["cifar100-valid"]) | ||||
|         arch_info.reset_pseudo_eval_times("cifar100", None, "x-test", eval_per_sample * nums["cifar100-valid"]) | ||||
|         arch_info.reset_pseudo_eval_times("cifar100", None, "ori-test", eval_per_sample * nums["cifar100-test"]) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "ImageNet16-120", None, "x-valid", eval_per_sample * nums["ImageNet16-120-valid"] | ||||
|             "cifar10-valid", | ||||
|             None, | ||||
|             "x-valid", | ||||
|             eval_per_sample * nums["cifar10-valid-valid"], | ||||
|         ) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "ImageNet16-120", None, "x-test", eval_per_sample * nums["ImageNet16-120-valid"] | ||||
|             "cifar10-valid", None, "ori-test", eval_per_sample * nums["cifar10-test"] | ||||
|         ) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "ImageNet16-120", None, "ori-test", eval_per_sample * nums["ImageNet16-120-test"] | ||||
|             "cifar10", None, "ori-test", eval_per_sample * nums["cifar10-test"] | ||||
|         ) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "cifar100", None, "x-valid", eval_per_sample * nums["cifar100-valid"] | ||||
|         ) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "cifar100", None, "x-test", eval_per_sample * nums["cifar100-valid"] | ||||
|         ) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "cifar100", None, "ori-test", eval_per_sample * nums["cifar100-test"] | ||||
|         ) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "ImageNet16-120", | ||||
|             None, | ||||
|             "x-valid", | ||||
|             eval_per_sample * nums["ImageNet16-120-valid"], | ||||
|         ) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "ImageNet16-120", | ||||
|             None, | ||||
|             "x-test", | ||||
|             eval_per_sample * nums["ImageNet16-120-valid"], | ||||
|         ) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "ImageNet16-120", | ||||
|             None, | ||||
|             "ori-test", | ||||
|             eval_per_sample * nums["ImageNet16-120-test"], | ||||
|         ) | ||||
|     return arch_infos | ||||
|  | ||||
| @@ -220,7 +294,9 @@ def simplify(save_dir, save_name, nets, total, sup_config): | ||||
|             seeds.add(seed) | ||||
|             nums.append(len(xlist)) | ||||
|             print("  [seed={:}] there are {:} checkpoints.".format(seed, len(xlist))) | ||||
|         assert len(nets) == total == max(nums), "there are some missed files : {:} vs {:}".format(max(nums), total) | ||||
|         assert ( | ||||
|             len(nets) == total == max(nums) | ||||
|         ), "there are some missed files : {:} vs {:}".format(max(nums), total) | ||||
|     print("{:} start simplify the checkpoint.".format(time_string())) | ||||
|  | ||||
|     datasets = ("cifar10-valid", "cifar10", "cifar100", "ImageNet16-120") | ||||
| @@ -236,7 +312,12 @@ def simplify(save_dir, save_name, nets, total, sup_config): | ||||
|     arch2infos, evaluated_indexes = dict(), set() | ||||
|     end_time, arch_time = time.time(), AverageMeter() | ||||
|     # save the meta information | ||||
|     temp_final_infos = {"meta_archs": nets, "total_archs": total, "arch2infos": None, "evaluated_indexes": set()} | ||||
|     temp_final_infos = { | ||||
|         "meta_archs": nets, | ||||
|         "total_archs": total, | ||||
|         "arch2infos": None, | ||||
|         "evaluated_indexes": set(), | ||||
|     } | ||||
|     pickle_save(temp_final_infos, str(full_save_dir / "meta.pickle")) | ||||
|     pickle_save(temp_final_infos, str(simple_save_dir / "meta.pickle")) | ||||
|  | ||||
| @@ -248,29 +329,40 @@ def simplify(save_dir, save_name, nets, total, sup_config): | ||||
|         simple_save_path = simple_save_dir / "{:06d}.pickle".format(index) | ||||
|         for hp in hps: | ||||
|             sub_save_dir = save_dir / "raw-data-{:}".format(hp) | ||||
|             ckps = [sub_save_dir / "arch-{:06d}-seed-{:}.pth".format(index, seed) for seed in seeds] | ||||
|             ckps = [ | ||||
|                 sub_save_dir / "arch-{:06d}-seed-{:}.pth".format(index, seed) | ||||
|                 for seed in seeds | ||||
|             ] | ||||
|             ckps = [x for x in ckps if x.exists()] | ||||
|             if len(ckps) == 0: | ||||
|                 raise ValueError("Invalid data : index={:}, hp={:}".format(index, hp)) | ||||
|  | ||||
|             arch_info = account_one_arch(index, arch_str, ckps, datasets, dataloader_dict) | ||||
|             arch_info = account_one_arch( | ||||
|                 index, arch_str, ckps, datasets, dataloader_dict | ||||
|             ) | ||||
|             hp2info[hp] = arch_info | ||||
|  | ||||
|         hp2info = correct_time_related_info(index, hp2info) | ||||
|         evaluated_indexes.add(index) | ||||
|  | ||||
|         to_save_data = OrderedDict({"12": hp2info["12"].state_dict(), "200": hp2info["200"].state_dict()}) | ||||
|         to_save_data = OrderedDict( | ||||
|             {"12": hp2info["12"].state_dict(), "200": hp2info["200"].state_dict()} | ||||
|         ) | ||||
|         pickle_save(to_save_data, str(full_save_path)) | ||||
|  | ||||
|         for hp in hps: | ||||
|             hp2info[hp].clear_params() | ||||
|         to_save_data = OrderedDict({"12": hp2info["12"].state_dict(), "200": hp2info["200"].state_dict()}) | ||||
|         to_save_data = OrderedDict( | ||||
|             {"12": hp2info["12"].state_dict(), "200": hp2info["200"].state_dict()} | ||||
|         ) | ||||
|         pickle_save(to_save_data, str(simple_save_path)) | ||||
|         arch2infos[index] = to_save_data | ||||
|         # measure elapsed time | ||||
|         arch_time.update(time.time() - end_time) | ||||
|         end_time = time.time() | ||||
|         need_time = "{:}".format(convert_secs2time(arch_time.avg * (total - index - 1), True)) | ||||
|         need_time = "{:}".format( | ||||
|             convert_secs2time(arch_time.avg * (total - index - 1), True) | ||||
|         ) | ||||
|         # print('{:} {:06d}/{:06d} : still need {:}'.format(time_string(), index, total, need_time)) | ||||
|     print("{:} {:} done.".format(time_string(), save_name)) | ||||
|     final_infos = { | ||||
| @@ -303,7 +395,11 @@ def simplify(save_dir, save_name, nets, total, sup_config): | ||||
| def traverse_net(max_node): | ||||
|     aa_nas_bench_ss = get_search_spaces("cell", "nats-bench") | ||||
|     archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False) | ||||
|     print("There are {:} archs vs {:}.".format(len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2))) | ||||
|     print( | ||||
|         "There are {:} archs vs {:}.".format( | ||||
|             len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2) | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     random.seed(88)  # please do not change this line for reproducibility | ||||
|     random.shuffle(archs) | ||||
| @@ -312,10 +408,12 @@ def traverse_net(max_node): | ||||
|         == "|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|" | ||||
|     ), "please check the 0-th architecture : {:}".format(archs[0]) | ||||
|     assert ( | ||||
|         archs[9].tostr() == "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|" | ||||
|         archs[9].tostr() | ||||
|         == "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|" | ||||
|     ), "please check the 9-th architecture : {:}".format(archs[9]) | ||||
|     assert ( | ||||
|         archs[123].tostr() == "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|" | ||||
|         archs[123].tostr() | ||||
|         == "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|" | ||||
|     ), "please check the 123-th architecture : {:}".format(archs[123]) | ||||
|     return [x.tostr() for x in archs] | ||||
|  | ||||
| @@ -323,7 +421,8 @@ def traverse_net(max_node): | ||||
| if __name__ == "__main__": | ||||
|  | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="NATS-Bench (topology search space)", formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||||
|         description="NATS-Bench (topology search space)", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--base_save_dir", | ||||
| @@ -331,16 +430,26 @@ if __name__ == "__main__": | ||||
|         default="./output/NATS-Bench-topology", | ||||
|         help="The base-name of folder to save checkpoints and log.", | ||||
|     ) | ||||
|     parser.add_argument("--max_node", type=int, default=4, help="The maximum node in a cell.") | ||||
|     parser.add_argument("--channel", type=int, default=16, help="The number of channels.") | ||||
|     parser.add_argument("--num_cells", type=int, default=5, help="The number of cells in one stage.") | ||||
|     parser.add_argument( | ||||
|         "--max_node", type=int, default=4, help="The maximum node in a cell." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--channel", type=int, default=16, help="The number of channels." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--num_cells", type=int, default=5, help="The number of cells in one stage." | ||||
|     ) | ||||
|     parser.add_argument("--check_N", type=int, default=15625, help="For safety.") | ||||
|     parser.add_argument("--save_name", type=str, default="process", help="The save directory.") | ||||
|     parser.add_argument( | ||||
|         "--save_name", type=str, default="process", help="The save directory." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     nets = traverse_net(args.max_node) | ||||
|     if len(nets) != args.check_N: | ||||
|         raise ValueError("Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N)) | ||||
|         raise ValueError( | ||||
|             "Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N) | ||||
|         ) | ||||
|  | ||||
|     save_dir = Path(args.base_save_dir) | ||||
|     simplify( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user