Update visualization codes for NATS-Bench
This commit is contained in:
		| @@ -11,7 +11,7 @@ import scipy | ||||
| import numpy as np | ||||
| from typing import List, Text, Dict, Any | ||||
| from shutil import copyfile | ||||
| from collections import defaultdict | ||||
| from collections import defaultdict, OrderedDict | ||||
| from copy    import deepcopy | ||||
| from pathlib import Path | ||||
| import matplotlib | ||||
| @@ -28,69 +28,103 @@ from models import get_cell_based_tiny_net | ||||
| from nats_bench import create | ||||
|  | ||||
|  | ||||
| def visualize_relative_info(api, vis_save_dir, indicator): | ||||
| name2label = {'cifar10': 'CIFAR-10', | ||||
|               'cifar100': 'CIFAR-100', | ||||
|               'ImageNet16-120': 'ImageNet-16-120'} | ||||
|  | ||||
|  | ||||
| def visualize_relative_info(vis_save_dir, search_space, indicator, topk): | ||||
|   vis_save_dir = vis_save_dir.resolve() | ||||
|   # print ('{:} start to visualize {:} information'.format(time_string(), api)) | ||||
|   print ('{:} start to visualize {:} with top-{:} information'.format(time_string(), search_space, topk)) | ||||
|   vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|   cache_file_path = vis_save_dir / 'cache-{:}-info.pth'.format(search_space) | ||||
|   datasets = ['cifar10', 'cifar100', 'ImageNet16-120'] | ||||
|   if not cache_file_path.exists(): | ||||
|     api = create(None, search_space, fast_mode=False, verbose=False) | ||||
|     all_infos = OrderedDict() | ||||
|     for index in range(len(api)): | ||||
|       all_info = OrderedDict() | ||||
|       for dataset in datasets: | ||||
|         info_less = api.get_more_info(index, dataset, hp='12', is_random=False) | ||||
|         info_more = api.get_more_info(index, dataset, hp=api.full_train_epochs, is_random=False) | ||||
|         all_info[dataset] = dict(less=info_less['test-accuracy'], | ||||
|                                  more=info_more['test-accuracy']) | ||||
|       all_infos[index] = all_info | ||||
|     torch.save(all_infos, cache_file_path) | ||||
|     print ('{:} save all cache data into {:}'.format(time_string(), cache_file_path)) | ||||
|   else: | ||||
|     api = create(None, search_space, fast_mode=True, verbose=False) | ||||
|     all_infos = torch.load(cache_file_path) | ||||
|  | ||||
|   cifar010_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar10', indicator) | ||||
|   cifar100_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar100', indicator) | ||||
|   imagenet_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('ImageNet16-120', indicator) | ||||
|   cifar010_info = torch.load(cifar010_cache_path) | ||||
|   cifar100_info = torch.load(cifar100_cache_path) | ||||
|   imagenet_info = torch.load(imagenet_cache_path) | ||||
|   indexes       = list(range(len(cifar010_info['params']))) | ||||
|  | ||||
|   print ('{:} start to visualize relative ranking'.format(time_string())) | ||||
|  | ||||
|   cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info['test_accs'][i]) | ||||
|   cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info['test_accs'][i]) | ||||
|   imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info['test_accs'][i]) | ||||
|  | ||||
|   cifar100_labels, imagenet_labels = [], [] | ||||
|   for idx in cifar010_ord_indexes: | ||||
|     cifar100_labels.append( cifar100_ord_indexes.index(idx) ) | ||||
|     imagenet_labels.append( imagenet_ord_indexes.index(idx) ) | ||||
|   print ('{:} prepare data done.'.format(time_string())) | ||||
|  | ||||
|   dpi, width, height = 200, 1400,  800 | ||||
|   dpi, width, height = 250, 5000, 1300 | ||||
|   figsize = width / float(dpi), height / float(dpi) | ||||
|   LabelSize, LegendFontsize = 18, 12 | ||||
|   resnet_scale, resnet_alpha = 120, 0.5 | ||||
|   LabelSize, LegendFontsize = 16, 16 | ||||
|  | ||||
|   fig = plt.figure(figsize=figsize) | ||||
|   ax  = fig.add_subplot(111) | ||||
|   plt.xlim(min(indexes), max(indexes)) | ||||
|   plt.ylim(min(indexes), max(indexes)) | ||||
|   # plt.ylabel('y').set_rotation(30) | ||||
|   plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//3), fontsize=LegendFontsize, rotation='vertical') | ||||
|   plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//5), fontsize=LegendFontsize) | ||||
|   ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8) | ||||
|   ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red'  , alpha=0.8) | ||||
|   ax.scatter(indexes, indexes        , marker='o', s=0.5, c='tab:blue' , alpha=0.8) | ||||
|   ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10') | ||||
|   ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-100') | ||||
|   ax.scatter([-1], [-1], marker='*', s=100, c='tab:red'  , label='ImageNet-16-120') | ||||
|   plt.grid(zorder=0) | ||||
|   ax.set_axisbelow(True) | ||||
|   plt.legend(loc=0, fontsize=LegendFontsize) | ||||
|   ax.set_xlabel('architecture ranking in CIFAR-10', fontsize=LabelSize) | ||||
|   ax.set_ylabel('architecture ranking', fontsize=LabelSize) | ||||
|   save_path = (vis_save_dir / '{:}-relative-rank.pdf'.format(indicator)).resolve() | ||||
|   fig, axs = plt.subplots(1, 3, figsize=figsize) | ||||
|   datasets = ['cifar10', 'cifar100', 'ImageNet16-120'] | ||||
|    | ||||
|   def sub_plot_fn(ax, dataset, indicator): | ||||
|     performances = [] | ||||
|     # pickup top 10% architectures | ||||
|     for _index in range(len(api)): | ||||
|       performances.append((all_infos[_index][dataset][indicator], _index)) | ||||
|     performances = sorted(performances, reverse=True) | ||||
|     performances = performances[: int(len(api) * topk * 0.01)] | ||||
|     selected_indexes = [x[1] for x in performances] | ||||
|     print('{:} plot {:10s} with {:}, {:} architectures'.format(time_string(), dataset, indicator, len(selected_indexes))) | ||||
|     standard_scores = [] | ||||
|     random_scores = [] | ||||
|     for idx in selected_indexes: | ||||
|       standard_scores.append( | ||||
|         api.get_more_info(idx, dataset, hp=api.full_train_epochs if indicator == 'more' else '12', is_random=False)['test-accuracy']) | ||||
|       random_scores.append( | ||||
|         api.get_more_info(idx, dataset, hp=api.full_train_epochs if indicator == 'more' else '12', is_random=True)['test-accuracy']) | ||||
|     indexes = list(range(len(selected_indexes))) | ||||
|     standard_indexes = sorted(indexes, key=lambda i: standard_scores[i]) | ||||
|     random_indexes = sorted(indexes, key=lambda i: random_scores[i]) | ||||
|     random_labels = [] | ||||
|     for idx in standard_indexes: | ||||
|       random_labels.append(random_indexes.index(idx)) | ||||
|     for tick in ax.get_xticklabels(): | ||||
|       tick.set_fontsize(LabelSize - 3) | ||||
|     for tick in ax.get_yticklabels(): | ||||
|       tick.set_rotation(25) | ||||
|       tick.set_fontsize(LabelSize - 3) | ||||
|     ax.set_xlim(0, len(indexes)) | ||||
|     ax.set_ylim(0, len(indexes)) | ||||
|     ax.set_yticks(np.arange(min(indexes), max(indexes), max(indexes)//3)) | ||||
|     ax.set_xticks(np.arange(min(indexes), max(indexes), max(indexes)//5)) | ||||
|     ax.scatter(indexes, random_labels, marker='^', s=0.5, c='tab:green', alpha=0.8) | ||||
|     ax.scatter(indexes, indexes      , marker='o', s=0.5, c='tab:blue' , alpha=0.8) | ||||
|     ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='Average Over Multi-Trials') | ||||
|     ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='Randomly Selected Trial') | ||||
|  | ||||
|     coef, p = scipy.stats.kendalltau(standard_scores, random_scores) | ||||
|     ax.set_xlabel('architecture ranking in {:}'.format(name2label[dataset]), fontsize=LabelSize) | ||||
|     if dataset == 'cifar10': | ||||
|       ax.set_ylabel('architecture ranking', fontsize=LabelSize) | ||||
|     ax.legend(loc=4, fontsize=LegendFontsize) | ||||
|     return coef | ||||
|  | ||||
|   for dataset, ax in zip(datasets, axs): | ||||
|     rank_coef = sub_plot_fn(ax, dataset, indicator) | ||||
|     print('sub-plot {:} on {:} done, the ranking coefficient is {:.4f}.'.format(dataset, search_space, rank_coef)) | ||||
|  | ||||
|   save_path = (vis_save_dir / '{:}-rank-{:}-top{:}.pdf'.format(search_space, indicator, topk)).resolve() | ||||
|   fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf') | ||||
|   save_path = (vis_save_dir / '{:}-relative-rank.png'.format(indicator)).resolve() | ||||
|   save_path = (vis_save_dir / '{:}-rank-{:}-top{:}.png'.format(search_space, indicator, topk)).resolve() | ||||
|   fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') | ||||
|   print ('{:} save into {:}'.format(time_string(), save_path)) | ||||
|   print('Save into {:}'.format(save_path)) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   parser = argparse.ArgumentParser(description='NATS-Bench', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--save_dir',    type=str, default='output/vis-nas-bench/rank-stability', help='Folder to save checkpoints and log.') | ||||
|   # use for train the model | ||||
|   args = parser.parse_args() | ||||
|  | ||||
|   to_save_dir = Path(args.save_dir) | ||||
|  | ||||
|   # Figure 2 | ||||
|   visualize_relative_info(None, to_save_dir, 'tss') | ||||
|   visualize_relative_info(None, to_save_dir, 'sss') | ||||
|   for topk in [1, 5, 10, 20]: | ||||
|     visualize_relative_info(to_save_dir, 'tss', 'more', topk) | ||||
|     visualize_relative_info(to_save_dir, 'sss', 'less', topk) | ||||
|   print ('{:} : complete running this file : {:}'.format(time_string(), __file__)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user