Update visualization codes for NATS-Bench

This commit is contained in:
D-X-Y 2020-11-30 00:48:10 +08:00
parent 550d24ec07
commit 29428bf5a3
6 changed files with 802 additions and 10 deletions

View File

@ -0,0 +1,90 @@
###############################################################
# NATS-Bench (https://arxiv.org/pdf/2009.00437.pdf) #
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
###############################################################
# Usage: python exps/NATS-Bench/draw-correlations.py #
###############################################################
import os, gc, sys, time, scipy, torch, argparse
import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict, OrderedDict
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
matplotlib.use('agg')
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from nats_bench import create
from log_utils import time_string
def get_valid_test_acc(api, arch, dataset):
is_size_space = api.search_space_name == 'size'
if dataset == 'cifar10':
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
test_acc = xinfo['test-accuracy']
xinfo = api.get_more_info(arch, dataset='cifar10-valid', hp=90 if is_size_space else 200, is_random=False)
valid_acc = xinfo['valid-accuracy']
else:
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
valid_acc = xinfo['valid-accuracy']
test_acc = xinfo['test-accuracy']
return valid_acc, test_acc, 'validation = {:.2f}, test = {:.2f}\n'.format(valid_acc, test_acc)
def compute_kendalltau(vectori, vectorj):
# indexes = list(range(len(vectori)))
# rank_1 = sorted(indexes, key=lambda i: vectori[i])
# rank_2 = sorted(indexes, key=lambda i: vectorj[i])
# import pdb; pdb.set_trace()
coef, p = scipy.stats.kendalltau(vectori, vectorj)
return coef
def compute_spearmanr(vectori, vectorj):
coef, p = scipy.stats.spearmanr(vectori, vectorj)
return coef
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/nas-algos', help='Folder to save checkpoints and log.')
parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.')
args = parser.parse_args()
save_dir = Path(args.save_dir)
api = create(None, 'tss', fast_mode=True, verbose=False)
indexes = list(range(1, 10000, 300))
scores_1 = []
scores_2 = []
for index in indexes:
valid_acc, test_acc, _ = get_valid_test_acc(api, index, 'cifar10')
scores_1.append(valid_acc)
scores_2.append(test_acc)
correlation = compute_kendalltau(scores_1, scores_2)
print('The kendall tau correlation of {:} samples : {:}'.format(len(indexes), correlation))
correlation = compute_spearmanr(scores_1, scores_2)
print('The spearmanr correlation of {:} samples : {:}'.format(len(indexes), correlation))
# scores_1 = ['{:.2f}'.format(x) for x in scores_1]
# scores_2 = ['{:.2f}'.format(x) for x in scores_2]
# print(', '.join(scores_1))
# print(', '.join(scores_2))
dpi, width, height = 250, 1000, 1000
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 14, 14
fig, ax = plt.subplots(1, 1, figsize=figsize)
ax.scatter(scores_1, scores_2 , marker='^', s=0.5, c='tab:green', alpha=0.8)
save_path = '/Users/xuanyidong/Desktop/test-temp-rank.png'
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
plt.close('all')

View File

@ -0,0 +1,415 @@
###############################################################
# NATS-Bench (https://arxiv.org/pdf/2009.00437.pdf) #
# The code to draw Figure 2 / 3 / 4 / 5 in our paper. #
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
###############################################################
# Usage: python exps/NATS-Bench/draw-fig2_5.py #
###############################################################
import os, sys, time, torch, argparse
import scipy
import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
matplotlib.use('agg')
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from log_utils import time_string
from models import get_cell_based_tiny_net
from nats_bench import create
def visualize_relative_info(api, vis_save_dir, indicator):
vis_save_dir = vis_save_dir.resolve()
# print ('{:} start to visualize {:} information'.format(time_string(), api))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cifar010_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar10', indicator)
cifar100_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar100', indicator)
imagenet_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('ImageNet16-120', indicator)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
indexes = list(range(len(cifar010_info['params'])))
print ('{:} start to visualize relative ranking'.format(time_string()))
cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info['test_accs'][i])
cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info['test_accs'][i])
imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info['test_accs'][i])
cifar100_labels, imagenet_labels = [], []
for idx in cifar010_ord_indexes:
cifar100_labels.append( cifar100_ord_indexes.index(idx) )
imagenet_labels.append( imagenet_ord_indexes.index(idx) )
print ('{:} prepare data done.'.format(time_string()))
dpi, width, height = 200, 1400, 800
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 18, 12
resnet_scale, resnet_alpha = 120, 0.5
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
plt.xlim(min(indexes), max(indexes))
plt.ylim(min(indexes), max(indexes))
# plt.ylabel('y').set_rotation(30)
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//3), fontsize=LegendFontsize, rotation='vertical')
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//5), fontsize=LegendFontsize)
ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8)
ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8)
ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8)
ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10')
ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-100')
ax.scatter([-1], [-1], marker='*', s=100, c='tab:red' , label='ImageNet-16-120')
plt.grid(zorder=0)
ax.set_axisbelow(True)
plt.legend(loc=0, fontsize=LegendFontsize)
ax.set_xlabel('architecture ranking in CIFAR-10', fontsize=LabelSize)
ax.set_ylabel('architecture ranking', fontsize=LabelSize)
save_path = (vis_save_dir / '{:}-relative-rank.pdf'.format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
save_path = (vis_save_dir / '{:}-relative-rank.png'.format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
def visualize_sss_info(api, dataset, vis_save_dir):
vis_save_dir = vis_save_dir.resolve()
print ('{:} start to visualize {:} information'.format(time_string(), dataset))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cache_file_path = vis_save_dir / '{:}-cache-sss-info.pth'.format(dataset)
if not cache_file_path.exists():
print ('Do not find cache file : {:}'.format(cache_file_path))
params, flops, train_accs, valid_accs, test_accs = [], [], [], [], []
for index in range(len(api)):
cost_info = api.get_cost_info(index, dataset, hp='90')
params.append(cost_info['params'])
flops.append(cost_info['flops'])
# accuracy
info = api.get_more_info(index, dataset, hp='90', is_random=False)
train_accs.append(info['train-accuracy'])
test_accs.append(info['test-accuracy'])
if dataset == 'cifar10':
info = api.get_more_info(index, 'cifar10-valid', hp='90', is_random=False)
valid_accs.append(info['valid-accuracy'])
else:
valid_accs.append(info['valid-accuracy'])
info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs}
torch.save(info, cache_file_path)
else:
print ('Find cache file : {:}'.format(cache_file_path))
info = torch.load(cache_file_path)
params, flops, train_accs, valid_accs, test_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs']
print ('{:} collect data done.'.format(time_string()))
# pyramid = ['8:16:32:48:64', '8:8:16:32:48', '8:8:16:16:32', '8:8:16:16:48', '8:8:16:16:64', '16:16:32:32:64', '32:32:64:64:64']
pyramid = ['8:16:24:32:40', '8:16:32:48:64', '32:40:48:56:64']
pyramid_indexes = [api.query_index_by_arch(x) for x in pyramid]
largest_indexes = [api.query_index_by_arch('64:64:64:64:64')]
indexes = list(range(len(params)))
dpi, width, height = 250, 8500, 1300
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 24, 24
# resnet_scale, resnet_alpha = 120, 0.5
xscale, xalpha = 120, 0.8
fig, axs = plt.subplots(1, 4, figsize=figsize)
# ax1, ax2, ax3, ax4, ax5 = axs
for ax in axs:
for tick in ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.0f'))
for tick in ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
ax1, ax2, ax3, ax4 = axs
ax1.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue')
ax1.scatter([params[x] for x in pyramid_indexes], [train_accs[x] for x in pyramid_indexes], marker='*', s=xscale, c='tab:orange', label='Pyramid Structure', alpha=xalpha)
ax1.scatter([params[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha)
ax1.set_xlabel('#parameters (MB)', fontsize=LabelSize)
ax1.set_ylabel('train accuracy (%)', fontsize=LabelSize)
ax1.legend(loc=4, fontsize=LegendFontsize)
ax2.scatter(flops, train_accs, marker='o', s=0.5, c='tab:blue')
ax2.scatter([flops[x] for x in pyramid_indexes], [train_accs[x] for x in pyramid_indexes], marker='*', s=xscale, c='tab:orange', label='Pyramid Structure', alpha=xalpha)
ax2.scatter([flops[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha)
ax2.set_xlabel('#FLOPs (M)', fontsize=LabelSize)
# ax2.set_ylabel('train accuracy (%)', fontsize=LabelSize)
ax2.legend(loc=4, fontsize=LegendFontsize)
ax3.scatter(params, test_accs, marker='o', s=0.5, c='tab:blue')
ax3.scatter([params[x] for x in pyramid_indexes], [test_accs[x] for x in pyramid_indexes], marker='*', s=xscale, c='tab:orange', label='Pyramid Structure', alpha=xalpha)
ax3.scatter([params[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha)
ax3.set_xlabel('#parameters (MB)', fontsize=LabelSize)
ax3.set_ylabel('test accuracy (%)', fontsize=LabelSize)
ax3.legend(loc=4, fontsize=LegendFontsize)
ax4.scatter(flops, test_accs, marker='o', s=0.5, c='tab:blue')
ax4.scatter([flops[x] for x in pyramid_indexes], [test_accs[x] for x in pyramid_indexes], marker='*', s=xscale, c='tab:orange', label='Pyramid Structure', alpha=xalpha)
ax4.scatter([flops[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha)
ax4.set_xlabel('#FLOPs (M)', fontsize=LabelSize)
# ax4.set_ylabel('test accuracy (%)', fontsize=LabelSize)
ax4.legend(loc=4, fontsize=LegendFontsize)
save_path = vis_save_dir / 'sss-{:}.png'.format(dataset.lower())
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
plt.close('all')
def visualize_tss_info(api, dataset, vis_save_dir):
vis_save_dir = vis_save_dir.resolve()
print ('{:} start to visualize {:} information'.format(time_string(), dataset))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cache_file_path = vis_save_dir / '{:}-cache-tss-info.pth'.format(dataset)
if not cache_file_path.exists():
print ('Do not find cache file : {:}'.format(cache_file_path))
params, flops, train_accs, valid_accs, test_accs = [], [], [], [], []
for index in range(len(api)):
cost_info = api.get_cost_info(index, dataset, hp='12')
params.append(cost_info['params'])
flops.append(cost_info['flops'])
# accuracy
info = api.get_more_info(index, dataset, hp='200', is_random=False)
train_accs.append(info['train-accuracy'])
test_accs.append(info['test-accuracy'])
if dataset == 'cifar10':
info = api.get_more_info(index, 'cifar10-valid', hp='200', is_random=False)
valid_accs.append(info['valid-accuracy'])
else:
valid_accs.append(info['valid-accuracy'])
print('')
info = {'params': params, 'flops': flops, 'train_accs': train_accs, 'valid_accs': valid_accs, 'test_accs': test_accs}
torch.save(info, cache_file_path)
else:
print ('Find cache file : {:}'.format(cache_file_path))
info = torch.load(cache_file_path)
params, flops, train_accs, valid_accs, test_accs = info['params'], info['flops'], info['train_accs'], info['valid_accs'], info['test_accs']
print ('{:} collect data done.'.format(time_string()))
resnet = ['|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|']
resnet_indexes = [api.query_index_by_arch(x) for x in resnet]
largest_indexes = [api.query_index_by_arch('|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|')]
indexes = list(range(len(params)))
dpi, width, height = 250, 8500, 1300
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 24, 24
# resnet_scale, resnet_alpha = 120, 0.5
xscale, xalpha = 120, 0.8
fig, axs = plt.subplots(1, 4, figsize=figsize)
for ax in axs:
for tick in ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.0f'))
for tick in ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
ax1, ax2, ax3, ax4 = axs
ax1.scatter(params, train_accs, marker='o', s=0.5, c='tab:blue')
ax1.scatter([params[x] for x in resnet_indexes] , [train_accs[x] for x in resnet_indexes], marker='*', s=xscale, c='tab:orange', label='ResNet', alpha=xalpha)
ax1.scatter([params[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha)
ax1.set_xlabel('#parameters (MB)', fontsize=LabelSize)
ax1.set_ylabel('train accuracy (%)', fontsize=LabelSize)
ax1.legend(loc=4, fontsize=LegendFontsize)
ax2.scatter(flops, train_accs, marker='o', s=0.5, c='tab:blue')
ax2.scatter([flops[x] for x in resnet_indexes], [train_accs[x] for x in resnet_indexes], marker='*', s=xscale, c='tab:orange', label='ResNet', alpha=xalpha)
ax2.scatter([flops[x] for x in largest_indexes], [train_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha)
ax2.set_xlabel('#FLOPs (M)', fontsize=LabelSize)
# ax2.set_ylabel('train accuracy (%)', fontsize=LabelSize)
ax2.legend(loc=4, fontsize=LegendFontsize)
ax3.scatter(params, test_accs, marker='o', s=0.5, c='tab:blue')
ax3.scatter([params[x] for x in resnet_indexes] , [test_accs[x] for x in resnet_indexes], marker='*', s=xscale, c='tab:orange', label='ResNet', alpha=xalpha)
ax3.scatter([params[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha)
ax3.set_xlabel('#parameters (MB)', fontsize=LabelSize)
ax3.set_ylabel('test accuracy (%)', fontsize=LabelSize)
ax3.legend(loc=4, fontsize=LegendFontsize)
ax4.scatter(flops, test_accs, marker='o', s=0.5, c='tab:blue')
ax4.scatter([flops[x] for x in resnet_indexes], [test_accs[x] for x in resnet_indexes], marker='*', s=xscale, c='tab:orange', label='ResNet', alpha=xalpha)
ax4.scatter([flops[x] for x in largest_indexes], [test_accs[x] for x in largest_indexes], marker='x', s=xscale, c='tab:green', label='Largest Candidate', alpha=xalpha)
ax4.set_xlabel('#FLOPs (M)', fontsize=LabelSize)
# ax4.set_ylabel('test accuracy (%)', fontsize=LabelSize)
ax4.legend(loc=4, fontsize=LegendFontsize)
save_path = vis_save_dir / 'tss-{:}.png'.format(dataset.lower())
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
plt.close('all')
def visualize_rank_info(api, vis_save_dir, indicator):
vis_save_dir = vis_save_dir.resolve()
# print ('{:} start to visualize {:} information'.format(time_string(), api))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cifar010_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar10', indicator)
cifar100_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar100', indicator)
imagenet_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('ImageNet16-120', indicator)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
indexes = list(range(len(cifar010_info['params'])))
print ('{:} start to visualize relative ranking'.format(time_string()))
dpi, width, height = 250, 3800, 1200
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 14, 14
fig, axs = plt.subplots(1, 3, figsize=figsize)
ax1, ax2, ax3 = axs
def get_labels(info):
ord_test_indexes = sorted(indexes, key=lambda i: info['test_accs'][i])
ord_valid_indexes = sorted(indexes, key=lambda i: info['valid_accs'][i])
labels = []
for idx in ord_test_indexes:
labels.append(ord_valid_indexes.index(idx))
return labels
def plot_ax(labels, ax, name):
for tick in ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
for tick in ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
tick.label.set_rotation(90)
ax.set_xlim(min(indexes), max(indexes))
ax.set_ylim(min(indexes), max(indexes))
ax.yaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes)//3))
ax.xaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes)//5))
ax.scatter(indexes, labels , marker='^', s=0.5, c='tab:green', alpha=0.8)
ax.scatter(indexes, indexes, marker='o', s=0.5, c='tab:blue' , alpha=0.8)
ax.scatter([-1], [-1], marker='^', s=100, c='tab:green' , label='{:} test'.format(name))
ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='{:} validation'.format(name))
ax.legend(loc=4, fontsize=LegendFontsize)
ax.set_xlabel('ranking on the {:} validation'.format(name), fontsize=LabelSize)
ax.set_ylabel('architecture ranking', fontsize=LabelSize)
labels = get_labels(cifar010_info)
plot_ax(labels, ax1, 'CIFAR-10')
labels = get_labels(cifar100_info)
plot_ax(labels, ax2, 'CIFAR-100')
labels = get_labels(imagenet_info)
plot_ax(labels, ax3, 'ImageNet-16-120')
save_path = (vis_save_dir / '{:}-same-relative-rank.pdf'.format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
save_path = (vis_save_dir / '{:}-same-relative-rank.png'.format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
plt.close('all')
def compute_kendalltau(vectori, vectorj):
# indexes = list(range(len(vectori)))
# rank_1 = sorted(indexes, key=lambda i: vectori[i])
# rank_2 = sorted(indexes, key=lambda i: vectorj[i])
return scipy.stats.kendalltau(vectori, vectorj).correlation
def calculate_correlation(*vectors):
matrix = []
for i, vectori in enumerate(vectors):
x = []
for j, vectorj in enumerate(vectors):
# x.append(np.corrcoef(vectori, vectorj)[0,1])
x.append(compute_kendalltau(vectori, vectorj))
matrix.append( x )
return np.array(matrix)
def visualize_all_rank_info(api, vis_save_dir, indicator):
vis_save_dir = vis_save_dir.resolve()
# print ('{:} start to visualize {:} information'.format(time_string(), api))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cifar010_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar10', indicator)
cifar100_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar100', indicator)
imagenet_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('ImageNet16-120', indicator)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
indexes = list(range(len(cifar010_info['params'])))
print ('{:} start to visualize relative ranking'.format(time_string()))
dpi, width, height = 250, 3200, 1400
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 14, 14
fig, axs = plt.subplots(1, 2, figsize=figsize)
ax1, ax2 = axs
sns_size, xformat = 15, '.2f'
CoRelMatrix = calculate_correlation(cifar010_info['valid_accs'], cifar010_info['test_accs'], cifar100_info['valid_accs'], cifar100_info['test_accs'], imagenet_info['valid_accs'], imagenet_info['test_accs'])
sns.heatmap(CoRelMatrix, annot=True, annot_kws={'size':sns_size}, fmt=xformat, linewidths=0.5, ax=ax1,
xticklabels=['C10-V', 'C10-T', 'C100-V', 'C100-T', 'I120-V', 'I120-T'],
yticklabels=['C10-V', 'C10-T', 'C100-V', 'C100-T', 'I120-V', 'I120-T'])
selected_indexes, acc_bar = [], 92
for i, acc in enumerate(cifar010_info['test_accs']):
if acc > acc_bar: selected_indexes.append( i )
cifar010_valid_accs = np.array(cifar010_info['valid_accs'])[ selected_indexes ]
cifar010_test_accs = np.array(cifar010_info['test_accs']) [ selected_indexes ]
cifar100_valid_accs = np.array(cifar100_info['valid_accs'])[ selected_indexes ]
cifar100_test_accs = np.array(cifar100_info['test_accs']) [ selected_indexes ]
imagenet_valid_accs = np.array(imagenet_info['valid_accs'])[ selected_indexes ]
imagenet_test_accs = np.array(imagenet_info['test_accs']) [ selected_indexes ]
CoRelMatrix = calculate_correlation(cifar010_valid_accs, cifar010_test_accs, cifar100_valid_accs, cifar100_test_accs, imagenet_valid_accs, imagenet_test_accs)
sns.heatmap(CoRelMatrix, annot=True, annot_kws={'size':sns_size}, fmt=xformat, linewidths=0.5, ax=ax2,
xticklabels=['C10-V', 'C10-T', 'C100-V', 'C100-T', 'I120-V', 'I120-T'],
yticklabels=['C10-V', 'C10-T', 'C100-V', 'C100-T', 'I120-V', 'I120-T'])
ax1.set_title('Correlation coefficient over ALL candidates')
ax2.set_title('Correlation coefficient over candidates with accuracy > {:}%'.format(acc_bar))
save_path = (vis_save_dir / '{:}-all-relative-rank.png'.format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
plt.close('all')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench', help='Folder to save checkpoints and log.')
# use for train the model
args = parser.parse_args()
to_save_dir = Path(args.save_dir)
datasets = ['cifar10', 'cifar100', 'ImageNet16-120']
# Figure 3 (a-c)
api_tss = create(None, 'tss', verbose=True)
for xdata in datasets:
visualize_tss_info(api_tss, xdata, to_save_dir)
# Figure 3 (d-f)
api_sss = create(None, 'size', verbose=True)
for xdata in datasets:
visualize_sss_info(api_sss, xdata, to_save_dir)
# Figure 2
visualize_relative_info(None, to_save_dir, 'tss')
visualize_relative_info(None, to_save_dir, 'sss')
# Figure 4
visualize_rank_info(None, to_save_dir, 'tss')
visualize_rank_info(None, to_save_dir, 'sss')
# Figure 5
visualize_all_rank_info(None, to_save_dir, 'tss')
visualize_all_rank_info(None, to_save_dir, 'sss')

View File

@ -33,7 +33,7 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None):
alg2name['REA'] = 'R-EA-SS3'
alg2name['REINFORCE'] = 'REINFORCE-0.01'
alg2name['RANDOM'] = 'RANDOM'
# alg2name['BOHB'] = 'BOHB'
alg2name['BOHB'] = 'BOHB'
for alg, name in alg2name.items():
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'results.pth')
assert os.path.isfile(alg2path[alg]), 'invalid path : {:}'.format(alg2path[alg])
@ -59,7 +59,26 @@ def query_performance(api, data, dataset, ticket):
accuracy_a, accuracy_b = info_a['test-accuracy'], info_b['test-accuracy']
interplate = (time_b-ticket) / (time_b-time_a) * accuracy_a + (ticket-time_a) / (time_b-time_a) * accuracy_b
results.append(interplate)
return sum(results) / len(results)
# return sum(results) / len(results)
return np.mean(results), np.std(results)
def show_valid_test(api, data, dataset):
valid_accs, test_accs, is_size_space = [], [], api.search_space_name == 'size'
for i, info in data.items():
time, arch = info['time_w_arch'][-1]
if dataset == 'cifar10':
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
test_accs.append(xinfo['test-accuracy'])
xinfo = api.get_more_info(arch, dataset='cifar10-valid', hp=90 if is_size_space else 200, is_random=False)
valid_accs.append(xinfo['valid-accuracy'])
else:
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
valid_accs.append(xinfo['valid-accuracy'])
test_accs.append(xinfo['test-accuracy'])
valid_str = '{:.2f}$\pm${:.2f}'.format(np.mean(valid_accs), np.std(valid_accs))
test_str = '{:.2f}$\pm${:.2f}'.format(np.mean(test_accs), np.std(test_accs))
return valid_str, test_str
y_min_s = {('cifar10', 'tss'): 90,
@ -69,11 +88,11 @@ y_min_s = {('cifar10', 'tss'): 90,
('ImageNet16-120', 'tss'): 36,
('ImageNet16-120', 'sss'): 40}
y_max_s = {('cifar10', 'tss'): 94.5,
y_max_s = {('cifar10', 'tss'): 94.3,
('cifar10', 'sss'): 93.3,
('cifar100', 'tss'): 72,
('cifar100', 'sss'): 70,
('ImageNet16-120', 'tss'): 44,
('cifar100', 'tss'): 72.5,
('cifar100', 'sss'): 70.5,
('ImageNet16-120', 'tss'): 46,
('ImageNet16-120', 'sss'): 46}
x_axis_s = {('cifar10', 'tss'): 200,
@ -87,6 +106,7 @@ name2label = {'cifar10': 'CIFAR-10',
'cifar100': 'CIFAR-100',
'ImageNet16-120': 'ImageNet-16-120'}
def visualize_curve(api, vis_save_dir, search_space):
vis_save_dir = vis_save_dir.resolve()
vis_save_dir.mkdir(parents=True, exist_ok=True)
@ -106,11 +126,13 @@ def visualize_curve(api, vis_save_dir, search_space):
ax.set_ylim(y_min_s[(xdataset, search_space)],
y_max_s[(xdataset, search_space)])
for idx, (alg, data) in enumerate(alg2data.items()):
print('{:} plot alg : {:}'.format(time_string(), alg))
accuracies = []
for ticket in time_tickets:
accuracy = query_performance(api, data, xdataset, ticket)
accuracy, accuracy_std = query_performance(api, data, xdataset, ticket)
accuracies.append(accuracy)
valid_str, test_str = show_valid_test(api, data, xdataset)
# print('{:} plot alg : {:10s}, final accuracy = {:.2f}$\pm${:.2f}'.format(time_string(), alg, accuracy, accuracy_std))
print('{:} plot alg : {:10s} | validation = {:} | test = {:}'.format(time_string(), alg, valid_str, test_str))
alg2accuracies[alg] = accuracies
ax.plot([x/100 for x in time_tickets], accuracies, c=colors[idx], label='{:}'.format(alg))
ax.set_xlabel('Estimated wall-clock time (1e2 seconds)', fontsize=LabelSize)

View File

@ -0,0 +1,180 @@
###############################################################
# NATS-Bench (https://arxiv.org/pdf/2009.00437.pdf) #
# The code to draw Figure 7 in our paper. #
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
###############################################################
# Usage: python exps/NATS-Bench/draw-fig7.py #
###############################################################
import os, gc, sys, time, torch, argparse
import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict, OrderedDict
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
matplotlib.use('agg')
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from nats_bench import create
from log_utils import time_string
def get_valid_test_acc(api, arch, dataset):
is_size_space = api.search_space_name == 'size'
if dataset == 'cifar10':
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
test_acc = xinfo['test-accuracy']
xinfo = api.get_more_info(arch, dataset='cifar10-valid', hp=90 if is_size_space else 200, is_random=False)
valid_acc = xinfo['valid-accuracy']
else:
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
valid_acc = xinfo['valid-accuracy']
test_acc = xinfo['test-accuracy']
return valid_acc, test_acc, 'validation = {:.2f}, test = {:.2f}\n'.format(valid_acc, test_acc)
def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-WARM0.3'):
ss_dir = '{:}-{:}'.format(root_dir, search_space)
alg2name, alg2path = OrderedDict(), OrderedDict()
seeds = [777, 888, 999]
print('\n[fetch data] from {:} on {:}'.format(search_space, dataset))
if search_space == 'tss':
alg2name['GDAS'] = 'gdas-affine0_BN0-None'
alg2name['RSPS'] = 'random-affine0_BN0-None'
alg2name['DARTS (1st)'] = 'darts-v1-affine0_BN0-None'
alg2name['DARTS (2nd)'] = 'darts-v2-affine0_BN0-None'
alg2name['ENAS'] = 'enas-affine0_BN0-None'
alg2name['SETN'] = 'setn-affine0_BN0-None'
else:
alg2name['channel-wise interpolation'] = 'tas-affine0_BN0-AWD0.001{:}'.format(suffix)
alg2name['masking + Gumbel-Softmax'] = 'mask_gumbel-affine0_BN0-AWD0.001{:}'.format(suffix)
alg2name['masking + sampling'] = 'mask_rl-affine0_BN0-AWD0.0{:}'.format(suffix)
for alg, name in alg2name.items():
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth')
alg2data = OrderedDict()
for alg, path in alg2path.items():
alg2data[alg], ok_num = [], 0
for seed in seeds:
xpath = path.format(seed)
if os.path.isfile(xpath):
ok_num += 1
else:
print('This is an invalid path : {:}'.format(xpath))
continue
data = torch.load(xpath, map_location=torch.device('cpu'))
try:
data = torch.load(data['last_checkpoint'], map_location=torch.device('cpu'))
except:
xpath = str(data['last_checkpoint']).split('E100-')
if len(xpath) == 2 and os.path.isfile(xpath[0] + xpath[1]):
xpath = xpath[0] + xpath[1]
elif 'fbv2' in str(data['last_checkpoint']):
xpath = str(data['last_checkpoint']).replace('fbv2', 'mask_gumbel')
elif 'tunas' in str(data['last_checkpoint']):
xpath = str(data['last_checkpoint']).replace('tunas', 'mask_rl')
else:
raise ValueError('Invalid path: {:}'.format(data['last_checkpoint']))
data = torch.load(xpath, map_location=torch.device('cpu'))
alg2data[alg].append(data['genotypes'])
print('This algorithm : {:} has {:} valid ckps.'.format(alg, ok_num))
assert ok_num > 0, 'Must have at least 1 valid ckps.'
return alg2data
y_min_s = {('cifar10', 'tss'): 90,
('cifar10', 'sss'): 92,
('cifar100', 'tss'): 65,
('cifar100', 'sss'): 65,
('ImageNet16-120', 'tss'): 36,
('ImageNet16-120', 'sss'): 40}
y_max_s = {('cifar10', 'tss'): 94.5,
('cifar10', 'sss'): 93.3,
('cifar100', 'tss'): 72,
('cifar100', 'sss'): 70,
('ImageNet16-120', 'tss'): 44,
('ImageNet16-120', 'sss'): 46}
name2label = {'cifar10': 'CIFAR-10',
'cifar100': 'CIFAR-100',
'ImageNet16-120': 'ImageNet-16-120'}
name2suffix = {('sss', 'warm'): '-WARM0.3',
('sss', 'none'): '-WARMNone',
('tss', 'none') : None,
('tss', None) : None}
def visualize_curve(api, vis_save_dir, search_space, suffix):
vis_save_dir = vis_save_dir.resolve()
vis_save_dir.mkdir(parents=True, exist_ok=True)
dpi, width, height = 250, 5200, 1400
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 16, 16
def sub_plot_fn(ax, dataset):
print('{:} plot {:10s}'.format(time_string(), dataset))
alg2data = fetch_data(search_space=search_space, dataset=dataset, suffix=name2suffix[(search_space, suffix)])
alg2accuracies = OrderedDict()
epochs = 100
colors = ['b', 'g', 'c', 'm', 'y', 'r']
ax.set_xlim(0, epochs)
# ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)])
for idx, (alg, data) in enumerate(alg2data.items()):
xs, accuracies = [], []
for iepoch in range(epochs + 1):
try:
structures, accs = [_[iepoch-1] for _ in data], []
except:
raise ValueError('This alg {:} on {:} has invalid checkpoints.'.format(alg, dataset))
for structure in structures:
info = api.get_more_info(structure, dataset=dataset, hp=90 if api.search_space_name == 'size' else 200, is_random=False)
accs.append(info['test-accuracy'])
accuracies.append(sum(accs)/len(accs))
xs.append(iepoch)
alg2accuracies[alg] = accuracies
ax.plot(xs, accuracies, c=colors[idx], label='{:}'.format(alg))
ax.set_xlabel('The searching epoch', fontsize=LabelSize)
ax.set_ylabel('Test accuracy on {:}'.format(name2label[dataset]), fontsize=LabelSize)
ax.set_title('Searching results on {:}'.format(name2label[dataset]), fontsize=LabelSize+4)
structures, valid_accs, test_accs = [_[epochs-1] for _ in data], [], []
print('{:} plot alg : {:} -- final {:} architectures.'.format(time_string(), alg, len(structures)))
for arch in structures:
valid_acc, test_acc, _ = get_valid_test_acc(api, arch, dataset)
test_accs.append(test_acc)
valid_accs.append(valid_acc)
print('{:} plot alg : {:} -- validation: {:.2f}$\pm${:.2f} -- test: {:.2f}$\pm${:.2f}'.format(
time_string(), alg, np.mean(valid_accs), np.std(valid_accs), np.mean(test_accs), np.std(test_accs)))
ax.legend(loc=4, fontsize=LegendFontsize)
fig, axs = plt.subplots(1, 3, figsize=figsize)
datasets = ['cifar10', 'cifar100', 'ImageNet16-120']
for dataset, ax in zip(datasets, axs):
sub_plot_fn(ax, dataset)
print('sub-plot {:} on {:} done.'.format(dataset, search_space))
save_path = (vis_save_dir / '{:}-ws-{:}-curve.png'.format(search_space, suffix)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
plt.close('all')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NATS-Bench', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/nas-algos', help='Folder to save checkpoints and log.')
args = parser.parse_args()
save_dir = Path(args.save_dir)
api_tss = create(None, 'tss', fast_mode=True, verbose=False)
visualize_curve(api_tss, save_dir, 'tss', None)
api_sss = create(None, 'sss', fast_mode=True, verbose=False)
visualize_curve(api_sss, save_dir, 'sss', 'warm')
visualize_curve(api_sss, save_dir, 'sss', 'none')

View File

@ -0,0 +1,85 @@
###############################################################
# NATS-Bench (https://arxiv.org/pdf/2009.00437.pdf) #
# The code to draw some results in Table 4 in our paper. #
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
###############################################################
# Usage: python exps/NATS-Bench/draw-table.py #
###############################################################
import os, gc, sys, time, torch, argparse
import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict, OrderedDict
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
matplotlib.use('agg')
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from nats_bench import create
from log_utils import time_string
def get_valid_test_acc(api, arch, dataset):
is_size_space = api.search_space_name == 'size'
if dataset == 'cifar10':
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
test_acc = xinfo['test-accuracy']
xinfo = api.get_more_info(arch, dataset='cifar10-valid', hp=90 if is_size_space else 200, is_random=False)
valid_acc = xinfo['valid-accuracy']
else:
xinfo = api.get_more_info(arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
valid_acc = xinfo['valid-accuracy']
test_acc = xinfo['test-accuracy']
return valid_acc, test_acc, 'validation = {:.2f}, test = {:.2f}\n'.format(valid_acc, test_acc)
def show_valid_test(api, arch):
is_size_space = api.search_space_name == 'size'
final_str = ''
for dataset in ['cifar10', 'cifar100', 'ImageNet16-120']:
valid_acc, test_acc, perf_str = get_valid_test_acc(api, arch, dataset)
final_str += '{:} : {:}\n'.format(dataset, perf_str)
return final_str
def find_best_valid(api, dataset):
all_valid_accs, all_test_accs = [], []
for index, arch in enumerate(api):
# import pdb; pdb.set_trace()
valid_acc, test_acc, perf_str = get_valid_test_acc(api, index, dataset)
all_valid_accs.append((index, valid_acc))
all_test_accs.append((index, test_acc))
best_valid_index = sorted(all_valid_accs, key=lambda x: -x[1])[0][0]
best_test_index = sorted(all_test_accs, key=lambda x: -x[1])[0][0]
print('-' * 50 + '{:10s}'.format(dataset) + '-' * 50)
print('Best ({:}) architecture on validation: {:}'.format(best_valid_index, api[best_valid_index]))
print('Best ({:}) architecture on test: {:}'.format(best_test_index, api[best_test_index]))
_, _, perf_str = get_valid_test_acc(api, best_valid_index, dataset)
print('using validation ::: {:}'.format(perf_str))
_, _, perf_str = get_valid_test_acc(api, best_test_index, dataset)
print('using test ::: {:}'.format(perf_str))
if __name__ == '__main__':
api_tss = create(None, 'tss', fast_mode=False, verbose=False)
resnet = '|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|'
resnet_index = api_tss.query_index_by_arch(resnet)
print(show_valid_test(api_tss, resnet_index))
for dataset in ['cifar10', 'cifar100', 'ImageNet16-120']:
find_best_valid(api_tss, dataset)
largest = '64:64:64:64:64'
largest_index = api_sss.query_index_by_arch(largest)
print(show_valid_test(api_sss, largest_index))
for dataset in ['cifar10', 'cifar100', 'ImageNet16-120']:
find_best_valid(api_sss, dataset)

View File

@ -92,8 +92,8 @@ class NATStopology(NASBenchMetaAPI):
file_path_or_dict = os.path.join(
os.environ['TORCH_HOME'], '{:}.{:}'.format(
ALL_BASE_NAMES[-1], PICKLE_EXT))
print('{:} Try to use the default NATS-Bench (topology) path '
'from {:}.'.format(time_string(), file_path_or_dict))
print('{:} Try to use the default NATS-Bench (topology) path from '
'fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode, file_path_or_dict))
if isinstance(file_path_or_dict, str):
file_path_or_dict = str(file_path_or_dict)
if verbose: