| 
									
										
										
										
											2020-11-30 00:48:10 +08:00
										 |  |  | ############################################################### | 
					
						
							| 
									
										
										
										
											2021-01-25 21:48:14 +08:00
										 |  |  | # NATS-Bench (arxiv.org/pdf/2009.00437.pdf), IEEE TPAMI 2021  # | 
					
						
							| 
									
										
										
										
											2020-11-30 00:48:10 +08:00
										 |  |  | # The code to draw Figure 2 / 3 / 4 / 5 in our paper.         # | 
					
						
							|  |  |  | ############################################################### | 
					
						
							|  |  |  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06           # | 
					
						
							|  |  |  | ############################################################### | 
					
						
							|  |  |  | # Usage: python exps/NATS-Bench/draw-fig2_5.py                # | 
					
						
							|  |  |  | ############################################################### | 
					
						
							|  |  |  | import os, sys, time, torch, argparse | 
					
						
							|  |  |  | import scipy | 
					
						
							|  |  |  | import numpy as np | 
					
						
							|  |  |  | from typing import List, Text, Dict, Any | 
					
						
							|  |  |  | from shutil import copyfile | 
					
						
							|  |  |  | from collections import defaultdict | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  | from copy import deepcopy | 
					
						
							| 
									
										
										
										
											2020-11-30 00:48:10 +08:00
										 |  |  | from pathlib import Path | 
					
						
							|  |  |  | import matplotlib | 
					
						
							|  |  |  | import seaborn as sns | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | matplotlib.use("agg") | 
					
						
							| 
									
										
										
										
											2020-11-30 00:48:10 +08:00
										 |  |  | import matplotlib.pyplot as plt | 
					
						
							|  |  |  | import matplotlib.ticker as ticker | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  | lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | 
					
						
							|  |  |  | if str(lib_dir) not in sys.path: | 
					
						
							|  |  |  |     sys.path.insert(0, str(lib_dir)) | 
					
						
							| 
									
										
										
										
											2020-11-30 00:48:10 +08:00
										 |  |  | from config_utils import dict2config, load_config | 
					
						
							|  |  |  | from log_utils import time_string | 
					
						
							|  |  |  | from models import get_cell_based_tiny_net | 
					
						
							|  |  |  | from nats_bench import create | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def visualize_relative_info(api, vis_save_dir, indicator): | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     vis_save_dir = vis_save_dir.resolve() | 
					
						
							|  |  |  |     # print ('{:} start to visualize {:} information'.format(time_string(), api)) | 
					
						
							|  |  |  |     vis_save_dir.mkdir(parents=True, exist_ok=True) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |     cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | 
					
						
							|  |  |  |         "cifar10", indicator | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | 
					
						
							|  |  |  |         "cifar100", indicator | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | 
					
						
							|  |  |  |         "ImageNet16-120", indicator | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     cifar010_info = torch.load(cifar010_cache_path) | 
					
						
							|  |  |  |     cifar100_info = torch.load(cifar100_cache_path) | 
					
						
							|  |  |  |     imagenet_info = torch.load(imagenet_cache_path) | 
					
						
							|  |  |  |     indexes = list(range(len(cifar010_info["params"]))) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     print("{:} start to visualize relative ranking".format(time_string())) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info["test_accs"][i]) | 
					
						
							|  |  |  |     cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info["test_accs"][i]) | 
					
						
							|  |  |  |     imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info["test_accs"][i]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     cifar100_labels, imagenet_labels = [], [] | 
					
						
							|  |  |  |     for idx in cifar010_ord_indexes: | 
					
						
							|  |  |  |         cifar100_labels.append(cifar100_ord_indexes.index(idx)) | 
					
						
							|  |  |  |         imagenet_labels.append(imagenet_ord_indexes.index(idx)) | 
					
						
							|  |  |  |     print("{:} prepare data done.".format(time_string())) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     dpi, width, height = 200, 1400, 800 | 
					
						
							|  |  |  |     figsize = width / float(dpi), height / float(dpi) | 
					
						
							|  |  |  |     LabelSize, LegendFontsize = 18, 12 | 
					
						
							|  |  |  |     resnet_scale, resnet_alpha = 120, 0.5 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     fig = plt.figure(figsize=figsize) | 
					
						
							|  |  |  |     ax = fig.add_subplot(111) | 
					
						
							|  |  |  |     plt.xlim(min(indexes), max(indexes)) | 
					
						
							|  |  |  |     plt.ylim(min(indexes), max(indexes)) | 
					
						
							|  |  |  |     # plt.ylabel('y').set_rotation(30) | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |     plt.yticks( | 
					
						
							|  |  |  |         np.arange(min(indexes), max(indexes), max(indexes) // 3), | 
					
						
							|  |  |  |         fontsize=LegendFontsize, | 
					
						
							|  |  |  |         rotation="vertical", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     plt.xticks( | 
					
						
							|  |  |  |         np.arange(min(indexes), max(indexes), max(indexes) // 5), | 
					
						
							|  |  |  |         fontsize=LegendFontsize, | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     ax.scatter(indexes, cifar100_labels, marker="^", s=0.5, c="tab:green", alpha=0.8) | 
					
						
							|  |  |  |     ax.scatter(indexes, imagenet_labels, marker="*", s=0.5, c="tab:red", alpha=0.8) | 
					
						
							|  |  |  |     ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8) | 
					
						
							|  |  |  |     ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="CIFAR-10") | 
					
						
							|  |  |  |     ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="CIFAR-100") | 
					
						
							|  |  |  |     ax.scatter([-1], [-1], marker="*", s=100, c="tab:red", label="ImageNet-16-120") | 
					
						
							|  |  |  |     plt.grid(zorder=0) | 
					
						
							|  |  |  |     ax.set_axisbelow(True) | 
					
						
							|  |  |  |     plt.legend(loc=0, fontsize=LegendFontsize) | 
					
						
							|  |  |  |     ax.set_xlabel("architecture ranking in CIFAR-10", fontsize=LabelSize) | 
					
						
							|  |  |  |     ax.set_ylabel("architecture ranking", fontsize=LabelSize) | 
					
						
							|  |  |  |     save_path = (vis_save_dir / "{:}-relative-rank.pdf".format(indicator)).resolve() | 
					
						
							|  |  |  |     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") | 
					
						
							|  |  |  |     save_path = (vis_save_dir / "{:}-relative-rank.png".format(indicator)).resolve() | 
					
						
							|  |  |  |     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") | 
					
						
							|  |  |  |     print("{:} save into {:}".format(time_string(), save_path)) | 
					
						
							| 
									
										
										
										
											2020-11-30 00:48:10 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def visualize_sss_info(api, dataset, vis_save_dir): | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     vis_save_dir = vis_save_dir.resolve() | 
					
						
							|  |  |  |     print("{:} start to visualize {:} information".format(time_string(), dataset)) | 
					
						
							|  |  |  |     vis_save_dir.mkdir(parents=True, exist_ok=True) | 
					
						
							|  |  |  |     cache_file_path = vis_save_dir / "{:}-cache-sss-info.pth".format(dataset) | 
					
						
							|  |  |  |     if not cache_file_path.exists(): | 
					
						
							|  |  |  |         print("Do not find cache file : {:}".format(cache_file_path)) | 
					
						
							|  |  |  |         params, flops, train_accs, valid_accs, test_accs = [], [], [], [], [] | 
					
						
							|  |  |  |         for index in range(len(api)): | 
					
						
							|  |  |  |             cost_info = api.get_cost_info(index, dataset, hp="90") | 
					
						
							|  |  |  |             params.append(cost_info["params"]) | 
					
						
							|  |  |  |             flops.append(cost_info["flops"]) | 
					
						
							|  |  |  |             # accuracy | 
					
						
							|  |  |  |             info = api.get_more_info(index, dataset, hp="90", is_random=False) | 
					
						
							|  |  |  |             train_accs.append(info["train-accuracy"]) | 
					
						
							|  |  |  |             test_accs.append(info["test-accuracy"]) | 
					
						
							|  |  |  |             if dataset == "cifar10": | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |                 info = api.get_more_info( | 
					
						
							|  |  |  |                     index, "cifar10-valid", hp="90", is_random=False | 
					
						
							|  |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |                 valid_accs.append(info["valid-accuracy"]) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 valid_accs.append(info["valid-accuracy"]) | 
					
						
							|  |  |  |         info = { | 
					
						
							|  |  |  |             "params": params, | 
					
						
							|  |  |  |             "flops": flops, | 
					
						
							|  |  |  |             "train_accs": train_accs, | 
					
						
							|  |  |  |             "valid_accs": valid_accs, | 
					
						
							|  |  |  |             "test_accs": test_accs, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         torch.save(info, cache_file_path) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         print("Find cache file : {:}".format(cache_file_path)) | 
					
						
							|  |  |  |         info = torch.load(cache_file_path) | 
					
						
							|  |  |  |         params, flops, train_accs, valid_accs, test_accs = ( | 
					
						
							|  |  |  |             info["params"], | 
					
						
							|  |  |  |             info["flops"], | 
					
						
							|  |  |  |             info["train_accs"], | 
					
						
							|  |  |  |             info["valid_accs"], | 
					
						
							|  |  |  |             info["test_accs"], | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     print("{:} collect data done.".format(time_string())) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # pyramid = ['8:16:32:48:64', '8:8:16:32:48', '8:8:16:16:32', '8:8:16:16:48', '8:8:16:16:64', '16:16:32:32:64', '32:32:64:64:64'] | 
					
						
							|  |  |  |     pyramid = ["8:16:24:32:40", "8:16:32:48:64", "32:40:48:56:64"] | 
					
						
							|  |  |  |     pyramid_indexes = [api.query_index_by_arch(x) for x in pyramid] | 
					
						
							|  |  |  |     largest_indexes = [api.query_index_by_arch("64:64:64:64:64")] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     indexes = list(range(len(params))) | 
					
						
							|  |  |  |     dpi, width, height = 250, 8500, 1300 | 
					
						
							|  |  |  |     figsize = width / float(dpi), height / float(dpi) | 
					
						
							|  |  |  |     LabelSize, LegendFontsize = 24, 24 | 
					
						
							|  |  |  |     # resnet_scale, resnet_alpha = 120, 0.5 | 
					
						
							|  |  |  |     xscale, xalpha = 120, 0.8 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     fig, axs = plt.subplots(1, 4, figsize=figsize) | 
					
						
							|  |  |  |     # ax1, ax2, ax3, ax4, ax5 = axs | 
					
						
							|  |  |  |     for ax in axs: | 
					
						
							|  |  |  |         for tick in ax.xaxis.get_major_ticks(): | 
					
						
							|  |  |  |             tick.label.set_fontsize(LabelSize) | 
					
						
							|  |  |  |         ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.0f")) | 
					
						
							|  |  |  |         for tick in ax.yaxis.get_major_ticks(): | 
					
						
							|  |  |  |             tick.label.set_fontsize(LabelSize) | 
					
						
							|  |  |  |     ax1, ax2, ax3, ax4 = axs | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ax1.scatter(params, train_accs, marker="o", s=0.5, c="tab:blue") | 
					
						
							|  |  |  |     ax1.scatter( | 
					
						
							|  |  |  |         [params[x] for x in pyramid_indexes], | 
					
						
							|  |  |  |         [train_accs[x] for x in pyramid_indexes], | 
					
						
							|  |  |  |         marker="*", | 
					
						
							|  |  |  |         s=xscale, | 
					
						
							|  |  |  |         c="tab:orange", | 
					
						
							|  |  |  |         label="Pyramid Structure", | 
					
						
							|  |  |  |         alpha=xalpha, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     ax1.scatter( | 
					
						
							|  |  |  |         [params[x] for x in largest_indexes], | 
					
						
							|  |  |  |         [train_accs[x] for x in largest_indexes], | 
					
						
							|  |  |  |         marker="x", | 
					
						
							|  |  |  |         s=xscale, | 
					
						
							|  |  |  |         c="tab:green", | 
					
						
							|  |  |  |         label="Largest Candidate", | 
					
						
							|  |  |  |         alpha=xalpha, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     ax1.set_xlabel("#parameters (MB)", fontsize=LabelSize) | 
					
						
							|  |  |  |     ax1.set_ylabel("train accuracy (%)", fontsize=LabelSize) | 
					
						
							|  |  |  |     ax1.legend(loc=4, fontsize=LegendFontsize) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ax2.scatter(flops, train_accs, marker="o", s=0.5, c="tab:blue") | 
					
						
							|  |  |  |     ax2.scatter( | 
					
						
							|  |  |  |         [flops[x] for x in pyramid_indexes], | 
					
						
							|  |  |  |         [train_accs[x] for x in pyramid_indexes], | 
					
						
							|  |  |  |         marker="*", | 
					
						
							|  |  |  |         s=xscale, | 
					
						
							|  |  |  |         c="tab:orange", | 
					
						
							|  |  |  |         label="Pyramid Structure", | 
					
						
							|  |  |  |         alpha=xalpha, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     ax2.scatter( | 
					
						
							|  |  |  |         [flops[x] for x in largest_indexes], | 
					
						
							|  |  |  |         [train_accs[x] for x in largest_indexes], | 
					
						
							|  |  |  |         marker="x", | 
					
						
							|  |  |  |         s=xscale, | 
					
						
							|  |  |  |         c="tab:green", | 
					
						
							|  |  |  |         label="Largest Candidate", | 
					
						
							|  |  |  |         alpha=xalpha, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     ax2.set_xlabel("#FLOPs (M)", fontsize=LabelSize) | 
					
						
							|  |  |  |     # ax2.set_ylabel('train accuracy (%)', fontsize=LabelSize) | 
					
						
							|  |  |  |     ax2.legend(loc=4, fontsize=LegendFontsize) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ax3.scatter(params, test_accs, marker="o", s=0.5, c="tab:blue") | 
					
						
							|  |  |  |     ax3.scatter( | 
					
						
							|  |  |  |         [params[x] for x in pyramid_indexes], | 
					
						
							|  |  |  |         [test_accs[x] for x in pyramid_indexes], | 
					
						
							|  |  |  |         marker="*", | 
					
						
							|  |  |  |         s=xscale, | 
					
						
							|  |  |  |         c="tab:orange", | 
					
						
							|  |  |  |         label="Pyramid Structure", | 
					
						
							|  |  |  |         alpha=xalpha, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     ax3.scatter( | 
					
						
							|  |  |  |         [params[x] for x in largest_indexes], | 
					
						
							|  |  |  |         [test_accs[x] for x in largest_indexes], | 
					
						
							|  |  |  |         marker="x", | 
					
						
							|  |  |  |         s=xscale, | 
					
						
							|  |  |  |         c="tab:green", | 
					
						
							|  |  |  |         label="Largest Candidate", | 
					
						
							|  |  |  |         alpha=xalpha, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     ax3.set_xlabel("#parameters (MB)", fontsize=LabelSize) | 
					
						
							|  |  |  |     ax3.set_ylabel("test accuracy (%)", fontsize=LabelSize) | 
					
						
							|  |  |  |     ax3.legend(loc=4, fontsize=LegendFontsize) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ax4.scatter(flops, test_accs, marker="o", s=0.5, c="tab:blue") | 
					
						
							|  |  |  |     ax4.scatter( | 
					
						
							|  |  |  |         [flops[x] for x in pyramid_indexes], | 
					
						
							|  |  |  |         [test_accs[x] for x in pyramid_indexes], | 
					
						
							|  |  |  |         marker="*", | 
					
						
							|  |  |  |         s=xscale, | 
					
						
							|  |  |  |         c="tab:orange", | 
					
						
							|  |  |  |         label="Pyramid Structure", | 
					
						
							|  |  |  |         alpha=xalpha, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     ax4.scatter( | 
					
						
							|  |  |  |         [flops[x] for x in largest_indexes], | 
					
						
							|  |  |  |         [test_accs[x] for x in largest_indexes], | 
					
						
							|  |  |  |         marker="x", | 
					
						
							|  |  |  |         s=xscale, | 
					
						
							|  |  |  |         c="tab:green", | 
					
						
							|  |  |  |         label="Largest Candidate", | 
					
						
							|  |  |  |         alpha=xalpha, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     ax4.set_xlabel("#FLOPs (M)", fontsize=LabelSize) | 
					
						
							|  |  |  |     # ax4.set_ylabel('test accuracy (%)', fontsize=LabelSize) | 
					
						
							|  |  |  |     ax4.legend(loc=4, fontsize=LegendFontsize) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     save_path = vis_save_dir / "sss-{:}.png".format(dataset.lower()) | 
					
						
							|  |  |  |     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") | 
					
						
							|  |  |  |     print("{:} save into {:}".format(time_string(), save_path)) | 
					
						
							|  |  |  |     plt.close("all") | 
					
						
							| 
									
										
										
										
											2020-11-30 00:48:10 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def visualize_tss_info(api, dataset, vis_save_dir): | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     vis_save_dir = vis_save_dir.resolve() | 
					
						
							|  |  |  |     print("{:} start to visualize {:} information".format(time_string(), dataset)) | 
					
						
							|  |  |  |     vis_save_dir.mkdir(parents=True, exist_ok=True) | 
					
						
							|  |  |  |     cache_file_path = vis_save_dir / "{:}-cache-tss-info.pth".format(dataset) | 
					
						
							|  |  |  |     if not cache_file_path.exists(): | 
					
						
							|  |  |  |         print("Do not find cache file : {:}".format(cache_file_path)) | 
					
						
							|  |  |  |         params, flops, train_accs, valid_accs, test_accs = [], [], [], [], [] | 
					
						
							|  |  |  |         for index in range(len(api)): | 
					
						
							|  |  |  |             cost_info = api.get_cost_info(index, dataset, hp="12") | 
					
						
							|  |  |  |             params.append(cost_info["params"]) | 
					
						
							|  |  |  |             flops.append(cost_info["flops"]) | 
					
						
							|  |  |  |             # accuracy | 
					
						
							|  |  |  |             info = api.get_more_info(index, dataset, hp="200", is_random=False) | 
					
						
							|  |  |  |             train_accs.append(info["train-accuracy"]) | 
					
						
							|  |  |  |             test_accs.append(info["test-accuracy"]) | 
					
						
							|  |  |  |             if dataset == "cifar10": | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |                 info = api.get_more_info( | 
					
						
							|  |  |  |                     index, "cifar10-valid", hp="200", is_random=False | 
					
						
							|  |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |                 valid_accs.append(info["valid-accuracy"]) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 valid_accs.append(info["valid-accuracy"]) | 
					
						
							|  |  |  |             print("") | 
					
						
							|  |  |  |         info = { | 
					
						
							|  |  |  |             "params": params, | 
					
						
							|  |  |  |             "flops": flops, | 
					
						
							|  |  |  |             "train_accs": train_accs, | 
					
						
							|  |  |  |             "valid_accs": valid_accs, | 
					
						
							|  |  |  |             "test_accs": test_accs, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         torch.save(info, cache_file_path) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         print("Find cache file : {:}".format(cache_file_path)) | 
					
						
							|  |  |  |         info = torch.load(cache_file_path) | 
					
						
							|  |  |  |         params, flops, train_accs, valid_accs, test_accs = ( | 
					
						
							|  |  |  |             info["params"], | 
					
						
							|  |  |  |             info["flops"], | 
					
						
							|  |  |  |             info["train_accs"], | 
					
						
							|  |  |  |             info["valid_accs"], | 
					
						
							|  |  |  |             info["test_accs"], | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     print("{:} collect data done.".format(time_string())) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |     resnet = [ | 
					
						
							|  |  |  |         "|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|" | 
					
						
							|  |  |  |     ] | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     resnet_indexes = [api.query_index_by_arch(x) for x in resnet] | 
					
						
							|  |  |  |     largest_indexes = [ | 
					
						
							|  |  |  |         api.query_index_by_arch( | 
					
						
							|  |  |  |             "|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|" | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     ] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     indexes = list(range(len(params))) | 
					
						
							|  |  |  |     dpi, width, height = 250, 8500, 1300 | 
					
						
							|  |  |  |     figsize = width / float(dpi), height / float(dpi) | 
					
						
							|  |  |  |     LabelSize, LegendFontsize = 24, 24 | 
					
						
							|  |  |  |     # resnet_scale, resnet_alpha = 120, 0.5 | 
					
						
							|  |  |  |     xscale, xalpha = 120, 0.8 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     fig, axs = plt.subplots(1, 4, figsize=figsize) | 
					
						
							|  |  |  |     for ax in axs: | 
					
						
							|  |  |  |         for tick in ax.xaxis.get_major_ticks(): | 
					
						
							|  |  |  |             tick.label.set_fontsize(LabelSize) | 
					
						
							|  |  |  |         ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.0f")) | 
					
						
							|  |  |  |         for tick in ax.yaxis.get_major_ticks(): | 
					
						
							|  |  |  |             tick.label.set_fontsize(LabelSize) | 
					
						
							|  |  |  |     ax1, ax2, ax3, ax4 = axs | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ax1.scatter(params, train_accs, marker="o", s=0.5, c="tab:blue") | 
					
						
							|  |  |  |     ax1.scatter( | 
					
						
							|  |  |  |         [params[x] for x in resnet_indexes], | 
					
						
							|  |  |  |         [train_accs[x] for x in resnet_indexes], | 
					
						
							|  |  |  |         marker="*", | 
					
						
							|  |  |  |         s=xscale, | 
					
						
							|  |  |  |         c="tab:orange", | 
					
						
							|  |  |  |         label="ResNet", | 
					
						
							|  |  |  |         alpha=xalpha, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     ax1.scatter( | 
					
						
							|  |  |  |         [params[x] for x in largest_indexes], | 
					
						
							|  |  |  |         [train_accs[x] for x in largest_indexes], | 
					
						
							|  |  |  |         marker="x", | 
					
						
							|  |  |  |         s=xscale, | 
					
						
							|  |  |  |         c="tab:green", | 
					
						
							|  |  |  |         label="Largest Candidate", | 
					
						
							|  |  |  |         alpha=xalpha, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     ax1.set_xlabel("#parameters (MB)", fontsize=LabelSize) | 
					
						
							|  |  |  |     ax1.set_ylabel("train accuracy (%)", fontsize=LabelSize) | 
					
						
							|  |  |  |     ax1.legend(loc=4, fontsize=LegendFontsize) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ax2.scatter(flops, train_accs, marker="o", s=0.5, c="tab:blue") | 
					
						
							|  |  |  |     ax2.scatter( | 
					
						
							|  |  |  |         [flops[x] for x in resnet_indexes], | 
					
						
							|  |  |  |         [train_accs[x] for x in resnet_indexes], | 
					
						
							|  |  |  |         marker="*", | 
					
						
							|  |  |  |         s=xscale, | 
					
						
							|  |  |  |         c="tab:orange", | 
					
						
							|  |  |  |         label="ResNet", | 
					
						
							|  |  |  |         alpha=xalpha, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     ax2.scatter( | 
					
						
							|  |  |  |         [flops[x] for x in largest_indexes], | 
					
						
							|  |  |  |         [train_accs[x] for x in largest_indexes], | 
					
						
							|  |  |  |         marker="x", | 
					
						
							|  |  |  |         s=xscale, | 
					
						
							|  |  |  |         c="tab:green", | 
					
						
							|  |  |  |         label="Largest Candidate", | 
					
						
							|  |  |  |         alpha=xalpha, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     ax2.set_xlabel("#FLOPs (M)", fontsize=LabelSize) | 
					
						
							|  |  |  |     # ax2.set_ylabel('train accuracy (%)', fontsize=LabelSize) | 
					
						
							|  |  |  |     ax2.legend(loc=4, fontsize=LegendFontsize) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ax3.scatter(params, test_accs, marker="o", s=0.5, c="tab:blue") | 
					
						
							|  |  |  |     ax3.scatter( | 
					
						
							|  |  |  |         [params[x] for x in resnet_indexes], | 
					
						
							|  |  |  |         [test_accs[x] for x in resnet_indexes], | 
					
						
							|  |  |  |         marker="*", | 
					
						
							|  |  |  |         s=xscale, | 
					
						
							|  |  |  |         c="tab:orange", | 
					
						
							|  |  |  |         label="ResNet", | 
					
						
							|  |  |  |         alpha=xalpha, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     ax3.scatter( | 
					
						
							|  |  |  |         [params[x] for x in largest_indexes], | 
					
						
							|  |  |  |         [test_accs[x] for x in largest_indexes], | 
					
						
							|  |  |  |         marker="x", | 
					
						
							|  |  |  |         s=xscale, | 
					
						
							|  |  |  |         c="tab:green", | 
					
						
							|  |  |  |         label="Largest Candidate", | 
					
						
							|  |  |  |         alpha=xalpha, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     ax3.set_xlabel("#parameters (MB)", fontsize=LabelSize) | 
					
						
							|  |  |  |     ax3.set_ylabel("test accuracy (%)", fontsize=LabelSize) | 
					
						
							|  |  |  |     ax3.legend(loc=4, fontsize=LegendFontsize) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ax4.scatter(flops, test_accs, marker="o", s=0.5, c="tab:blue") | 
					
						
							|  |  |  |     ax4.scatter( | 
					
						
							|  |  |  |         [flops[x] for x in resnet_indexes], | 
					
						
							|  |  |  |         [test_accs[x] for x in resnet_indexes], | 
					
						
							|  |  |  |         marker="*", | 
					
						
							|  |  |  |         s=xscale, | 
					
						
							|  |  |  |         c="tab:orange", | 
					
						
							|  |  |  |         label="ResNet", | 
					
						
							|  |  |  |         alpha=xalpha, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     ax4.scatter( | 
					
						
							|  |  |  |         [flops[x] for x in largest_indexes], | 
					
						
							|  |  |  |         [test_accs[x] for x in largest_indexes], | 
					
						
							|  |  |  |         marker="x", | 
					
						
							|  |  |  |         s=xscale, | 
					
						
							|  |  |  |         c="tab:green", | 
					
						
							|  |  |  |         label="Largest Candidate", | 
					
						
							|  |  |  |         alpha=xalpha, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     ax4.set_xlabel("#FLOPs (M)", fontsize=LabelSize) | 
					
						
							|  |  |  |     # ax4.set_ylabel('test accuracy (%)', fontsize=LabelSize) | 
					
						
							|  |  |  |     ax4.legend(loc=4, fontsize=LegendFontsize) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     save_path = vis_save_dir / "tss-{:}.png".format(dataset.lower()) | 
					
						
							|  |  |  |     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") | 
					
						
							|  |  |  |     print("{:} save into {:}".format(time_string(), save_path)) | 
					
						
							|  |  |  |     plt.close("all") | 
					
						
							| 
									
										
										
										
											2020-11-30 00:48:10 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def visualize_rank_info(api, vis_save_dir, indicator): | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     vis_save_dir = vis_save_dir.resolve() | 
					
						
							|  |  |  |     # print ('{:} start to visualize {:} information'.format(time_string(), api)) | 
					
						
							|  |  |  |     vis_save_dir.mkdir(parents=True, exist_ok=True) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |     cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | 
					
						
							|  |  |  |         "cifar10", indicator | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | 
					
						
							|  |  |  |         "cifar100", indicator | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | 
					
						
							|  |  |  |         "ImageNet16-120", indicator | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     cifar010_info = torch.load(cifar010_cache_path) | 
					
						
							|  |  |  |     cifar100_info = torch.load(cifar100_cache_path) | 
					
						
							|  |  |  |     imagenet_info = torch.load(imagenet_cache_path) | 
					
						
							|  |  |  |     indexes = list(range(len(cifar010_info["params"]))) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     print("{:} start to visualize relative ranking".format(time_string())) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     dpi, width, height = 250, 3800, 1200 | 
					
						
							|  |  |  |     figsize = width / float(dpi), height / float(dpi) | 
					
						
							|  |  |  |     LabelSize, LegendFontsize = 14, 14 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     fig, axs = plt.subplots(1, 3, figsize=figsize) | 
					
						
							|  |  |  |     ax1, ax2, ax3 = axs | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_labels(info): | 
					
						
							|  |  |  |         ord_test_indexes = sorted(indexes, key=lambda i: info["test_accs"][i]) | 
					
						
							|  |  |  |         ord_valid_indexes = sorted(indexes, key=lambda i: info["valid_accs"][i]) | 
					
						
							|  |  |  |         labels = [] | 
					
						
							|  |  |  |         for idx in ord_test_indexes: | 
					
						
							|  |  |  |             labels.append(ord_valid_indexes.index(idx)) | 
					
						
							|  |  |  |         return labels | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def plot_ax(labels, ax, name): | 
					
						
							|  |  |  |         for tick in ax.xaxis.get_major_ticks(): | 
					
						
							|  |  |  |             tick.label.set_fontsize(LabelSize) | 
					
						
							|  |  |  |         for tick in ax.yaxis.get_major_ticks(): | 
					
						
							|  |  |  |             tick.label.set_fontsize(LabelSize) | 
					
						
							|  |  |  |             tick.label.set_rotation(90) | 
					
						
							|  |  |  |         ax.set_xlim(min(indexes), max(indexes)) | 
					
						
							|  |  |  |         ax.set_ylim(min(indexes), max(indexes)) | 
					
						
							|  |  |  |         ax.yaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 3)) | 
					
						
							|  |  |  |         ax.xaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 5)) | 
					
						
							|  |  |  |         ax.scatter(indexes, labels, marker="^", s=0.5, c="tab:green", alpha=0.8) | 
					
						
							|  |  |  |         ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8) | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |         ax.scatter( | 
					
						
							|  |  |  |             [-1], [-1], marker="^", s=100, c="tab:green", label="{:} test".format(name) | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         ax.scatter( | 
					
						
							|  |  |  |             [-1], | 
					
						
							|  |  |  |             [-1], | 
					
						
							|  |  |  |             marker="o", | 
					
						
							|  |  |  |             s=100, | 
					
						
							|  |  |  |             c="tab:blue", | 
					
						
							|  |  |  |             label="{:} validation".format(name), | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |         ax.legend(loc=4, fontsize=LegendFontsize) | 
					
						
							|  |  |  |         ax.set_xlabel("ranking on the {:} validation".format(name), fontsize=LabelSize) | 
					
						
							|  |  |  |         ax.set_ylabel("architecture ranking", fontsize=LabelSize) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     labels = get_labels(cifar010_info) | 
					
						
							|  |  |  |     plot_ax(labels, ax1, "CIFAR-10") | 
					
						
							|  |  |  |     labels = get_labels(cifar100_info) | 
					
						
							|  |  |  |     plot_ax(labels, ax2, "CIFAR-100") | 
					
						
							|  |  |  |     labels = get_labels(imagenet_info) | 
					
						
							|  |  |  |     plot_ax(labels, ax3, "ImageNet-16-120") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |     save_path = ( | 
					
						
							|  |  |  |         vis_save_dir / "{:}-same-relative-rank.pdf".format(indicator) | 
					
						
							|  |  |  |     ).resolve() | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |     save_path = ( | 
					
						
							|  |  |  |         vis_save_dir / "{:}-same-relative-rank.png".format(indicator) | 
					
						
							|  |  |  |     ).resolve() | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") | 
					
						
							|  |  |  |     print("{:} save into {:}".format(time_string(), save_path)) | 
					
						
							|  |  |  |     plt.close("all") | 
					
						
							| 
									
										
										
										
											2020-11-30 00:48:10 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def compute_kendalltau(vectori, vectorj): | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     # indexes = list(range(len(vectori))) | 
					
						
							|  |  |  |     # rank_1 = sorted(indexes, key=lambda i: vectori[i]) | 
					
						
							|  |  |  |     # rank_2 = sorted(indexes, key=lambda i: vectorj[i]) | 
					
						
							|  |  |  |     return scipy.stats.kendalltau(vectori, vectorj).correlation | 
					
						
							| 
									
										
										
										
											2020-11-30 00:48:10 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def calculate_correlation(*vectors): | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     matrix = [] | 
					
						
							|  |  |  |     for i, vectori in enumerate(vectors): | 
					
						
							|  |  |  |         x = [] | 
					
						
							|  |  |  |         for j, vectorj in enumerate(vectors): | 
					
						
							|  |  |  |             # x.append(np.corrcoef(vectori, vectorj)[0,1]) | 
					
						
							|  |  |  |             x.append(compute_kendalltau(vectori, vectorj)) | 
					
						
							|  |  |  |         matrix.append(x) | 
					
						
							|  |  |  |     return np.array(matrix) | 
					
						
							| 
									
										
										
										
											2020-11-30 00:48:10 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def visualize_all_rank_info(api, vis_save_dir, indicator): | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     vis_save_dir = vis_save_dir.resolve() | 
					
						
							|  |  |  |     # print ('{:} start to visualize {:} information'.format(time_string(), api)) | 
					
						
							|  |  |  |     vis_save_dir.mkdir(parents=True, exist_ok=True) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |     cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | 
					
						
							|  |  |  |         "cifar10", indicator | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | 
					
						
							|  |  |  |         "cifar100", indicator | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | 
					
						
							|  |  |  |         "ImageNet16-120", indicator | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     cifar010_info = torch.load(cifar010_cache_path) | 
					
						
							|  |  |  |     cifar100_info = torch.load(cifar100_cache_path) | 
					
						
							|  |  |  |     imagenet_info = torch.load(imagenet_cache_path) | 
					
						
							|  |  |  |     indexes = list(range(len(cifar010_info["params"]))) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     print("{:} start to visualize relative ranking".format(time_string())) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     dpi, width, height = 250, 3200, 1400 | 
					
						
							|  |  |  |     figsize = width / float(dpi), height / float(dpi) | 
					
						
							|  |  |  |     LabelSize, LegendFontsize = 14, 14 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     fig, axs = plt.subplots(1, 2, figsize=figsize) | 
					
						
							|  |  |  |     ax1, ax2 = axs | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     sns_size, xformat = 15, ".2f" | 
					
						
							|  |  |  |     CoRelMatrix = calculate_correlation( | 
					
						
							|  |  |  |         cifar010_info["valid_accs"], | 
					
						
							|  |  |  |         cifar010_info["test_accs"], | 
					
						
							|  |  |  |         cifar100_info["valid_accs"], | 
					
						
							|  |  |  |         cifar100_info["test_accs"], | 
					
						
							|  |  |  |         imagenet_info["valid_accs"], | 
					
						
							|  |  |  |         imagenet_info["test_accs"], | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     sns.heatmap( | 
					
						
							|  |  |  |         CoRelMatrix, | 
					
						
							|  |  |  |         annot=True, | 
					
						
							|  |  |  |         annot_kws={"size": sns_size}, | 
					
						
							|  |  |  |         fmt=xformat, | 
					
						
							|  |  |  |         linewidths=0.5, | 
					
						
							|  |  |  |         ax=ax1, | 
					
						
							|  |  |  |         xticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"], | 
					
						
							|  |  |  |         yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"], | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     selected_indexes, acc_bar = [], 92 | 
					
						
							|  |  |  |     for i, acc in enumerate(cifar010_info["test_accs"]): | 
					
						
							|  |  |  |         if acc > acc_bar: | 
					
						
							|  |  |  |             selected_indexes.append(i) | 
					
						
							|  |  |  |     cifar010_valid_accs = np.array(cifar010_info["valid_accs"])[selected_indexes] | 
					
						
							|  |  |  |     cifar010_test_accs = np.array(cifar010_info["test_accs"])[selected_indexes] | 
					
						
							|  |  |  |     cifar100_valid_accs = np.array(cifar100_info["valid_accs"])[selected_indexes] | 
					
						
							|  |  |  |     cifar100_test_accs = np.array(cifar100_info["test_accs"])[selected_indexes] | 
					
						
							|  |  |  |     imagenet_valid_accs = np.array(imagenet_info["valid_accs"])[selected_indexes] | 
					
						
							|  |  |  |     imagenet_test_accs = np.array(imagenet_info["test_accs"])[selected_indexes] | 
					
						
							|  |  |  |     CoRelMatrix = calculate_correlation( | 
					
						
							|  |  |  |         cifar010_valid_accs, | 
					
						
							|  |  |  |         cifar010_test_accs, | 
					
						
							|  |  |  |         cifar100_valid_accs, | 
					
						
							|  |  |  |         cifar100_test_accs, | 
					
						
							|  |  |  |         imagenet_valid_accs, | 
					
						
							|  |  |  |         imagenet_test_accs, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     sns.heatmap( | 
					
						
							|  |  |  |         CoRelMatrix, | 
					
						
							|  |  |  |         annot=True, | 
					
						
							|  |  |  |         annot_kws={"size": sns_size}, | 
					
						
							|  |  |  |         fmt=xformat, | 
					
						
							|  |  |  |         linewidths=0.5, | 
					
						
							|  |  |  |         ax=ax2, | 
					
						
							|  |  |  |         xticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"], | 
					
						
							|  |  |  |         yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"], | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     ax1.set_title("Correlation coefficient over ALL candidates") | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |     ax2.set_title( | 
					
						
							|  |  |  |         "Correlation coefficient over candidates with accuracy > {:}%".format(acc_bar) | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     save_path = (vis_save_dir / "{:}-all-relative-rank.png".format(indicator)).resolve() | 
					
						
							|  |  |  |     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") | 
					
						
							|  |  |  |     print("{:} save into {:}".format(time_string(), save_path)) | 
					
						
							|  |  |  |     plt.close("all") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |     parser = argparse.ArgumentParser( | 
					
						
							|  |  |  |         description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     parser.add_argument( | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |         "--save_dir", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         default="output/vis-nas-bench", | 
					
						
							|  |  |  |         help="Folder to save checkpoints and log.", | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     ) | 
					
						
							|  |  |  |     # use for train the model | 
					
						
							|  |  |  |     args = parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     to_save_dir = Path(args.save_dir) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     datasets = ["cifar10", "cifar100", "ImageNet16-120"] | 
					
						
							|  |  |  |     # Figure 3 (a-c) | 
					
						
							|  |  |  |     api_tss = create(None, "tss", verbose=True) | 
					
						
							|  |  |  |     for xdata in datasets: | 
					
						
							|  |  |  |         visualize_tss_info(api_tss, xdata, to_save_dir) | 
					
						
							|  |  |  |     # Figure 3 (d-f) | 
					
						
							|  |  |  |     api_sss = create(None, "size", verbose=True) | 
					
						
							|  |  |  |     for xdata in datasets: | 
					
						
							|  |  |  |         visualize_sss_info(api_sss, xdata, to_save_dir) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Figure 2 | 
					
						
							|  |  |  |     visualize_relative_info(None, to_save_dir, "tss") | 
					
						
							|  |  |  |     visualize_relative_info(None, to_save_dir, "sss") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Figure 4 | 
					
						
							|  |  |  |     visualize_rank_info(None, to_save_dir, "tss") | 
					
						
							|  |  |  |     visualize_rank_info(None, to_save_dir, "sss") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Figure 5 | 
					
						
							|  |  |  |     visualize_all_rank_info(None, to_save_dir, "tss") | 
					
						
							|  |  |  |     visualize_all_rank_info(None, to_save_dir, "sss") |