Add int search space
This commit is contained in:
		| @@ -35,9 +35,15 @@ def visualize_relative_info(api, vis_save_dir, indicator): | ||||
|     # print ('{:} start to visualize {:} information'.format(time_string(), api)) | ||||
|     vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|     cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator) | ||||
|     cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator) | ||||
|     imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator) | ||||
|     cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | ||||
|         "cifar10", indicator | ||||
|     ) | ||||
|     cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | ||||
|         "cifar100", indicator | ||||
|     ) | ||||
|     imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | ||||
|         "ImageNet16-120", indicator | ||||
|     ) | ||||
|     cifar010_info = torch.load(cifar010_cache_path) | ||||
|     cifar100_info = torch.load(cifar100_cache_path) | ||||
|     imagenet_info = torch.load(imagenet_cache_path) | ||||
| @@ -65,8 +71,15 @@ def visualize_relative_info(api, vis_save_dir, indicator): | ||||
|     plt.xlim(min(indexes), max(indexes)) | ||||
|     plt.ylim(min(indexes), max(indexes)) | ||||
|     # plt.ylabel('y').set_rotation(30) | ||||
|     plt.yticks(np.arange(min(indexes), max(indexes), max(indexes) // 3), fontsize=LegendFontsize, rotation="vertical") | ||||
|     plt.xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5), fontsize=LegendFontsize) | ||||
|     plt.yticks( | ||||
|         np.arange(min(indexes), max(indexes), max(indexes) // 3), | ||||
|         fontsize=LegendFontsize, | ||||
|         rotation="vertical", | ||||
|     ) | ||||
|     plt.xticks( | ||||
|         np.arange(min(indexes), max(indexes), max(indexes) // 5), | ||||
|         fontsize=LegendFontsize, | ||||
|     ) | ||||
|     ax.scatter(indexes, cifar100_labels, marker="^", s=0.5, c="tab:green", alpha=0.8) | ||||
|     ax.scatter(indexes, imagenet_labels, marker="*", s=0.5, c="tab:red", alpha=0.8) | ||||
|     ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8) | ||||
| @@ -102,7 +115,9 @@ def visualize_sss_info(api, dataset, vis_save_dir): | ||||
|             train_accs.append(info["train-accuracy"]) | ||||
|             test_accs.append(info["test-accuracy"]) | ||||
|             if dataset == "cifar10": | ||||
|                 info = api.get_more_info(index, "cifar10-valid", hp="90", is_random=False) | ||||
|                 info = api.get_more_info( | ||||
|                     index, "cifar10-valid", hp="90", is_random=False | ||||
|                 ) | ||||
|                 valid_accs.append(info["valid-accuracy"]) | ||||
|             else: | ||||
|                 valid_accs.append(info["valid-accuracy"]) | ||||
| @@ -263,7 +278,9 @@ def visualize_tss_info(api, dataset, vis_save_dir): | ||||
|             train_accs.append(info["train-accuracy"]) | ||||
|             test_accs.append(info["test-accuracy"]) | ||||
|             if dataset == "cifar10": | ||||
|                 info = api.get_more_info(index, "cifar10-valid", hp="200", is_random=False) | ||||
|                 info = api.get_more_info( | ||||
|                     index, "cifar10-valid", hp="200", is_random=False | ||||
|                 ) | ||||
|                 valid_accs.append(info["valid-accuracy"]) | ||||
|             else: | ||||
|                 valid_accs.append(info["valid-accuracy"]) | ||||
| @@ -288,7 +305,9 @@ def visualize_tss_info(api, dataset, vis_save_dir): | ||||
|         ) | ||||
|     print("{:} collect data done.".format(time_string())) | ||||
|  | ||||
|     resnet = ["|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|"] | ||||
|     resnet = [ | ||||
|         "|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|" | ||||
|     ] | ||||
|     resnet_indexes = [api.query_index_by_arch(x) for x in resnet] | ||||
|     largest_indexes = [ | ||||
|         api.query_index_by_arch( | ||||
| @@ -415,9 +434,15 @@ def visualize_rank_info(api, vis_save_dir, indicator): | ||||
|     # print ('{:} start to visualize {:} information'.format(time_string(), api)) | ||||
|     vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|     cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator) | ||||
|     cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator) | ||||
|     imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator) | ||||
|     cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | ||||
|         "cifar10", indicator | ||||
|     ) | ||||
|     cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | ||||
|         "cifar100", indicator | ||||
|     ) | ||||
|     imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | ||||
|         "ImageNet16-120", indicator | ||||
|     ) | ||||
|     cifar010_info = torch.load(cifar010_cache_path) | ||||
|     cifar100_info = torch.load(cifar100_cache_path) | ||||
|     imagenet_info = torch.load(imagenet_cache_path) | ||||
| @@ -452,8 +477,17 @@ def visualize_rank_info(api, vis_save_dir, indicator): | ||||
|         ax.xaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 5)) | ||||
|         ax.scatter(indexes, labels, marker="^", s=0.5, c="tab:green", alpha=0.8) | ||||
|         ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8) | ||||
|         ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="{:} test".format(name)) | ||||
|         ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="{:} validation".format(name)) | ||||
|         ax.scatter( | ||||
|             [-1], [-1], marker="^", s=100, c="tab:green", label="{:} test".format(name) | ||||
|         ) | ||||
|         ax.scatter( | ||||
|             [-1], | ||||
|             [-1], | ||||
|             marker="o", | ||||
|             s=100, | ||||
|             c="tab:blue", | ||||
|             label="{:} validation".format(name), | ||||
|         ) | ||||
|         ax.legend(loc=4, fontsize=LegendFontsize) | ||||
|         ax.set_xlabel("ranking on the {:} validation".format(name), fontsize=LabelSize) | ||||
|         ax.set_ylabel("architecture ranking", fontsize=LabelSize) | ||||
| @@ -465,9 +499,13 @@ def visualize_rank_info(api, vis_save_dir, indicator): | ||||
|     labels = get_labels(imagenet_info) | ||||
|     plot_ax(labels, ax3, "ImageNet-16-120") | ||||
|  | ||||
|     save_path = (vis_save_dir / "{:}-same-relative-rank.pdf".format(indicator)).resolve() | ||||
|     save_path = ( | ||||
|         vis_save_dir / "{:}-same-relative-rank.pdf".format(indicator) | ||||
|     ).resolve() | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") | ||||
|     save_path = (vis_save_dir / "{:}-same-relative-rank.png".format(indicator)).resolve() | ||||
|     save_path = ( | ||||
|         vis_save_dir / "{:}-same-relative-rank.png".format(indicator) | ||||
|     ).resolve() | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") | ||||
|     print("{:} save into {:}".format(time_string(), save_path)) | ||||
|     plt.close("all") | ||||
| @@ -496,9 +534,15 @@ def visualize_all_rank_info(api, vis_save_dir, indicator): | ||||
|     # print ('{:} start to visualize {:} information'.format(time_string(), api)) | ||||
|     vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|     cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar10", indicator) | ||||
|     cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("cifar100", indicator) | ||||
|     imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format("ImageNet16-120", indicator) | ||||
|     cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | ||||
|         "cifar10", indicator | ||||
|     ) | ||||
|     cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | ||||
|         "cifar100", indicator | ||||
|     ) | ||||
|     imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | ||||
|         "ImageNet16-120", indicator | ||||
|     ) | ||||
|     cifar010_info = torch.load(cifar010_cache_path) | ||||
|     cifar100_info = torch.load(cifar100_cache_path) | ||||
|     imagenet_info = torch.load(imagenet_cache_path) | ||||
| @@ -564,7 +608,9 @@ def visualize_all_rank_info(api, vis_save_dir, indicator): | ||||
|         yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"], | ||||
|     ) | ||||
|     ax1.set_title("Correlation coefficient over ALL candidates") | ||||
|     ax2.set_title("Correlation coefficient over candidates with accuracy > {:}%".format(acc_bar)) | ||||
|     ax2.set_title( | ||||
|         "Correlation coefficient over candidates with accuracy > {:}%".format(acc_bar) | ||||
|     ) | ||||
|     save_path = (vis_save_dir / "{:}-all-relative-rank.png".format(indicator)).resolve() | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") | ||||
|     print("{:} save into {:}".format(time_string(), save_path)) | ||||
| @@ -572,9 +618,14 @@ def visualize_all_rank_info(api, vis_save_dir, indicator): | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser(description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", type=str, default="output/vis-nas-bench", help="Folder to save checkpoints and log." | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="output/vis-nas-bench", | ||||
|         help="Folder to save checkpoints and log.", | ||||
|     ) | ||||
|     # use for train the model | ||||
|     args = parser.parse_args() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user