Add int search space
This commit is contained in:
		| @@ -30,12 +30,20 @@ from models import get_cell_based_tiny_net | ||||
| from nats_bench import create | ||||
|  | ||||
|  | ||||
| name2label = {"cifar10": "CIFAR-10", "cifar100": "CIFAR-100", "ImageNet16-120": "ImageNet-16-120"} | ||||
| name2label = { | ||||
|     "cifar10": "CIFAR-10", | ||||
|     "cifar100": "CIFAR-100", | ||||
|     "ImageNet16-120": "ImageNet-16-120", | ||||
| } | ||||
|  | ||||
|  | ||||
| def visualize_relative_info(vis_save_dir, search_space, indicator, topk): | ||||
|     vis_save_dir = vis_save_dir.resolve() | ||||
|     print("{:} start to visualize {:} with top-{:} information".format(time_string(), search_space, topk)) | ||||
|     print( | ||||
|         "{:} start to visualize {:} with top-{:} information".format( | ||||
|             time_string(), search_space, topk | ||||
|         ) | ||||
|     ) | ||||
|     vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|     cache_file_path = vis_save_dir / "cache-{:}-info.pth".format(search_space) | ||||
|     datasets = ["cifar10", "cifar100", "ImageNet16-120"] | ||||
| @@ -46,8 +54,12 @@ def visualize_relative_info(vis_save_dir, search_space, indicator, topk): | ||||
|             all_info = OrderedDict() | ||||
|             for dataset in datasets: | ||||
|                 info_less = api.get_more_info(index, dataset, hp="12", is_random=False) | ||||
|                 info_more = api.get_more_info(index, dataset, hp=api.full_train_epochs, is_random=False) | ||||
|                 all_info[dataset] = dict(less=info_less["test-accuracy"], more=info_more["test-accuracy"]) | ||||
|                 info_more = api.get_more_info( | ||||
|                     index, dataset, hp=api.full_train_epochs, is_random=False | ||||
|                 ) | ||||
|                 all_info[dataset] = dict( | ||||
|                     less=info_less["test-accuracy"], more=info_more["test-accuracy"] | ||||
|                 ) | ||||
|             all_infos[index] = all_info | ||||
|         torch.save(all_infos, cache_file_path) | ||||
|         print("{:} save all cache data into {:}".format(time_string(), cache_file_path)) | ||||
| @@ -80,12 +92,18 @@ def visualize_relative_info(vis_save_dir, search_space, indicator, topk): | ||||
|         for idx in selected_indexes: | ||||
|             standard_scores.append( | ||||
|                 api.get_more_info( | ||||
|                     idx, dataset, hp=api.full_train_epochs if indicator == "more" else "12", is_random=False | ||||
|                     idx, | ||||
|                     dataset, | ||||
|                     hp=api.full_train_epochs if indicator == "more" else "12", | ||||
|                     is_random=False, | ||||
|                 )["test-accuracy"] | ||||
|             ) | ||||
|             random_scores.append( | ||||
|                 api.get_more_info( | ||||
|                     idx, dataset, hp=api.full_train_epochs if indicator == "more" else "12", is_random=True | ||||
|                     idx, | ||||
|                     dataset, | ||||
|                     hp=api.full_train_epochs if indicator == "more" else "12", | ||||
|                     is_random=True, | ||||
|                 )["test-accuracy"] | ||||
|             ) | ||||
|         indexes = list(range(len(selected_indexes))) | ||||
| @@ -105,11 +123,28 @@ def visualize_relative_info(vis_save_dir, search_space, indicator, topk): | ||||
|         ax.set_xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5)) | ||||
|         ax.scatter(indexes, random_labels, marker="^", s=0.5, c="tab:green", alpha=0.8) | ||||
|         ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8) | ||||
|         ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="Average Over Multi-Trials") | ||||
|         ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="Randomly Selected Trial") | ||||
|         ax.scatter( | ||||
|             [-1], | ||||
|             [-1], | ||||
|             marker="o", | ||||
|             s=100, | ||||
|             c="tab:blue", | ||||
|             label="Average Over Multi-Trials", | ||||
|         ) | ||||
|         ax.scatter( | ||||
|             [-1], | ||||
|             [-1], | ||||
|             marker="^", | ||||
|             s=100, | ||||
|             c="tab:green", | ||||
|             label="Randomly Selected Trial", | ||||
|         ) | ||||
|  | ||||
|         coef, p = scipy.stats.kendalltau(standard_scores, random_scores) | ||||
|         ax.set_xlabel("architecture ranking in {:}".format(name2label[dataset]), fontsize=LabelSize) | ||||
|         ax.set_xlabel( | ||||
|             "architecture ranking in {:}".format(name2label[dataset]), | ||||
|             fontsize=LabelSize, | ||||
|         ) | ||||
|         if dataset == "cifar10": | ||||
|             ax.set_ylabel("architecture ranking", fontsize=LabelSize) | ||||
|         ax.legend(loc=4, fontsize=LegendFontsize) | ||||
| @@ -117,17 +152,27 @@ def visualize_relative_info(vis_save_dir, search_space, indicator, topk): | ||||
|  | ||||
|     for dataset, ax in zip(datasets, axs): | ||||
|         rank_coef = sub_plot_fn(ax, dataset, indicator) | ||||
|         print("sub-plot {:} on {:} done, the ranking coefficient is {:.4f}.".format(dataset, search_space, rank_coef)) | ||||
|         print( | ||||
|             "sub-plot {:} on {:} done, the ranking coefficient is {:.4f}.".format( | ||||
|                 dataset, search_space, rank_coef | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     save_path = (vis_save_dir / "{:}-rank-{:}-top{:}.pdf".format(search_space, indicator, topk)).resolve() | ||||
|     save_path = ( | ||||
|         vis_save_dir / "{:}-rank-{:}-top{:}.pdf".format(search_space, indicator, topk) | ||||
|     ).resolve() | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") | ||||
|     save_path = (vis_save_dir / "{:}-rank-{:}-top{:}.png".format(search_space, indicator, topk)).resolve() | ||||
|     save_path = ( | ||||
|         vis_save_dir / "{:}-rank-{:}-top{:}.png".format(search_space, indicator, topk) | ||||
|     ).resolve() | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") | ||||
|     print("Save into {:}".format(save_path)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser(description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user