656 lines
		
	
	
		
			24 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			656 lines
		
	
	
		
			24 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| ###############################################################
 | |
| # NATS-Bench (arxiv.org/pdf/2009.00437.pdf), IEEE TPAMI 2021  #
 | |
| # The code to draw Figure 2 / 3 / 4 / 5 in our paper.         #
 | |
| ###############################################################
 | |
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06           #
 | |
| ###############################################################
 | |
| # Usage: python exps/NATS-Bench/draw-fig2_5.py                #
 | |
| ###############################################################
 | |
| import os, sys, time, torch, argparse
 | |
| import scipy
 | |
| import numpy as np
 | |
| from typing import List, Text, Dict, Any
 | |
| from shutil import copyfile
 | |
| from collections import defaultdict
 | |
| from copy import deepcopy
 | |
| from pathlib import Path
 | |
| import matplotlib
 | |
| import seaborn as sns
 | |
| 
 | |
| matplotlib.use("agg")
 | |
| import matplotlib.pyplot as plt
 | |
| import matplotlib.ticker as ticker
 | |
| 
 | |
| lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
 | |
| if str(lib_dir) not in sys.path:
 | |
|     sys.path.insert(0, str(lib_dir))
 | |
| from config_utils import dict2config, load_config
 | |
| from log_utils import time_string
 | |
| from models import get_cell_based_tiny_net
 | |
| from nats_bench import create
 | |
| 
 | |
| 
 | |
| def visualize_relative_info(api, vis_save_dir, indicator):
 | |
|     vis_save_dir = vis_save_dir.resolve()
 | |
|     # print ('{:} start to visualize {:} information'.format(time_string(), api))
 | |
|     vis_save_dir.mkdir(parents=True, exist_ok=True)
 | |
| 
 | |
|     cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
 | |
|         "cifar10", indicator
 | |
|     )
 | |
|     cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
 | |
|         "cifar100", indicator
 | |
|     )
 | |
|     imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
 | |
|         "ImageNet16-120", indicator
 | |
|     )
 | |
|     cifar010_info = torch.load(cifar010_cache_path)
 | |
|     cifar100_info = torch.load(cifar100_cache_path)
 | |
|     imagenet_info = torch.load(imagenet_cache_path)
 | |
|     indexes = list(range(len(cifar010_info["params"])))
 | |
| 
 | |
|     print("{:} start to visualize relative ranking".format(time_string()))
 | |
| 
 | |
|     cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info["test_accs"][i])
 | |
|     cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info["test_accs"][i])
 | |
|     imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info["test_accs"][i])
 | |
| 
 | |
|     cifar100_labels, imagenet_labels = [], []
 | |
|     for idx in cifar010_ord_indexes:
 | |
|         cifar100_labels.append(cifar100_ord_indexes.index(idx))
 | |
|         imagenet_labels.append(imagenet_ord_indexes.index(idx))
 | |
|     print("{:} prepare data done.".format(time_string()))
 | |
| 
 | |
|     dpi, width, height = 200, 1400, 800
 | |
|     figsize = width / float(dpi), height / float(dpi)
 | |
|     LabelSize, LegendFontsize = 18, 12
 | |
|     resnet_scale, resnet_alpha = 120, 0.5
 | |
| 
 | |
|     fig = plt.figure(figsize=figsize)
 | |
|     ax = fig.add_subplot(111)
 | |
|     plt.xlim(min(indexes), max(indexes))
 | |
|     plt.ylim(min(indexes), max(indexes))
 | |
|     # plt.ylabel('y').set_rotation(30)
 | |
|     plt.yticks(
 | |
|         np.arange(min(indexes), max(indexes), max(indexes) // 3),
 | |
|         fontsize=LegendFontsize,
 | |
|         rotation="vertical",
 | |
|     )
 | |
|     plt.xticks(
 | |
|         np.arange(min(indexes), max(indexes), max(indexes) // 5),
 | |
|         fontsize=LegendFontsize,
 | |
|     )
 | |
|     ax.scatter(indexes, cifar100_labels, marker="^", s=0.5, c="tab:green", alpha=0.8)
 | |
|     ax.scatter(indexes, imagenet_labels, marker="*", s=0.5, c="tab:red", alpha=0.8)
 | |
|     ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
 | |
|     ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="CIFAR-10")
 | |
|     ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="CIFAR-100")
 | |
|     ax.scatter([-1], [-1], marker="*", s=100, c="tab:red", label="ImageNet-16-120")
 | |
|     plt.grid(zorder=0)
 | |
|     ax.set_axisbelow(True)
 | |
|     plt.legend(loc=0, fontsize=LegendFontsize)
 | |
|     ax.set_xlabel("architecture ranking in CIFAR-10", fontsize=LabelSize)
 | |
|     ax.set_ylabel("architecture ranking", fontsize=LabelSize)
 | |
|     save_path = (vis_save_dir / "{:}-relative-rank.pdf".format(indicator)).resolve()
 | |
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
 | |
|     save_path = (vis_save_dir / "{:}-relative-rank.png".format(indicator)).resolve()
 | |
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
 | |
|     print("{:} save into {:}".format(time_string(), save_path))
 | |
| 
 | |
| 
 | |
| def visualize_sss_info(api, dataset, vis_save_dir):
 | |
|     vis_save_dir = vis_save_dir.resolve()
 | |
|     print("{:} start to visualize {:} information".format(time_string(), dataset))
 | |
|     vis_save_dir.mkdir(parents=True, exist_ok=True)
 | |
|     cache_file_path = vis_save_dir / "{:}-cache-sss-info.pth".format(dataset)
 | |
|     if not cache_file_path.exists():
 | |
|         print("Do not find cache file : {:}".format(cache_file_path))
 | |
|         params, flops, train_accs, valid_accs, test_accs = [], [], [], [], []
 | |
|         for index in range(len(api)):
 | |
|             cost_info = api.get_cost_info(index, dataset, hp="90")
 | |
|             params.append(cost_info["params"])
 | |
|             flops.append(cost_info["flops"])
 | |
|             # accuracy
 | |
|             info = api.get_more_info(index, dataset, hp="90", is_random=False)
 | |
|             train_accs.append(info["train-accuracy"])
 | |
|             test_accs.append(info["test-accuracy"])
 | |
|             if dataset == "cifar10":
 | |
|                 info = api.get_more_info(
 | |
|                     index, "cifar10-valid", hp="90", is_random=False
 | |
|                 )
 | |
|                 valid_accs.append(info["valid-accuracy"])
 | |
|             else:
 | |
|                 valid_accs.append(info["valid-accuracy"])
 | |
|         info = {
 | |
|             "params": params,
 | |
|             "flops": flops,
 | |
|             "train_accs": train_accs,
 | |
|             "valid_accs": valid_accs,
 | |
|             "test_accs": test_accs,
 | |
|         }
 | |
|         torch.save(info, cache_file_path)
 | |
|     else:
 | |
|         print("Find cache file : {:}".format(cache_file_path))
 | |
|         info = torch.load(cache_file_path)
 | |
|         params, flops, train_accs, valid_accs, test_accs = (
 | |
|             info["params"],
 | |
|             info["flops"],
 | |
|             info["train_accs"],
 | |
|             info["valid_accs"],
 | |
|             info["test_accs"],
 | |
|         )
 | |
|     print("{:} collect data done.".format(time_string()))
 | |
| 
 | |
|     # pyramid = ['8:16:32:48:64', '8:8:16:32:48', '8:8:16:16:32', '8:8:16:16:48', '8:8:16:16:64', '16:16:32:32:64', '32:32:64:64:64']
 | |
|     pyramid = ["8:16:24:32:40", "8:16:32:48:64", "32:40:48:56:64"]
 | |
|     pyramid_indexes = [api.query_index_by_arch(x) for x in pyramid]
 | |
|     largest_indexes = [api.query_index_by_arch("64:64:64:64:64")]
 | |
| 
 | |
|     indexes = list(range(len(params)))
 | |
|     dpi, width, height = 250, 8500, 1300
 | |
|     figsize = width / float(dpi), height / float(dpi)
 | |
|     LabelSize, LegendFontsize = 24, 24
 | |
|     # resnet_scale, resnet_alpha = 120, 0.5
 | |
|     xscale, xalpha = 120, 0.8
 | |
| 
 | |
|     fig, axs = plt.subplots(1, 4, figsize=figsize)
 | |
|     # ax1, ax2, ax3, ax4, ax5 = axs
 | |
|     for ax in axs:
 | |
|         for tick in ax.xaxis.get_major_ticks():
 | |
|             tick.label.set_fontsize(LabelSize)
 | |
|         ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.0f"))
 | |
|         for tick in ax.yaxis.get_major_ticks():
 | |
|             tick.label.set_fontsize(LabelSize)
 | |
|     ax1, ax2, ax3, ax4 = axs
 | |
| 
 | |
|     ax1.scatter(params, train_accs, marker="o", s=0.5, c="tab:blue")
 | |
|     ax1.scatter(
 | |
|         [params[x] for x in pyramid_indexes],
 | |
|         [train_accs[x] for x in pyramid_indexes],
 | |
|         marker="*",
 | |
|         s=xscale,
 | |
|         c="tab:orange",
 | |
|         label="Pyramid Structure",
 | |
|         alpha=xalpha,
 | |
|     )
 | |
|     ax1.scatter(
 | |
|         [params[x] for x in largest_indexes],
 | |
|         [train_accs[x] for x in largest_indexes],
 | |
|         marker="x",
 | |
|         s=xscale,
 | |
|         c="tab:green",
 | |
|         label="Largest Candidate",
 | |
|         alpha=xalpha,
 | |
|     )
 | |
|     ax1.set_xlabel("#parameters (MB)", fontsize=LabelSize)
 | |
|     ax1.set_ylabel("train accuracy (%)", fontsize=LabelSize)
 | |
|     ax1.legend(loc=4, fontsize=LegendFontsize)
 | |
| 
 | |
|     ax2.scatter(flops, train_accs, marker="o", s=0.5, c="tab:blue")
 | |
|     ax2.scatter(
 | |
|         [flops[x] for x in pyramid_indexes],
 | |
|         [train_accs[x] for x in pyramid_indexes],
 | |
|         marker="*",
 | |
|         s=xscale,
 | |
|         c="tab:orange",
 | |
|         label="Pyramid Structure",
 | |
|         alpha=xalpha,
 | |
|     )
 | |
|     ax2.scatter(
 | |
|         [flops[x] for x in largest_indexes],
 | |
|         [train_accs[x] for x in largest_indexes],
 | |
|         marker="x",
 | |
|         s=xscale,
 | |
|         c="tab:green",
 | |
|         label="Largest Candidate",
 | |
|         alpha=xalpha,
 | |
|     )
 | |
|     ax2.set_xlabel("#FLOPs (M)", fontsize=LabelSize)
 | |
|     # ax2.set_ylabel('train accuracy (%)', fontsize=LabelSize)
 | |
|     ax2.legend(loc=4, fontsize=LegendFontsize)
 | |
| 
 | |
|     ax3.scatter(params, test_accs, marker="o", s=0.5, c="tab:blue")
 | |
|     ax3.scatter(
 | |
|         [params[x] for x in pyramid_indexes],
 | |
|         [test_accs[x] for x in pyramid_indexes],
 | |
|         marker="*",
 | |
|         s=xscale,
 | |
|         c="tab:orange",
 | |
|         label="Pyramid Structure",
 | |
|         alpha=xalpha,
 | |
|     )
 | |
|     ax3.scatter(
 | |
|         [params[x] for x in largest_indexes],
 | |
|         [test_accs[x] for x in largest_indexes],
 | |
|         marker="x",
 | |
|         s=xscale,
 | |
|         c="tab:green",
 | |
|         label="Largest Candidate",
 | |
|         alpha=xalpha,
 | |
|     )
 | |
|     ax3.set_xlabel("#parameters (MB)", fontsize=LabelSize)
 | |
|     ax3.set_ylabel("test accuracy (%)", fontsize=LabelSize)
 | |
|     ax3.legend(loc=4, fontsize=LegendFontsize)
 | |
| 
 | |
|     ax4.scatter(flops, test_accs, marker="o", s=0.5, c="tab:blue")
 | |
|     ax4.scatter(
 | |
|         [flops[x] for x in pyramid_indexes],
 | |
|         [test_accs[x] for x in pyramid_indexes],
 | |
|         marker="*",
 | |
|         s=xscale,
 | |
|         c="tab:orange",
 | |
|         label="Pyramid Structure",
 | |
|         alpha=xalpha,
 | |
|     )
 | |
|     ax4.scatter(
 | |
|         [flops[x] for x in largest_indexes],
 | |
|         [test_accs[x] for x in largest_indexes],
 | |
|         marker="x",
 | |
|         s=xscale,
 | |
|         c="tab:green",
 | |
|         label="Largest Candidate",
 | |
|         alpha=xalpha,
 | |
|     )
 | |
|     ax4.set_xlabel("#FLOPs (M)", fontsize=LabelSize)
 | |
|     # ax4.set_ylabel('test accuracy (%)', fontsize=LabelSize)
 | |
|     ax4.legend(loc=4, fontsize=LegendFontsize)
 | |
| 
 | |
|     save_path = vis_save_dir / "sss-{:}.png".format(dataset.lower())
 | |
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
 | |
|     print("{:} save into {:}".format(time_string(), save_path))
 | |
|     plt.close("all")
 | |
| 
 | |
| 
 | |
| def visualize_tss_info(api, dataset, vis_save_dir):
 | |
|     vis_save_dir = vis_save_dir.resolve()
 | |
|     print("{:} start to visualize {:} information".format(time_string(), dataset))
 | |
|     vis_save_dir.mkdir(parents=True, exist_ok=True)
 | |
|     cache_file_path = vis_save_dir / "{:}-cache-tss-info.pth".format(dataset)
 | |
|     if not cache_file_path.exists():
 | |
|         print("Do not find cache file : {:}".format(cache_file_path))
 | |
|         params, flops, train_accs, valid_accs, test_accs = [], [], [], [], []
 | |
|         for index in range(len(api)):
 | |
|             cost_info = api.get_cost_info(index, dataset, hp="12")
 | |
|             params.append(cost_info["params"])
 | |
|             flops.append(cost_info["flops"])
 | |
|             # accuracy
 | |
|             info = api.get_more_info(index, dataset, hp="200", is_random=False)
 | |
|             train_accs.append(info["train-accuracy"])
 | |
|             test_accs.append(info["test-accuracy"])
 | |
|             if dataset == "cifar10":
 | |
|                 info = api.get_more_info(
 | |
|                     index, "cifar10-valid", hp="200", is_random=False
 | |
|                 )
 | |
|                 valid_accs.append(info["valid-accuracy"])
 | |
|             else:
 | |
|                 valid_accs.append(info["valid-accuracy"])
 | |
|             print("")
 | |
|         info = {
 | |
|             "params": params,
 | |
|             "flops": flops,
 | |
|             "train_accs": train_accs,
 | |
|             "valid_accs": valid_accs,
 | |
|             "test_accs": test_accs,
 | |
|         }
 | |
|         torch.save(info, cache_file_path)
 | |
|     else:
 | |
|         print("Find cache file : {:}".format(cache_file_path))
 | |
|         info = torch.load(cache_file_path)
 | |
|         params, flops, train_accs, valid_accs, test_accs = (
 | |
|             info["params"],
 | |
|             info["flops"],
 | |
|             info["train_accs"],
 | |
|             info["valid_accs"],
 | |
|             info["test_accs"],
 | |
|         )
 | |
|     print("{:} collect data done.".format(time_string()))
 | |
| 
 | |
|     resnet = [
 | |
|         "|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|"
 | |
|     ]
 | |
|     resnet_indexes = [api.query_index_by_arch(x) for x in resnet]
 | |
|     largest_indexes = [
 | |
|         api.query_index_by_arch(
 | |
|             "|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|"
 | |
|         )
 | |
|     ]
 | |
| 
 | |
|     indexes = list(range(len(params)))
 | |
|     dpi, width, height = 250, 8500, 1300
 | |
|     figsize = width / float(dpi), height / float(dpi)
 | |
|     LabelSize, LegendFontsize = 24, 24
 | |
|     # resnet_scale, resnet_alpha = 120, 0.5
 | |
|     xscale, xalpha = 120, 0.8
 | |
| 
 | |
|     fig, axs = plt.subplots(1, 4, figsize=figsize)
 | |
|     for ax in axs:
 | |
|         for tick in ax.xaxis.get_major_ticks():
 | |
|             tick.label.set_fontsize(LabelSize)
 | |
|         ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.0f"))
 | |
|         for tick in ax.yaxis.get_major_ticks():
 | |
|             tick.label.set_fontsize(LabelSize)
 | |
|     ax1, ax2, ax3, ax4 = axs
 | |
| 
 | |
|     ax1.scatter(params, train_accs, marker="o", s=0.5, c="tab:blue")
 | |
|     ax1.scatter(
 | |
|         [params[x] for x in resnet_indexes],
 | |
|         [train_accs[x] for x in resnet_indexes],
 | |
|         marker="*",
 | |
|         s=xscale,
 | |
|         c="tab:orange",
 | |
|         label="ResNet",
 | |
|         alpha=xalpha,
 | |
|     )
 | |
|     ax1.scatter(
 | |
|         [params[x] for x in largest_indexes],
 | |
|         [train_accs[x] for x in largest_indexes],
 | |
|         marker="x",
 | |
|         s=xscale,
 | |
|         c="tab:green",
 | |
|         label="Largest Candidate",
 | |
|         alpha=xalpha,
 | |
|     )
 | |
|     ax1.set_xlabel("#parameters (MB)", fontsize=LabelSize)
 | |
|     ax1.set_ylabel("train accuracy (%)", fontsize=LabelSize)
 | |
|     ax1.legend(loc=4, fontsize=LegendFontsize)
 | |
| 
 | |
|     ax2.scatter(flops, train_accs, marker="o", s=0.5, c="tab:blue")
 | |
|     ax2.scatter(
 | |
|         [flops[x] for x in resnet_indexes],
 | |
|         [train_accs[x] for x in resnet_indexes],
 | |
|         marker="*",
 | |
|         s=xscale,
 | |
|         c="tab:orange",
 | |
|         label="ResNet",
 | |
|         alpha=xalpha,
 | |
|     )
 | |
|     ax2.scatter(
 | |
|         [flops[x] for x in largest_indexes],
 | |
|         [train_accs[x] for x in largest_indexes],
 | |
|         marker="x",
 | |
|         s=xscale,
 | |
|         c="tab:green",
 | |
|         label="Largest Candidate",
 | |
|         alpha=xalpha,
 | |
|     )
 | |
|     ax2.set_xlabel("#FLOPs (M)", fontsize=LabelSize)
 | |
|     # ax2.set_ylabel('train accuracy (%)', fontsize=LabelSize)
 | |
|     ax2.legend(loc=4, fontsize=LegendFontsize)
 | |
| 
 | |
|     ax3.scatter(params, test_accs, marker="o", s=0.5, c="tab:blue")
 | |
|     ax3.scatter(
 | |
|         [params[x] for x in resnet_indexes],
 | |
|         [test_accs[x] for x in resnet_indexes],
 | |
|         marker="*",
 | |
|         s=xscale,
 | |
|         c="tab:orange",
 | |
|         label="ResNet",
 | |
|         alpha=xalpha,
 | |
|     )
 | |
|     ax3.scatter(
 | |
|         [params[x] for x in largest_indexes],
 | |
|         [test_accs[x] for x in largest_indexes],
 | |
|         marker="x",
 | |
|         s=xscale,
 | |
|         c="tab:green",
 | |
|         label="Largest Candidate",
 | |
|         alpha=xalpha,
 | |
|     )
 | |
|     ax3.set_xlabel("#parameters (MB)", fontsize=LabelSize)
 | |
|     ax3.set_ylabel("test accuracy (%)", fontsize=LabelSize)
 | |
|     ax3.legend(loc=4, fontsize=LegendFontsize)
 | |
| 
 | |
|     ax4.scatter(flops, test_accs, marker="o", s=0.5, c="tab:blue")
 | |
|     ax4.scatter(
 | |
|         [flops[x] for x in resnet_indexes],
 | |
|         [test_accs[x] for x in resnet_indexes],
 | |
|         marker="*",
 | |
|         s=xscale,
 | |
|         c="tab:orange",
 | |
|         label="ResNet",
 | |
|         alpha=xalpha,
 | |
|     )
 | |
|     ax4.scatter(
 | |
|         [flops[x] for x in largest_indexes],
 | |
|         [test_accs[x] for x in largest_indexes],
 | |
|         marker="x",
 | |
|         s=xscale,
 | |
|         c="tab:green",
 | |
|         label="Largest Candidate",
 | |
|         alpha=xalpha,
 | |
|     )
 | |
|     ax4.set_xlabel("#FLOPs (M)", fontsize=LabelSize)
 | |
|     # ax4.set_ylabel('test accuracy (%)', fontsize=LabelSize)
 | |
|     ax4.legend(loc=4, fontsize=LegendFontsize)
 | |
| 
 | |
|     save_path = vis_save_dir / "tss-{:}.png".format(dataset.lower())
 | |
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
 | |
|     print("{:} save into {:}".format(time_string(), save_path))
 | |
|     plt.close("all")
 | |
| 
 | |
| 
 | |
| def visualize_rank_info(api, vis_save_dir, indicator):
 | |
|     vis_save_dir = vis_save_dir.resolve()
 | |
|     # print ('{:} start to visualize {:} information'.format(time_string(), api))
 | |
|     vis_save_dir.mkdir(parents=True, exist_ok=True)
 | |
| 
 | |
|     cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
 | |
|         "cifar10", indicator
 | |
|     )
 | |
|     cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
 | |
|         "cifar100", indicator
 | |
|     )
 | |
|     imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
 | |
|         "ImageNet16-120", indicator
 | |
|     )
 | |
|     cifar010_info = torch.load(cifar010_cache_path)
 | |
|     cifar100_info = torch.load(cifar100_cache_path)
 | |
|     imagenet_info = torch.load(imagenet_cache_path)
 | |
|     indexes = list(range(len(cifar010_info["params"])))
 | |
| 
 | |
|     print("{:} start to visualize relative ranking".format(time_string()))
 | |
| 
 | |
|     dpi, width, height = 250, 3800, 1200
 | |
|     figsize = width / float(dpi), height / float(dpi)
 | |
|     LabelSize, LegendFontsize = 14, 14
 | |
| 
 | |
|     fig, axs = plt.subplots(1, 3, figsize=figsize)
 | |
|     ax1, ax2, ax3 = axs
 | |
| 
 | |
|     def get_labels(info):
 | |
|         ord_test_indexes = sorted(indexes, key=lambda i: info["test_accs"][i])
 | |
|         ord_valid_indexes = sorted(indexes, key=lambda i: info["valid_accs"][i])
 | |
|         labels = []
 | |
|         for idx in ord_test_indexes:
 | |
|             labels.append(ord_valid_indexes.index(idx))
 | |
|         return labels
 | |
| 
 | |
|     def plot_ax(labels, ax, name):
 | |
|         for tick in ax.xaxis.get_major_ticks():
 | |
|             tick.label.set_fontsize(LabelSize)
 | |
|         for tick in ax.yaxis.get_major_ticks():
 | |
|             tick.label.set_fontsize(LabelSize)
 | |
|             tick.label.set_rotation(90)
 | |
|         ax.set_xlim(min(indexes), max(indexes))
 | |
|         ax.set_ylim(min(indexes), max(indexes))
 | |
|         ax.yaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 3))
 | |
|         ax.xaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 5))
 | |
|         ax.scatter(indexes, labels, marker="^", s=0.5, c="tab:green", alpha=0.8)
 | |
|         ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
 | |
|         ax.scatter(
 | |
|             [-1], [-1], marker="^", s=100, c="tab:green", label="{:} test".format(name)
 | |
|         )
 | |
|         ax.scatter(
 | |
|             [-1],
 | |
|             [-1],
 | |
|             marker="o",
 | |
|             s=100,
 | |
|             c="tab:blue",
 | |
|             label="{:} validation".format(name),
 | |
|         )
 | |
|         ax.legend(loc=4, fontsize=LegendFontsize)
 | |
|         ax.set_xlabel("ranking on the {:} validation".format(name), fontsize=LabelSize)
 | |
|         ax.set_ylabel("architecture ranking", fontsize=LabelSize)
 | |
| 
 | |
|     labels = get_labels(cifar010_info)
 | |
|     plot_ax(labels, ax1, "CIFAR-10")
 | |
|     labels = get_labels(cifar100_info)
 | |
|     plot_ax(labels, ax2, "CIFAR-100")
 | |
|     labels = get_labels(imagenet_info)
 | |
|     plot_ax(labels, ax3, "ImageNet-16-120")
 | |
| 
 | |
|     save_path = (
 | |
|         vis_save_dir / "{:}-same-relative-rank.pdf".format(indicator)
 | |
|     ).resolve()
 | |
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
 | |
|     save_path = (
 | |
|         vis_save_dir / "{:}-same-relative-rank.png".format(indicator)
 | |
|     ).resolve()
 | |
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
 | |
|     print("{:} save into {:}".format(time_string(), save_path))
 | |
|     plt.close("all")
 | |
| 
 | |
| 
 | |
| def compute_kendalltau(vectori, vectorj):
 | |
|     # indexes = list(range(len(vectori)))
 | |
|     # rank_1 = sorted(indexes, key=lambda i: vectori[i])
 | |
|     # rank_2 = sorted(indexes, key=lambda i: vectorj[i])
 | |
|     return scipy.stats.kendalltau(vectori, vectorj).correlation
 | |
| 
 | |
| 
 | |
| def calculate_correlation(*vectors):
 | |
|     matrix = []
 | |
|     for i, vectori in enumerate(vectors):
 | |
|         x = []
 | |
|         for j, vectorj in enumerate(vectors):
 | |
|             # x.append(np.corrcoef(vectori, vectorj)[0,1])
 | |
|             x.append(compute_kendalltau(vectori, vectorj))
 | |
|         matrix.append(x)
 | |
|     return np.array(matrix)
 | |
| 
 | |
| 
 | |
| def visualize_all_rank_info(api, vis_save_dir, indicator):
 | |
|     vis_save_dir = vis_save_dir.resolve()
 | |
|     # print ('{:} start to visualize {:} information'.format(time_string(), api))
 | |
|     vis_save_dir.mkdir(parents=True, exist_ok=True)
 | |
| 
 | |
|     cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
 | |
|         "cifar10", indicator
 | |
|     )
 | |
|     cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
 | |
|         "cifar100", indicator
 | |
|     )
 | |
|     imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
 | |
|         "ImageNet16-120", indicator
 | |
|     )
 | |
|     cifar010_info = torch.load(cifar010_cache_path)
 | |
|     cifar100_info = torch.load(cifar100_cache_path)
 | |
|     imagenet_info = torch.load(imagenet_cache_path)
 | |
|     indexes = list(range(len(cifar010_info["params"])))
 | |
| 
 | |
|     print("{:} start to visualize relative ranking".format(time_string()))
 | |
| 
 | |
|     dpi, width, height = 250, 3200, 1400
 | |
|     figsize = width / float(dpi), height / float(dpi)
 | |
|     LabelSize, LegendFontsize = 14, 14
 | |
| 
 | |
|     fig, axs = plt.subplots(1, 2, figsize=figsize)
 | |
|     ax1, ax2 = axs
 | |
| 
 | |
|     sns_size, xformat = 15, ".2f"
 | |
|     CoRelMatrix = calculate_correlation(
 | |
|         cifar010_info["valid_accs"],
 | |
|         cifar010_info["test_accs"],
 | |
|         cifar100_info["valid_accs"],
 | |
|         cifar100_info["test_accs"],
 | |
|         imagenet_info["valid_accs"],
 | |
|         imagenet_info["test_accs"],
 | |
|     )
 | |
| 
 | |
|     sns.heatmap(
 | |
|         CoRelMatrix,
 | |
|         annot=True,
 | |
|         annot_kws={"size": sns_size},
 | |
|         fmt=xformat,
 | |
|         linewidths=0.5,
 | |
|         ax=ax1,
 | |
|         xticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"],
 | |
|         yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"],
 | |
|     )
 | |
| 
 | |
|     selected_indexes, acc_bar = [], 92
 | |
|     for i, acc in enumerate(cifar010_info["test_accs"]):
 | |
|         if acc > acc_bar:
 | |
|             selected_indexes.append(i)
 | |
|     cifar010_valid_accs = np.array(cifar010_info["valid_accs"])[selected_indexes]
 | |
|     cifar010_test_accs = np.array(cifar010_info["test_accs"])[selected_indexes]
 | |
|     cifar100_valid_accs = np.array(cifar100_info["valid_accs"])[selected_indexes]
 | |
|     cifar100_test_accs = np.array(cifar100_info["test_accs"])[selected_indexes]
 | |
|     imagenet_valid_accs = np.array(imagenet_info["valid_accs"])[selected_indexes]
 | |
|     imagenet_test_accs = np.array(imagenet_info["test_accs"])[selected_indexes]
 | |
|     CoRelMatrix = calculate_correlation(
 | |
|         cifar010_valid_accs,
 | |
|         cifar010_test_accs,
 | |
|         cifar100_valid_accs,
 | |
|         cifar100_test_accs,
 | |
|         imagenet_valid_accs,
 | |
|         imagenet_test_accs,
 | |
|     )
 | |
| 
 | |
|     sns.heatmap(
 | |
|         CoRelMatrix,
 | |
|         annot=True,
 | |
|         annot_kws={"size": sns_size},
 | |
|         fmt=xformat,
 | |
|         linewidths=0.5,
 | |
|         ax=ax2,
 | |
|         xticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"],
 | |
|         yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"],
 | |
|     )
 | |
|     ax1.set_title("Correlation coefficient over ALL candidates")
 | |
|     ax2.set_title(
 | |
|         "Correlation coefficient over candidates with accuracy > {:}%".format(acc_bar)
 | |
|     )
 | |
|     save_path = (vis_save_dir / "{:}-all-relative-rank.png".format(indicator)).resolve()
 | |
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
 | |
|     print("{:} save into {:}".format(time_string(), save_path))
 | |
|     plt.close("all")
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     parser = argparse.ArgumentParser(
 | |
|         description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--save_dir",
 | |
|         type=str,
 | |
|         default="output/vis-nas-bench",
 | |
|         help="Folder to save checkpoints and log.",
 | |
|     )
 | |
|     # use for train the model
 | |
|     args = parser.parse_args()
 | |
| 
 | |
|     to_save_dir = Path(args.save_dir)
 | |
| 
 | |
|     datasets = ["cifar10", "cifar100", "ImageNet16-120"]
 | |
|     # Figure 3 (a-c)
 | |
|     api_tss = create(None, "tss", verbose=True)
 | |
|     for xdata in datasets:
 | |
|         visualize_tss_info(api_tss, xdata, to_save_dir)
 | |
|     # Figure 3 (d-f)
 | |
|     api_sss = create(None, "size", verbose=True)
 | |
|     for xdata in datasets:
 | |
|         visualize_sss_info(api_sss, xdata, to_save_dir)
 | |
| 
 | |
|     # Figure 2
 | |
|     visualize_relative_info(None, to_save_dir, "tss")
 | |
|     visualize_relative_info(None, to_save_dir, "sss")
 | |
| 
 | |
|     # Figure 4
 | |
|     visualize_rank_info(None, to_save_dir, "tss")
 | |
|     visualize_rank_info(None, to_save_dir, "sss")
 | |
| 
 | |
|     # Figure 5
 | |
|     visualize_all_rank_info(None, to_save_dir, "tss")
 | |
|     visualize_all_rank_info(None, to_save_dir, "sss")
 |