Update visualization codes for NATS-Bench

This commit is contained in:
D-X-Y 2020-12-02 08:05:03 +08:00
parent 46b92e37e2
commit bda30c7098
2 changed files with 103 additions and 70 deletions

View File

@ -43,20 +43,14 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None):
# alg2name['REINFORCE'] = 'REINFORCE-0.01' # alg2name['REINFORCE'] = 'REINFORCE-0.01'
# alg2name['RANDOM'] = 'RANDOM' # alg2name['RANDOM'] = 'RANDOM'
# alg2name['BOHB'] = 'BOHB' # alg2name['BOHB'] = 'BOHB'
if dataset == 'cifar10':
suffixes = ['-T200000', '-T200000-FULL']
elif dataset == 'cifar100':
suffixes = ['-T40000', '-T40000-FULL']
elif search_space == 'tss':
suffixes = ['-T120000', '-T120000-FULL']
elif search_space == 'sss':
suffixes = ['-T60000', '-T60000-FULL']
else:
raise ValueError('Unkonwn dataset : {:}'.format(dataset))
if search_space == 'tss': if search_space == 'tss':
hp = '$\mathcal{H}^{1}$' hp = '$\mathcal{H}^{1}$'
if dataset == 'cifar10':
suffixes = ['-T800000', '-T800000-FULL']
elif search_space == 'sss': elif search_space == 'sss':
hp = '$\mathcal{H}^{2}$' hp = '$\mathcal{H}^{2}$'
if dataset == 'cifar10':
suffixes = ['-T200000', '-T200000-FULL']
else: else:
raise ValueError('Unkonwn search space: {:}'.format(search_space)) raise ValueError('Unkonwn search space: {:}'.format(search_space))
@ -92,21 +86,21 @@ def query_performance(api, data, dataset, ticket):
return np.mean(results), np.std(results) return np.mean(results), np.std(results)
y_min_s = {('cifar10', 'tss'): 90, y_min_s = {('cifar10', 'tss'): 91,
('cifar10', 'sss'): 90, ('cifar10', 'sss'): 91,
('cifar100', 'tss'): 65, ('cifar100', 'tss'): 65,
('cifar100', 'sss'): 65, ('cifar100', 'sss'): 65,
('ImageNet16-120', 'tss'): 36, ('ImageNet16-120', 'tss'): 36,
('ImageNet16-120', 'sss'): 40} ('ImageNet16-120', 'sss'): 40}
y_max_s = {('cifar10', 'tss'): 94.5, y_max_s = {('cifar10', 'tss'): 94.5,
('cifar10', 'sss'): 94.5, ('cifar10', 'sss'): 93.5,
('cifar100', 'tss'): 72.5, ('cifar100', 'tss'): 72.5,
('cifar100', 'sss'): 70.5, ('cifar100', 'sss'): 70.5,
('ImageNet16-120', 'tss'): 46, ('ImageNet16-120', 'tss'): 46,
('ImageNet16-120', 'sss'): 46} ('ImageNet16-120', 'sss'): 46}
x_axis_s = {('cifar10', 'tss'): 200000, x_axis_s = {('cifar10', 'tss'): 800000,
('cifar10', 'sss'): 200000, ('cifar10', 'sss'): 200000,
('cifar100', 'tss'): 400, ('cifar100', 'tss'): 400,
('cifar100', 'sss'): 400, ('cifar100', 'sss'): 400,
@ -124,9 +118,9 @@ def visualize_curve(api_dict, vis_save_dir):
vis_save_dir = vis_save_dir.resolve() vis_save_dir = vis_save_dir.resolve()
vis_save_dir.mkdir(parents=True, exist_ok=True) vis_save_dir.mkdir(parents=True, exist_ok=True)
dpi, width, height = 250, 4000, 2400 dpi, width, height = 250, 5000, 2000
figsize = width / float(dpi), height / float(dpi) figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 16, 16 LabelSize, LegendFontsize = 28, 28
def sub_plot_fn(ax, search_space, dataset): def sub_plot_fn(ax, search_space, dataset):
max_time = x_axis_s[(dataset, search_space)] max_time = x_axis_s[(dataset, search_space)]
@ -137,6 +131,11 @@ def visualize_curve(api_dict, vis_save_dir):
ax.set_xlim(0, x_axis_s[(dataset, search_space)]) ax.set_xlim(0, x_axis_s[(dataset, search_space)])
ax.set_ylim(y_min_s[(dataset, search_space)], ax.set_ylim(y_min_s[(dataset, search_space)],
y_max_s[(dataset, search_space)]) y_max_s[(dataset, search_space)])
for tick in ax.get_xticklabels():
tick.set_rotation(25)
tick.set_fontsize(LabelSize - 6)
for tick in ax.get_yticklabels():
tick.set_fontsize(LabelSize - 6)
for idx, (alg, xdata) in enumerate(alg2data.items()): for idx, (alg, xdata) in enumerate(alg2data.items()):
accuracies = [] accuracies = []
for ticket in time_tickets: for ticket in time_tickets:
@ -150,8 +149,8 @@ def visualize_curve(api_dict, vis_save_dir):
ax.plot(time_tickets, accuracies, c=xdata['color'], linestyle=xdata['linestyle'], label='{:}'.format(alg)) ax.plot(time_tickets, accuracies, c=xdata['color'], linestyle=xdata['linestyle'], label='{:}'.format(alg))
ax.set_xlabel('Estimated wall-clock time', fontsize=LabelSize) ax.set_xlabel('Estimated wall-clock time', fontsize=LabelSize)
ax.set_ylabel('Test accuracy', fontsize=LabelSize) ax.set_ylabel('Test accuracy', fontsize=LabelSize)
ax.set_title(r'Searching results on {:} for {:}'.format(name2label[dataset], spaces2latex[search_space]), ax.set_title(r'Results on {:} over {:}'.format(name2label[dataset], spaces2latex[search_space]),
fontsize=LabelSize+4) fontsize=LabelSize)
ax.legend(loc=4, fontsize=LegendFontsize) ax.legend(loc=4, fontsize=LegendFontsize)
fig, axs = plt.subplots(1, 2, figsize=figsize) fig, axs = plt.subplots(1, 2, figsize=figsize)
@ -165,7 +164,7 @@ def visualize_curve(api_dict, vis_save_dir):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = argparse.ArgumentParser(description='NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/nas-algos-vs-h', help='Folder to save checkpoints and log.') parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/nas-algos-vs-h', help='Folder to save checkpoints and log.')
args = parser.parse_args() args = parser.parse_args()
save_dir = Path(args.save_dir) save_dir = Path(args.save_dir)

View File

@ -11,7 +11,7 @@ import scipy
import numpy as np import numpy as np
from typing import List, Text, Dict, Any from typing import List, Text, Dict, Any
from shutil import copyfile from shutil import copyfile
from collections import defaultdict from collections import defaultdict, OrderedDict
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
import matplotlib import matplotlib
@ -28,69 +28,103 @@ from models import get_cell_based_tiny_net
from nats_bench import create from nats_bench import create
def visualize_relative_info(api, vis_save_dir, indicator): name2label = {'cifar10': 'CIFAR-10',
'cifar100': 'CIFAR-100',
'ImageNet16-120': 'ImageNet-16-120'}
def visualize_relative_info(vis_save_dir, search_space, indicator, topk):
vis_save_dir = vis_save_dir.resolve() vis_save_dir = vis_save_dir.resolve()
# print ('{:} start to visualize {:} information'.format(time_string(), api)) print ('{:} start to visualize {:} with top-{:} information'.format(time_string(), search_space, topk))
vis_save_dir.mkdir(parents=True, exist_ok=True) vis_save_dir.mkdir(parents=True, exist_ok=True)
cache_file_path = vis_save_dir / 'cache-{:}-info.pth'.format(search_space)
datasets = ['cifar10', 'cifar100', 'ImageNet16-120']
if not cache_file_path.exists():
api = create(None, search_space, fast_mode=False, verbose=False)
all_infos = OrderedDict()
for index in range(len(api)):
all_info = OrderedDict()
for dataset in datasets:
info_less = api.get_more_info(index, dataset, hp='12', is_random=False)
info_more = api.get_more_info(index, dataset, hp=api.full_train_epochs, is_random=False)
all_info[dataset] = dict(less=info_less['test-accuracy'],
more=info_more['test-accuracy'])
all_infos[index] = all_info
torch.save(all_infos, cache_file_path)
print ('{:} save all cache data into {:}'.format(time_string(), cache_file_path))
else:
api = create(None, search_space, fast_mode=True, verbose=False)
all_infos = torch.load(cache_file_path)
cifar010_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar10', indicator)
cifar100_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar100', indicator)
imagenet_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('ImageNet16-120', indicator)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
indexes = list(range(len(cifar010_info['params'])))
print ('{:} start to visualize relative ranking'.format(time_string())) dpi, width, height = 250, 5000, 1300
cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info['test_accs'][i])
cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info['test_accs'][i])
imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info['test_accs'][i])
cifar100_labels, imagenet_labels = [], []
for idx in cifar010_ord_indexes:
cifar100_labels.append( cifar100_ord_indexes.index(idx) )
imagenet_labels.append( imagenet_ord_indexes.index(idx) )
print ('{:} prepare data done.'.format(time_string()))
dpi, width, height = 200, 1400, 800
figsize = width / float(dpi), height / float(dpi) figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 18, 12 LabelSize, LegendFontsize = 16, 16
resnet_scale, resnet_alpha = 120, 0.5
fig = plt.figure(figsize=figsize) fig, axs = plt.subplots(1, 3, figsize=figsize)
ax = fig.add_subplot(111) datasets = ['cifar10', 'cifar100', 'ImageNet16-120']
plt.xlim(min(indexes), max(indexes))
plt.ylim(min(indexes), max(indexes)) def sub_plot_fn(ax, dataset, indicator):
# plt.ylabel('y').set_rotation(30) performances = []
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//3), fontsize=LegendFontsize, rotation='vertical') # pickup top 10% architectures
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//5), fontsize=LegendFontsize) for _index in range(len(api)):
ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8) performances.append((all_infos[_index][dataset][indicator], _index))
ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8) performances = sorted(performances, reverse=True)
ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8) performances = performances[: int(len(api) * topk * 0.01)]
ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10') selected_indexes = [x[1] for x in performances]
ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-100') print('{:} plot {:10s} with {:}, {:} architectures'.format(time_string(), dataset, indicator, len(selected_indexes)))
ax.scatter([-1], [-1], marker='*', s=100, c='tab:red' , label='ImageNet-16-120') standard_scores = []
plt.grid(zorder=0) random_scores = []
ax.set_axisbelow(True) for idx in selected_indexes:
plt.legend(loc=0, fontsize=LegendFontsize) standard_scores.append(
ax.set_xlabel('architecture ranking in CIFAR-10', fontsize=LabelSize) api.get_more_info(idx, dataset, hp=api.full_train_epochs if indicator == 'more' else '12', is_random=False)['test-accuracy'])
ax.set_ylabel('architecture ranking', fontsize=LabelSize) random_scores.append(
save_path = (vis_save_dir / '{:}-relative-rank.pdf'.format(indicator)).resolve() api.get_more_info(idx, dataset, hp=api.full_train_epochs if indicator == 'more' else '12', is_random=True)['test-accuracy'])
indexes = list(range(len(selected_indexes)))
standard_indexes = sorted(indexes, key=lambda i: standard_scores[i])
random_indexes = sorted(indexes, key=lambda i: random_scores[i])
random_labels = []
for idx in standard_indexes:
random_labels.append(random_indexes.index(idx))
for tick in ax.get_xticklabels():
tick.set_fontsize(LabelSize - 3)
for tick in ax.get_yticklabels():
tick.set_rotation(25)
tick.set_fontsize(LabelSize - 3)
ax.set_xlim(0, len(indexes))
ax.set_ylim(0, len(indexes))
ax.set_yticks(np.arange(min(indexes), max(indexes), max(indexes)//3))
ax.set_xticks(np.arange(min(indexes), max(indexes), max(indexes)//5))
ax.scatter(indexes, random_labels, marker='^', s=0.5, c='tab:green', alpha=0.8)
ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8)
ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='Average Over Multi-Trials')
ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='Randomly Selected Trial')
coef, p = scipy.stats.kendalltau(standard_scores, random_scores)
ax.set_xlabel('architecture ranking in {:}'.format(name2label[dataset]), fontsize=LabelSize)
if dataset == 'cifar10':
ax.set_ylabel('architecture ranking', fontsize=LabelSize)
ax.legend(loc=4, fontsize=LegendFontsize)
return coef
for dataset, ax in zip(datasets, axs):
rank_coef = sub_plot_fn(ax, dataset, indicator)
print('sub-plot {:} on {:} done, the ranking coefficient is {:.4f}.'.format(dataset, search_space, rank_coef))
save_path = (vis_save_dir / '{:}-rank-{:}-top{:}.pdf'.format(search_space, indicator, topk)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
save_path = (vis_save_dir / '{:}-relative-rank.png'.format(indicator)).resolve() save_path = (vis_save_dir / '{:}-rank-{:}-top{:}.png'.format(search_space, indicator, topk)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path)) print('Save into {:}'.format(save_path))
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NATS-Bench', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = argparse.ArgumentParser(description='NATS-Bench', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/rank-stability', help='Folder to save checkpoints and log.') parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/rank-stability', help='Folder to save checkpoints and log.')
# use for train the model
args = parser.parse_args() args = parser.parse_args()
to_save_dir = Path(args.save_dir) to_save_dir = Path(args.save_dir)
# Figure 2 for topk in [1, 5, 10, 20]:
visualize_relative_info(None, to_save_dir, 'tss') visualize_relative_info(to_save_dir, 'tss', 'more', topk)
visualize_relative_info(None, to_save_dir, 'sss') visualize_relative_info(to_save_dir, 'sss', 'less', topk)
print ('{:} : complete running this file : {:}'.format(time_string(), __file__))