Update visualization codes for NATS-Bench

This commit is contained in:
D-X-Y 2020-12-02 08:05:03 +08:00
parent 46b92e37e2
commit bda30c7098
2 changed files with 103 additions and 70 deletions

View File

@ -43,20 +43,14 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None):
# alg2name['REINFORCE'] = 'REINFORCE-0.01'
# alg2name['RANDOM'] = 'RANDOM'
# alg2name['BOHB'] = 'BOHB'
if dataset == 'cifar10':
suffixes = ['-T200000', '-T200000-FULL']
elif dataset == 'cifar100':
suffixes = ['-T40000', '-T40000-FULL']
elif search_space == 'tss':
suffixes = ['-T120000', '-T120000-FULL']
elif search_space == 'sss':
suffixes = ['-T60000', '-T60000-FULL']
else:
raise ValueError('Unkonwn dataset : {:}'.format(dataset))
if search_space == 'tss':
hp = '$\mathcal{H}^{1}$'
if dataset == 'cifar10':
suffixes = ['-T800000', '-T800000-FULL']
elif search_space == 'sss':
hp = '$\mathcal{H}^{2}$'
if dataset == 'cifar10':
suffixes = ['-T200000', '-T200000-FULL']
else:
raise ValueError('Unkonwn search space: {:}'.format(search_space))
@ -92,21 +86,21 @@ def query_performance(api, data, dataset, ticket):
return np.mean(results), np.std(results)
y_min_s = {('cifar10', 'tss'): 90,
('cifar10', 'sss'): 90,
y_min_s = {('cifar10', 'tss'): 91,
('cifar10', 'sss'): 91,
('cifar100', 'tss'): 65,
('cifar100', 'sss'): 65,
('ImageNet16-120', 'tss'): 36,
('ImageNet16-120', 'sss'): 40}
y_max_s = {('cifar10', 'tss'): 94.5,
('cifar10', 'sss'): 94.5,
('cifar10', 'sss'): 93.5,
('cifar100', 'tss'): 72.5,
('cifar100', 'sss'): 70.5,
('ImageNet16-120', 'tss'): 46,
('ImageNet16-120', 'sss'): 46}
x_axis_s = {('cifar10', 'tss'): 200000,
x_axis_s = {('cifar10', 'tss'): 800000,
('cifar10', 'sss'): 200000,
('cifar100', 'tss'): 400,
('cifar100', 'sss'): 400,
@ -124,9 +118,9 @@ def visualize_curve(api_dict, vis_save_dir):
vis_save_dir = vis_save_dir.resolve()
vis_save_dir.mkdir(parents=True, exist_ok=True)
dpi, width, height = 250, 4000, 2400
dpi, width, height = 250, 5000, 2000
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 16, 16
LabelSize, LegendFontsize = 28, 28
def sub_plot_fn(ax, search_space, dataset):
max_time = x_axis_s[(dataset, search_space)]
@ -137,6 +131,11 @@ def visualize_curve(api_dict, vis_save_dir):
ax.set_xlim(0, x_axis_s[(dataset, search_space)])
ax.set_ylim(y_min_s[(dataset, search_space)],
y_max_s[(dataset, search_space)])
for tick in ax.get_xticklabels():
tick.set_rotation(25)
tick.set_fontsize(LabelSize - 6)
for tick in ax.get_yticklabels():
tick.set_fontsize(LabelSize - 6)
for idx, (alg, xdata) in enumerate(alg2data.items()):
accuracies = []
for ticket in time_tickets:
@ -150,8 +149,8 @@ def visualize_curve(api_dict, vis_save_dir):
ax.plot(time_tickets, accuracies, c=xdata['color'], linestyle=xdata['linestyle'], label='{:}'.format(alg))
ax.set_xlabel('Estimated wall-clock time', fontsize=LabelSize)
ax.set_ylabel('Test accuracy', fontsize=LabelSize)
ax.set_title(r'Searching results on {:} for {:}'.format(name2label[dataset], spaces2latex[search_space]),
fontsize=LabelSize+4)
ax.set_title(r'Results on {:} over {:}'.format(name2label[dataset], spaces2latex[search_space]),
fontsize=LabelSize)
ax.legend(loc=4, fontsize=LegendFontsize)
fig, axs = plt.subplots(1, 2, figsize=figsize)
@ -165,7 +164,7 @@ def visualize_curve(api_dict, vis_save_dir):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/nas-algos-vs-h', help='Folder to save checkpoints and log.')
parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/nas-algos-vs-h', help='Folder to save checkpoints and log.')
args = parser.parse_args()
save_dir = Path(args.save_dir)

View File

@ -11,7 +11,7 @@ import scipy
import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict
from collections import defaultdict, OrderedDict
from copy import deepcopy
from pathlib import Path
import matplotlib
@ -28,69 +28,103 @@ from models import get_cell_based_tiny_net
from nats_bench import create
def visualize_relative_info(api, vis_save_dir, indicator):
name2label = {'cifar10': 'CIFAR-10',
'cifar100': 'CIFAR-100',
'ImageNet16-120': 'ImageNet-16-120'}
def visualize_relative_info(vis_save_dir, search_space, indicator, topk):
vis_save_dir = vis_save_dir.resolve()
# print ('{:} start to visualize {:} information'.format(time_string(), api))
print ('{:} start to visualize {:} with top-{:} information'.format(time_string(), search_space, topk))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cache_file_path = vis_save_dir / 'cache-{:}-info.pth'.format(search_space)
datasets = ['cifar10', 'cifar100', 'ImageNet16-120']
if not cache_file_path.exists():
api = create(None, search_space, fast_mode=False, verbose=False)
all_infos = OrderedDict()
for index in range(len(api)):
all_info = OrderedDict()
for dataset in datasets:
info_less = api.get_more_info(index, dataset, hp='12', is_random=False)
info_more = api.get_more_info(index, dataset, hp=api.full_train_epochs, is_random=False)
all_info[dataset] = dict(less=info_less['test-accuracy'],
more=info_more['test-accuracy'])
all_infos[index] = all_info
torch.save(all_infos, cache_file_path)
print ('{:} save all cache data into {:}'.format(time_string(), cache_file_path))
else:
api = create(None, search_space, fast_mode=True, verbose=False)
all_infos = torch.load(cache_file_path)
cifar010_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar10', indicator)
cifar100_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar100', indicator)
imagenet_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('ImageNet16-120', indicator)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
indexes = list(range(len(cifar010_info['params'])))
print ('{:} start to visualize relative ranking'.format(time_string()))
cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info['test_accs'][i])
cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info['test_accs'][i])
imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info['test_accs'][i])
cifar100_labels, imagenet_labels = [], []
for idx in cifar010_ord_indexes:
cifar100_labels.append( cifar100_ord_indexes.index(idx) )
imagenet_labels.append( imagenet_ord_indexes.index(idx) )
print ('{:} prepare data done.'.format(time_string()))
dpi, width, height = 200, 1400, 800
dpi, width, height = 250, 5000, 1300
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 18, 12
resnet_scale, resnet_alpha = 120, 0.5
LabelSize, LegendFontsize = 16, 16
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
plt.xlim(min(indexes), max(indexes))
plt.ylim(min(indexes), max(indexes))
# plt.ylabel('y').set_rotation(30)
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//3), fontsize=LegendFontsize, rotation='vertical')
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//5), fontsize=LegendFontsize)
ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8)
ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8)
ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8)
ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10')
ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-100')
ax.scatter([-1], [-1], marker='*', s=100, c='tab:red' , label='ImageNet-16-120')
plt.grid(zorder=0)
ax.set_axisbelow(True)
plt.legend(loc=0, fontsize=LegendFontsize)
ax.set_xlabel('architecture ranking in CIFAR-10', fontsize=LabelSize)
ax.set_ylabel('architecture ranking', fontsize=LabelSize)
save_path = (vis_save_dir / '{:}-relative-rank.pdf'.format(indicator)).resolve()
fig, axs = plt.subplots(1, 3, figsize=figsize)
datasets = ['cifar10', 'cifar100', 'ImageNet16-120']
def sub_plot_fn(ax, dataset, indicator):
performances = []
# pickup top 10% architectures
for _index in range(len(api)):
performances.append((all_infos[_index][dataset][indicator], _index))
performances = sorted(performances, reverse=True)
performances = performances[: int(len(api) * topk * 0.01)]
selected_indexes = [x[1] for x in performances]
print('{:} plot {:10s} with {:}, {:} architectures'.format(time_string(), dataset, indicator, len(selected_indexes)))
standard_scores = []
random_scores = []
for idx in selected_indexes:
standard_scores.append(
api.get_more_info(idx, dataset, hp=api.full_train_epochs if indicator == 'more' else '12', is_random=False)['test-accuracy'])
random_scores.append(
api.get_more_info(idx, dataset, hp=api.full_train_epochs if indicator == 'more' else '12', is_random=True)['test-accuracy'])
indexes = list(range(len(selected_indexes)))
standard_indexes = sorted(indexes, key=lambda i: standard_scores[i])
random_indexes = sorted(indexes, key=lambda i: random_scores[i])
random_labels = []
for idx in standard_indexes:
random_labels.append(random_indexes.index(idx))
for tick in ax.get_xticklabels():
tick.set_fontsize(LabelSize - 3)
for tick in ax.get_yticklabels():
tick.set_rotation(25)
tick.set_fontsize(LabelSize - 3)
ax.set_xlim(0, len(indexes))
ax.set_ylim(0, len(indexes))
ax.set_yticks(np.arange(min(indexes), max(indexes), max(indexes)//3))
ax.set_xticks(np.arange(min(indexes), max(indexes), max(indexes)//5))
ax.scatter(indexes, random_labels, marker='^', s=0.5, c='tab:green', alpha=0.8)
ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8)
ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='Average Over Multi-Trials')
ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='Randomly Selected Trial')
coef, p = scipy.stats.kendalltau(standard_scores, random_scores)
ax.set_xlabel('architecture ranking in {:}'.format(name2label[dataset]), fontsize=LabelSize)
if dataset == 'cifar10':
ax.set_ylabel('architecture ranking', fontsize=LabelSize)
ax.legend(loc=4, fontsize=LegendFontsize)
return coef
for dataset, ax in zip(datasets, axs):
rank_coef = sub_plot_fn(ax, dataset, indicator)
print('sub-plot {:} on {:} done, the ranking coefficient is {:.4f}.'.format(dataset, search_space, rank_coef))
save_path = (vis_save_dir / '{:}-rank-{:}-top{:}.pdf'.format(search_space, indicator, topk)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
save_path = (vis_save_dir / '{:}-relative-rank.png'.format(indicator)).resolve()
save_path = (vis_save_dir / '{:}-rank-{:}-top{:}.png'.format(search_space, indicator, topk)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
print('Save into {:}'.format(save_path))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NATS-Bench', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/rank-stability', help='Folder to save checkpoints and log.')
# use for train the model
args = parser.parse_args()
to_save_dir = Path(args.save_dir)
# Figure 2
visualize_relative_info(None, to_save_dir, 'tss')
visualize_relative_info(None, to_save_dir, 'sss')
for topk in [1, 5, 10, 20]:
visualize_relative_info(to_save_dir, 'tss', 'more', topk)
visualize_relative_info(to_save_dir, 'sss', 'less', topk)
print ('{:} : complete running this file : {:}'.format(time_string(), __file__))