############################################################### # NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) # ############################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 # ############################################################### # Usage: python exps/experimental/visualize-nas-bench-x.py ############################################################### import os, sys, time, torch, argparse import numpy as np from typing import List, Text, Dict, Any from shutil import copyfile from collections import defaultdict from copy import deepcopy from pathlib import Path import matplotlib import seaborn as sns matplotlib.use("agg") import matplotlib.pyplot as plt import matplotlib.ticker as ticker lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) from config_utils import dict2config, load_config from log_utils import time_string from models import get_cell_based_tiny_net from nats_bench import create def visualize_info(api, vis_save_dir, indicator): vis_save_dir = vis_save_dir.resolve() # print ('{:} start to visualize {:} information'.format(time_string(), api)) vis_save_dir.mkdir(parents=True, exist_ok=True) cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( "cifar10", indicator ) cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( "cifar100", indicator ) imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( "ImageNet16-120", indicator ) cifar010_info = torch.load(cifar010_cache_path) cifar100_info = torch.load(cifar100_cache_path) imagenet_info = torch.load(imagenet_cache_path) indexes = list(range(len(cifar010_info["params"]))) print("{:} start to visualize relative ranking".format(time_string())) cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info["test_accs"][i]) cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info["test_accs"][i]) imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info["test_accs"][i]) cifar100_labels, imagenet_labels = [], [] for idx in cifar010_ord_indexes: cifar100_labels.append(cifar100_ord_indexes.index(idx)) imagenet_labels.append(imagenet_ord_indexes.index(idx)) print("{:} prepare data done.".format(time_string())) dpi, width, height = 200, 1400, 800 figsize = width / float(dpi), height / float(dpi) LabelSize, LegendFontsize = 18, 12 resnet_scale, resnet_alpha = 120, 0.5 fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111) plt.xlim(min(indexes), max(indexes)) plt.ylim(min(indexes), max(indexes)) # plt.ylabel('y').set_rotation(30) plt.yticks( np.arange(min(indexes), max(indexes), max(indexes) // 3), fontsize=LegendFontsize, rotation="vertical", ) plt.xticks( np.arange(min(indexes), max(indexes), max(indexes) // 5), fontsize=LegendFontsize, ) ax.scatter(indexes, cifar100_labels, marker="^", s=0.5, c="tab:green", alpha=0.8) ax.scatter(indexes, imagenet_labels, marker="*", s=0.5, c="tab:red", alpha=0.8) ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8) ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="CIFAR-10") ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="CIFAR-100") ax.scatter([-1], [-1], marker="*", s=100, c="tab:red", label="ImageNet-16-120") plt.grid(zorder=0) ax.set_axisbelow(True) plt.legend(loc=0, fontsize=LegendFontsize) ax.set_xlabel("architecture ranking in CIFAR-10", fontsize=LabelSize) ax.set_ylabel("architecture ranking", fontsize=LabelSize) save_path = (vis_save_dir / "{:}-relative-rank.pdf".format(indicator)).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") save_path = (vis_save_dir / "{:}-relative-rank.png".format(indicator)).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") print("{:} save into {:}".format(time_string(), save_path)) def visualize_sss_info(api, dataset, vis_save_dir): vis_save_dir = vis_save_dir.resolve() print("{:} start to visualize {:} information".format(time_string(), dataset)) vis_save_dir.mkdir(parents=True, exist_ok=True) cache_file_path = vis_save_dir / "{:}-cache-sss-info.pth".format(dataset) if not cache_file_path.exists(): print("Do not find cache file : {:}".format(cache_file_path)) params, flops, train_accs, valid_accs, test_accs = [], [], [], [], [] for index in range(len(api)): cost_info = api.get_cost_info(index, dataset, hp="90") params.append(cost_info["params"]) flops.append(cost_info["flops"]) # accuracy info = api.get_more_info(index, dataset, hp="90", is_random=False) train_accs.append(info["train-accuracy"]) test_accs.append(info["test-accuracy"]) if dataset == "cifar10": info = api.get_more_info( index, "cifar10-valid", hp="90", is_random=False ) valid_accs.append(info["valid-accuracy"]) else: valid_accs.append(info["valid-accuracy"]) info = { "params": params, "flops": flops, "train_accs": train_accs, "valid_accs": valid_accs, "test_accs": test_accs, } torch.save(info, cache_file_path) else: print("Find cache file : {:}".format(cache_file_path)) info = torch.load(cache_file_path) params, flops, train_accs, valid_accs, test_accs = ( info["params"], info["flops"], info["train_accs"], info["valid_accs"], info["test_accs"], ) print("{:} collect data done.".format(time_string())) pyramid = [ "8:16:32:48:64", "8:8:16:32:48", "8:8:16:16:32", "8:8:16:16:48", "8:8:16:16:64", "16:16:32:32:64", "32:32:64:64:64", ] pyramid_indexes = [api.query_index_by_arch(x) for x in pyramid] largest_indexes = [api.query_index_by_arch("64:64:64:64:64")] indexes = list(range(len(params))) dpi, width, height = 250, 8500, 1300 figsize = width / float(dpi), height / float(dpi) LabelSize, LegendFontsize = 24, 24 # resnet_scale, resnet_alpha = 120, 0.5 xscale, xalpha = 120, 0.8 fig, axs = plt.subplots(1, 4, figsize=figsize) # ax1, ax2, ax3, ax4, ax5 = axs for ax in axs: for tick in ax.xaxis.get_major_ticks(): tick.label.set_fontsize(LabelSize) ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.0f")) for tick in ax.yaxis.get_major_ticks(): tick.label.set_fontsize(LabelSize) ax2, ax3, ax4, ax5 = axs # ax1.xaxis.set_ticks(np.arange(0, max(indexes), max(indexes)//5)) # ax1.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue') # ax1.set_xlabel('architecture ID', fontsize=LabelSize) # ax1.set_ylabel('test accuracy (%)', fontsize=LabelSize) ax2.scatter(params, train_accs, marker="o", s=0.5, c="tab:blue") ax2.scatter( [params[x] for x in pyramid_indexes], [train_accs[x] for x in pyramid_indexes], marker="*", s=xscale, c="tab:orange", label="Pyramid Structure", alpha=xalpha, ) ax2.scatter( [params[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker="x", s=xscale, c="tab:green", label="Largest Candidate", alpha=xalpha, ) ax2.set_xlabel("#parameters (MB)", fontsize=LabelSize) ax2.set_ylabel("train accuracy (%)", fontsize=LabelSize) ax2.legend(loc=4, fontsize=LegendFontsize) ax3.scatter(params, test_accs, marker="o", s=0.5, c="tab:blue") ax3.scatter( [params[x] for x in pyramid_indexes], [test_accs[x] for x in pyramid_indexes], marker="*", s=xscale, c="tab:orange", label="Pyramid Structure", alpha=xalpha, ) ax3.scatter( [params[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker="x", s=xscale, c="tab:green", label="Largest Candidate", alpha=xalpha, ) ax3.set_xlabel("#parameters (MB)", fontsize=LabelSize) ax3.set_ylabel("test accuracy (%)", fontsize=LabelSize) ax3.legend(loc=4, fontsize=LegendFontsize) ax4.scatter(flops, train_accs, marker="o", s=0.5, c="tab:blue") ax4.scatter( [flops[x] for x in pyramid_indexes], [train_accs[x] for x in pyramid_indexes], marker="*", s=xscale, c="tab:orange", label="Pyramid Structure", alpha=xalpha, ) ax4.scatter( [flops[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker="x", s=xscale, c="tab:green", label="Largest Candidate", alpha=xalpha, ) ax4.set_xlabel("#FLOPs (M)", fontsize=LabelSize) ax4.set_ylabel("train accuracy (%)", fontsize=LabelSize) ax4.legend(loc=4, fontsize=LegendFontsize) ax5.scatter(flops, test_accs, marker="o", s=0.5, c="tab:blue") ax5.scatter( [flops[x] for x in pyramid_indexes], [test_accs[x] for x in pyramid_indexes], marker="*", s=xscale, c="tab:orange", label="Pyramid Structure", alpha=xalpha, ) ax5.scatter( [flops[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker="x", s=xscale, c="tab:green", label="Largest Candidate", alpha=xalpha, ) ax5.set_xlabel("#FLOPs (M)", fontsize=LabelSize) ax5.set_ylabel("test accuracy (%)", fontsize=LabelSize) ax5.legend(loc=4, fontsize=LegendFontsize) save_path = vis_save_dir / "sss-{:}.png".format(dataset) fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") print("{:} save into {:}".format(time_string(), save_path)) plt.close("all") def visualize_tss_info(api, dataset, vis_save_dir): vis_save_dir = vis_save_dir.resolve() print("{:} start to visualize {:} information".format(time_string(), dataset)) vis_save_dir.mkdir(parents=True, exist_ok=True) cache_file_path = vis_save_dir / "{:}-cache-tss-info.pth".format(dataset) if not cache_file_path.exists(): print("Do not find cache file : {:}".format(cache_file_path)) params, flops, train_accs, valid_accs, test_accs = [], [], [], [], [] for index in range(len(api)): cost_info = api.get_cost_info(index, dataset, hp="12") params.append(cost_info["params"]) flops.append(cost_info["flops"]) # accuracy info = api.get_more_info(index, dataset, hp="200", is_random=False) train_accs.append(info["train-accuracy"]) test_accs.append(info["test-accuracy"]) if dataset == "cifar10": info = api.get_more_info( index, "cifar10-valid", hp="200", is_random=False ) valid_accs.append(info["valid-accuracy"]) else: valid_accs.append(info["valid-accuracy"]) print("") info = { "params": params, "flops": flops, "train_accs": train_accs, "valid_accs": valid_accs, "test_accs": test_accs, } torch.save(info, cache_file_path) else: print("Find cache file : {:}".format(cache_file_path)) info = torch.load(cache_file_path) params, flops, train_accs, valid_accs, test_accs = ( info["params"], info["flops"], info["train_accs"], info["valid_accs"], info["test_accs"], ) print("{:} collect data done.".format(time_string())) resnet = [ "|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|" ] resnet_indexes = [api.query_index_by_arch(x) for x in resnet] largest_indexes = [ api.query_index_by_arch( "|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|" ) ] indexes = list(range(len(params))) dpi, width, height = 250, 8500, 1300 figsize = width / float(dpi), height / float(dpi) LabelSize, LegendFontsize = 24, 24 # resnet_scale, resnet_alpha = 120, 0.5 xscale, xalpha = 120, 0.8 fig, axs = plt.subplots(1, 4, figsize=figsize) # ax1, ax2, ax3, ax4, ax5 = axs for ax in axs: for tick in ax.xaxis.get_major_ticks(): tick.label.set_fontsize(LabelSize) ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.0f")) for tick in ax.yaxis.get_major_ticks(): tick.label.set_fontsize(LabelSize) ax2, ax3, ax4, ax5 = axs # ax1.xaxis.set_ticks(np.arange(0, max(indexes), max(indexes)//5)) # ax1.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue') # ax1.set_xlabel('architecture ID', fontsize=LabelSize) # ax1.set_ylabel('test accuracy (%)', fontsize=LabelSize) ax2.scatter(params, train_accs, marker="o", s=0.5, c="tab:blue") ax2.scatter( [params[x] for x in resnet_indexes], [train_accs[x] for x in resnet_indexes], marker="*", s=xscale, c="tab:orange", label="ResNet", alpha=xalpha, ) ax2.scatter( [params[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker="x", s=xscale, c="tab:green", label="Largest Candidate", alpha=xalpha, ) ax2.set_xlabel("#parameters (MB)", fontsize=LabelSize) ax2.set_ylabel("train accuracy (%)", fontsize=LabelSize) ax2.legend(loc=4, fontsize=LegendFontsize) ax3.scatter(params, test_accs, marker="o", s=0.5, c="tab:blue") ax3.scatter( [params[x] for x in resnet_indexes], [test_accs[x] for x in resnet_indexes], marker="*", s=xscale, c="tab:orange", label="ResNet", alpha=xalpha, ) ax3.scatter( [params[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker="x", s=xscale, c="tab:green", label="Largest Candidate", alpha=xalpha, ) ax3.set_xlabel("#parameters (MB)", fontsize=LabelSize) ax3.set_ylabel("test accuracy (%)", fontsize=LabelSize) ax3.legend(loc=4, fontsize=LegendFontsize) ax4.scatter(flops, train_accs, marker="o", s=0.5, c="tab:blue") ax4.scatter( [flops[x] for x in resnet_indexes], [train_accs[x] for x in resnet_indexes], marker="*", s=xscale, c="tab:orange", label="ResNet", alpha=xalpha, ) ax4.scatter( [flops[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker="x", s=xscale, c="tab:green", label="Largest Candidate", alpha=xalpha, ) ax4.set_xlabel("#FLOPs (M)", fontsize=LabelSize) ax4.set_ylabel("train accuracy (%)", fontsize=LabelSize) ax4.legend(loc=4, fontsize=LegendFontsize) ax5.scatter(flops, test_accs, marker="o", s=0.5, c="tab:blue") ax5.scatter( [flops[x] for x in resnet_indexes], [test_accs[x] for x in resnet_indexes], marker="*", s=xscale, c="tab:orange", label="ResNet", alpha=xalpha, ) ax5.scatter( [flops[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker="x", s=xscale, c="tab:green", label="Largest Candidate", alpha=xalpha, ) ax5.set_xlabel("#FLOPs (M)", fontsize=LabelSize) ax5.set_ylabel("test accuracy (%)", fontsize=LabelSize) ax5.legend(loc=4, fontsize=LegendFontsize) save_path = vis_save_dir / "tss-{:}.png".format(dataset) fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") print("{:} save into {:}".format(time_string(), save_path)) plt.close("all") def visualize_rank_info(api, vis_save_dir, indicator): vis_save_dir = vis_save_dir.resolve() # print ('{:} start to visualize {:} information'.format(time_string(), api)) vis_save_dir.mkdir(parents=True, exist_ok=True) cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( "cifar10", indicator ) cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( "cifar100", indicator ) imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( "ImageNet16-120", indicator ) cifar010_info = torch.load(cifar010_cache_path) cifar100_info = torch.load(cifar100_cache_path) imagenet_info = torch.load(imagenet_cache_path) indexes = list(range(len(cifar010_info["params"]))) print("{:} start to visualize relative ranking".format(time_string())) dpi, width, height = 250, 3800, 1200 figsize = width / float(dpi), height / float(dpi) LabelSize, LegendFontsize = 14, 14 fig, axs = plt.subplots(1, 3, figsize=figsize) ax1, ax2, ax3 = axs def get_labels(info): ord_test_indexes = sorted(indexes, key=lambda i: info["test_accs"][i]) ord_valid_indexes = sorted(indexes, key=lambda i: info["valid_accs"][i]) labels = [] for idx in ord_test_indexes: labels.append(ord_valid_indexes.index(idx)) return labels def plot_ax(labels, ax, name): for tick in ax.xaxis.get_major_ticks(): tick.label.set_fontsize(LabelSize) for tick in ax.yaxis.get_major_ticks(): tick.label.set_fontsize(LabelSize) tick.label.set_rotation(90) ax.set_xlim(min(indexes), max(indexes)) ax.set_ylim(min(indexes), max(indexes)) ax.yaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 3)) ax.xaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 5)) ax.scatter(indexes, labels, marker="^", s=0.5, c="tab:green", alpha=0.8) ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8) ax.scatter( [-1], [-1], marker="^", s=100, c="tab:green", label="{:} test".format(name) ) ax.scatter( [-1], [-1], marker="o", s=100, c="tab:blue", label="{:} validation".format(name), ) ax.legend(loc=4, fontsize=LegendFontsize) ax.set_xlabel("ranking on the {:} validation".format(name), fontsize=LabelSize) ax.set_ylabel("architecture ranking", fontsize=LabelSize) labels = get_labels(cifar010_info) plot_ax(labels, ax1, "CIFAR-10") labels = get_labels(cifar100_info) plot_ax(labels, ax2, "CIFAR-100") labels = get_labels(imagenet_info) plot_ax(labels, ax3, "ImageNet-16-120") save_path = ( vis_save_dir / "{:}-same-relative-rank.pdf".format(indicator) ).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") save_path = ( vis_save_dir / "{:}-same-relative-rank.png".format(indicator) ).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") print("{:} save into {:}".format(time_string(), save_path)) plt.close("all") def calculate_correlation(*vectors): matrix = [] for i, vectori in enumerate(vectors): x = [] for j, vectorj in enumerate(vectors): x.append(np.corrcoef(vectori, vectorj)[0, 1]) matrix.append(x) return np.array(matrix) def visualize_all_rank_info(api, vis_save_dir, indicator): vis_save_dir = vis_save_dir.resolve() # print ('{:} start to visualize {:} information'.format(time_string(), api)) vis_save_dir.mkdir(parents=True, exist_ok=True) cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( "cifar10", indicator ) cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( "cifar100", indicator ) imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( "ImageNet16-120", indicator ) cifar010_info = torch.load(cifar010_cache_path) cifar100_info = torch.load(cifar100_cache_path) imagenet_info = torch.load(imagenet_cache_path) indexes = list(range(len(cifar010_info["params"]))) print("{:} start to visualize relative ranking".format(time_string())) dpi, width, height = 250, 3200, 1400 figsize = width / float(dpi), height / float(dpi) LabelSize, LegendFontsize = 14, 14 fig, axs = plt.subplots(1, 2, figsize=figsize) ax1, ax2 = axs sns_size = 15 CoRelMatrix = calculate_correlation( cifar010_info["valid_accs"], cifar010_info["test_accs"], cifar100_info["valid_accs"], cifar100_info["test_accs"], imagenet_info["valid_accs"], imagenet_info["test_accs"], ) sns.heatmap( CoRelMatrix, annot=True, annot_kws={"size": sns_size}, fmt=".3f", linewidths=0.5, ax=ax1, xticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"], yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"], ) selected_indexes, acc_bar = [], 92 for i, acc in enumerate(cifar010_info["test_accs"]): if acc > acc_bar: selected_indexes.append(i) cifar010_valid_accs = np.array(cifar010_info["valid_accs"])[selected_indexes] cifar010_test_accs = np.array(cifar010_info["test_accs"])[selected_indexes] cifar100_valid_accs = np.array(cifar100_info["valid_accs"])[selected_indexes] cifar100_test_accs = np.array(cifar100_info["test_accs"])[selected_indexes] imagenet_valid_accs = np.array(imagenet_info["valid_accs"])[selected_indexes] imagenet_test_accs = np.array(imagenet_info["test_accs"])[selected_indexes] CoRelMatrix = calculate_correlation( cifar010_valid_accs, cifar010_test_accs, cifar100_valid_accs, cifar100_test_accs, imagenet_valid_accs, imagenet_test_accs, ) sns.heatmap( CoRelMatrix, annot=True, annot_kws={"size": sns_size}, fmt=".3f", linewidths=0.5, ax=ax2, xticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"], yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"], ) ax1.set_title("Correlation coefficient over ALL candidates") ax2.set_title( "Correlation coefficient over candidates with accuracy > {:}%".format(acc_bar) ) save_path = (vis_save_dir / "{:}-all-relative-rank.png".format(indicator)).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") print("{:} save into {:}".format(time_string(), save_path)) plt.close("all") if __name__ == "__main__": parser = argparse.ArgumentParser( description="NAS-Bench-X", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--save_dir", type=str, default="output/vis-nas-bench", help="Folder to save checkpoints and log.", ) # use for train the model args = parser.parse_args() to_save_dir = Path(args.save_dir) datasets = ["cifar10", "cifar100", "ImageNet16-120"] api201 = create(None, "tss", verbose=True) for xdata in datasets: visualize_tss_info(api201, xdata, to_save_dir) api_sss = create(None, "size", verbose=True) for xdata in datasets: visualize_sss_info(api_sss, xdata, to_save_dir) visualize_info(None, to_save_dir, "tss") visualize_info(None, to_save_dir, "sss") visualize_rank_info(None, to_save_dir, "tss") visualize_rank_info(None, to_save_dir, "sss") visualize_all_rank_info(None, to_save_dir, "tss") visualize_all_rank_info(None, to_save_dir, "sss")