diff --git a/exps/NATS-Bench/draw-fig8.py b/exps/NATS-Bench/draw-fig8.py index c3a5b72..d2a4ad5 100644 --- a/exps/NATS-Bench/draw-fig8.py +++ b/exps/NATS-Bench/draw-fig8.py @@ -43,20 +43,14 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None): # alg2name['REINFORCE'] = 'REINFORCE-0.01' # alg2name['RANDOM'] = 'RANDOM' # alg2name['BOHB'] = 'BOHB' - if dataset == 'cifar10': - suffixes = ['-T200000', '-T200000-FULL'] - elif dataset == 'cifar100': - suffixes = ['-T40000', '-T40000-FULL'] - elif search_space == 'tss': - suffixes = ['-T120000', '-T120000-FULL'] - elif search_space == 'sss': - suffixes = ['-T60000', '-T60000-FULL'] - else: - raise ValueError('Unkonwn dataset : {:}'.format(dataset)) if search_space == 'tss': hp = '$\mathcal{H}^{1}$' + if dataset == 'cifar10': + suffixes = ['-T800000', '-T800000-FULL'] elif search_space == 'sss': hp = '$\mathcal{H}^{2}$' + if dataset == 'cifar10': + suffixes = ['-T200000', '-T200000-FULL'] else: raise ValueError('Unkonwn search space: {:}'.format(search_space)) @@ -92,21 +86,21 @@ def query_performance(api, data, dataset, ticket): return np.mean(results), np.std(results) -y_min_s = {('cifar10', 'tss'): 90, - ('cifar10', 'sss'): 90, +y_min_s = {('cifar10', 'tss'): 91, + ('cifar10', 'sss'): 91, ('cifar100', 'tss'): 65, ('cifar100', 'sss'): 65, ('ImageNet16-120', 'tss'): 36, ('ImageNet16-120', 'sss'): 40} y_max_s = {('cifar10', 'tss'): 94.5, - ('cifar10', 'sss'): 94.5, + ('cifar10', 'sss'): 93.5, ('cifar100', 'tss'): 72.5, ('cifar100', 'sss'): 70.5, ('ImageNet16-120', 'tss'): 46, ('ImageNet16-120', 'sss'): 46} -x_axis_s = {('cifar10', 'tss'): 200000, +x_axis_s = {('cifar10', 'tss'): 800000, ('cifar10', 'sss'): 200000, ('cifar100', 'tss'): 400, ('cifar100', 'sss'): 400, @@ -124,9 +118,9 @@ def visualize_curve(api_dict, vis_save_dir): vis_save_dir = vis_save_dir.resolve() vis_save_dir.mkdir(parents=True, exist_ok=True) - dpi, width, height = 250, 4000, 2400 + dpi, width, height = 250, 5000, 2000 figsize = width / float(dpi), height / float(dpi) - LabelSize, LegendFontsize = 16, 16 + LabelSize, LegendFontsize = 28, 28 def sub_plot_fn(ax, search_space, dataset): max_time = x_axis_s[(dataset, search_space)] @@ -137,6 +131,11 @@ def visualize_curve(api_dict, vis_save_dir): ax.set_xlim(0, x_axis_s[(dataset, search_space)]) ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)]) + for tick in ax.get_xticklabels(): + tick.set_rotation(25) + tick.set_fontsize(LabelSize - 6) + for tick in ax.get_yticklabels(): + tick.set_fontsize(LabelSize - 6) for idx, (alg, xdata) in enumerate(alg2data.items()): accuracies = [] for ticket in time_tickets: @@ -150,8 +149,8 @@ def visualize_curve(api_dict, vis_save_dir): ax.plot(time_tickets, accuracies, c=xdata['color'], linestyle=xdata['linestyle'], label='{:}'.format(alg)) ax.set_xlabel('Estimated wall-clock time', fontsize=LabelSize) ax.set_ylabel('Test accuracy', fontsize=LabelSize) - ax.set_title(r'Searching results on {:} for {:}'.format(name2label[dataset], spaces2latex[search_space]), - fontsize=LabelSize+4) + ax.set_title(r'Results on {:} over {:}'.format(name2label[dataset], spaces2latex[search_space]), + fontsize=LabelSize) ax.legend(loc=4, fontsize=LegendFontsize) fig, axs = plt.subplots(1, 2, figsize=figsize) @@ -165,7 +164,7 @@ def visualize_curve(api_dict, vis_save_dir): if __name__ == '__main__': parser = argparse.ArgumentParser(description='NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/nas-algos-vs-h', help='Folder to save checkpoints and log.') + parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/nas-algos-vs-h', help='Folder to save checkpoints and log.') args = parser.parse_args() save_dir = Path(args.save_dir) diff --git a/exps/NATS-Bench/draw-ranks.py b/exps/NATS-Bench/draw-ranks.py index e8f7d65..031ef83 100644 --- a/exps/NATS-Bench/draw-ranks.py +++ b/exps/NATS-Bench/draw-ranks.py @@ -11,7 +11,7 @@ import scipy import numpy as np from typing import List, Text, Dict, Any from shutil import copyfile -from collections import defaultdict +from collections import defaultdict, OrderedDict from copy import deepcopy from pathlib import Path import matplotlib @@ -28,69 +28,103 @@ from models import get_cell_based_tiny_net from nats_bench import create -def visualize_relative_info(api, vis_save_dir, indicator): +name2label = {'cifar10': 'CIFAR-10', + 'cifar100': 'CIFAR-100', + 'ImageNet16-120': 'ImageNet-16-120'} + + +def visualize_relative_info(vis_save_dir, search_space, indicator, topk): vis_save_dir = vis_save_dir.resolve() - # print ('{:} start to visualize {:} information'.format(time_string(), api)) + print ('{:} start to visualize {:} with top-{:} information'.format(time_string(), search_space, topk)) vis_save_dir.mkdir(parents=True, exist_ok=True) + cache_file_path = vis_save_dir / 'cache-{:}-info.pth'.format(search_space) + datasets = ['cifar10', 'cifar100', 'ImageNet16-120'] + if not cache_file_path.exists(): + api = create(None, search_space, fast_mode=False, verbose=False) + all_infos = OrderedDict() + for index in range(len(api)): + all_info = OrderedDict() + for dataset in datasets: + info_less = api.get_more_info(index, dataset, hp='12', is_random=False) + info_more = api.get_more_info(index, dataset, hp=api.full_train_epochs, is_random=False) + all_info[dataset] = dict(less=info_less['test-accuracy'], + more=info_more['test-accuracy']) + all_infos[index] = all_info + torch.save(all_infos, cache_file_path) + print ('{:} save all cache data into {:}'.format(time_string(), cache_file_path)) + else: + api = create(None, search_space, fast_mode=True, verbose=False) + all_infos = torch.load(cache_file_path) - cifar010_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar10', indicator) - cifar100_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar100', indicator) - imagenet_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('ImageNet16-120', indicator) - cifar010_info = torch.load(cifar010_cache_path) - cifar100_info = torch.load(cifar100_cache_path) - imagenet_info = torch.load(imagenet_cache_path) - indexes = list(range(len(cifar010_info['params']))) - print ('{:} start to visualize relative ranking'.format(time_string())) - - cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info['test_accs'][i]) - cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info['test_accs'][i]) - imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info['test_accs'][i]) - - cifar100_labels, imagenet_labels = [], [] - for idx in cifar010_ord_indexes: - cifar100_labels.append( cifar100_ord_indexes.index(idx) ) - imagenet_labels.append( imagenet_ord_indexes.index(idx) ) - print ('{:} prepare data done.'.format(time_string())) - - dpi, width, height = 200, 1400, 800 + dpi, width, height = 250, 5000, 1300 figsize = width / float(dpi), height / float(dpi) - LabelSize, LegendFontsize = 18, 12 - resnet_scale, resnet_alpha = 120, 0.5 + LabelSize, LegendFontsize = 16, 16 - fig = plt.figure(figsize=figsize) - ax = fig.add_subplot(111) - plt.xlim(min(indexes), max(indexes)) - plt.ylim(min(indexes), max(indexes)) - # plt.ylabel('y').set_rotation(30) - plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//3), fontsize=LegendFontsize, rotation='vertical') - plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//5), fontsize=LegendFontsize) - ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8) - ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8) - ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8) - ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10') - ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-100') - ax.scatter([-1], [-1], marker='*', s=100, c='tab:red' , label='ImageNet-16-120') - plt.grid(zorder=0) - ax.set_axisbelow(True) - plt.legend(loc=0, fontsize=LegendFontsize) - ax.set_xlabel('architecture ranking in CIFAR-10', fontsize=LabelSize) - ax.set_ylabel('architecture ranking', fontsize=LabelSize) - save_path = (vis_save_dir / '{:}-relative-rank.pdf'.format(indicator)).resolve() + fig, axs = plt.subplots(1, 3, figsize=figsize) + datasets = ['cifar10', 'cifar100', 'ImageNet16-120'] + + def sub_plot_fn(ax, dataset, indicator): + performances = [] + # pickup top 10% architectures + for _index in range(len(api)): + performances.append((all_infos[_index][dataset][indicator], _index)) + performances = sorted(performances, reverse=True) + performances = performances[: int(len(api) * topk * 0.01)] + selected_indexes = [x[1] for x in performances] + print('{:} plot {:10s} with {:}, {:} architectures'.format(time_string(), dataset, indicator, len(selected_indexes))) + standard_scores = [] + random_scores = [] + for idx in selected_indexes: + standard_scores.append( + api.get_more_info(idx, dataset, hp=api.full_train_epochs if indicator == 'more' else '12', is_random=False)['test-accuracy']) + random_scores.append( + api.get_more_info(idx, dataset, hp=api.full_train_epochs if indicator == 'more' else '12', is_random=True)['test-accuracy']) + indexes = list(range(len(selected_indexes))) + standard_indexes = sorted(indexes, key=lambda i: standard_scores[i]) + random_indexes = sorted(indexes, key=lambda i: random_scores[i]) + random_labels = [] + for idx in standard_indexes: + random_labels.append(random_indexes.index(idx)) + for tick in ax.get_xticklabels(): + tick.set_fontsize(LabelSize - 3) + for tick in ax.get_yticklabels(): + tick.set_rotation(25) + tick.set_fontsize(LabelSize - 3) + ax.set_xlim(0, len(indexes)) + ax.set_ylim(0, len(indexes)) + ax.set_yticks(np.arange(min(indexes), max(indexes), max(indexes)//3)) + ax.set_xticks(np.arange(min(indexes), max(indexes), max(indexes)//5)) + ax.scatter(indexes, random_labels, marker='^', s=0.5, c='tab:green', alpha=0.8) + ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8) + ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='Average Over Multi-Trials') + ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='Randomly Selected Trial') + + coef, p = scipy.stats.kendalltau(standard_scores, random_scores) + ax.set_xlabel('architecture ranking in {:}'.format(name2label[dataset]), fontsize=LabelSize) + if dataset == 'cifar10': + ax.set_ylabel('architecture ranking', fontsize=LabelSize) + ax.legend(loc=4, fontsize=LegendFontsize) + return coef + + for dataset, ax in zip(datasets, axs): + rank_coef = sub_plot_fn(ax, dataset, indicator) + print('sub-plot {:} on {:} done, the ranking coefficient is {:.4f}.'.format(dataset, search_space, rank_coef)) + + save_path = (vis_save_dir / '{:}-rank-{:}-top{:}.pdf'.format(search_space, indicator, topk)).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') - save_path = (vis_save_dir / '{:}-relative-rank.png'.format(indicator)).resolve() + save_path = (vis_save_dir / '{:}-rank-{:}-top{:}.png'.format(search_space, indicator, topk)).resolve() fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') - print ('{:} save into {:}'.format(time_string(), save_path)) + print('Save into {:}'.format(save_path)) if __name__ == '__main__': parser = argparse.ArgumentParser(description='NATS-Bench', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/rank-stability', help='Folder to save checkpoints and log.') - # use for train the model args = parser.parse_args() - to_save_dir = Path(args.save_dir) - # Figure 2 - visualize_relative_info(None, to_save_dir, 'tss') - visualize_relative_info(None, to_save_dir, 'sss') \ No newline at end of file + for topk in [1, 5, 10, 20]: + visualize_relative_info(to_save_dir, 'tss', 'more', topk) + visualize_relative_info(to_save_dir, 'sss', 'less', topk) + print ('{:} : complete running this file : {:}'.format(time_string(), __file__))