94 lines
4.5 KiB
Python
94 lines
4.5 KiB
Python
|
###############################################################
|
||
|
# NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) #
|
||
|
###############################################################
|
||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
|
||
|
###############################################################
|
||
|
# Usage: python exps/NAS-Bench-201/test-nas-api-vis.py
|
||
|
###############################################################
|
||
|
import os, sys, time, torch, argparse
|
||
|
import numpy as np
|
||
|
from typing import List, Text, Dict, Any
|
||
|
from shutil import copyfile
|
||
|
from collections import defaultdict
|
||
|
from copy import deepcopy
|
||
|
from pathlib import Path
|
||
|
import matplotlib
|
||
|
import seaborn as sns
|
||
|
matplotlib.use('agg')
|
||
|
import matplotlib.pyplot as plt
|
||
|
import matplotlib.ticker as ticker
|
||
|
|
||
|
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||
|
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||
|
from config_utils import dict2config, load_config
|
||
|
from nas_201_api import NASBench201API, NASBench301API
|
||
|
from log_utils import time_string
|
||
|
from models import get_cell_based_tiny_net
|
||
|
|
||
|
|
||
|
def visualize_info(api, vis_save_dir, indicator):
|
||
|
vis_save_dir = vis_save_dir.resolve()
|
||
|
# print ('{:} start to visualize {:} information'.format(time_string(), api))
|
||
|
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
||
|
|
||
|
cifar010_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar10', indicator)
|
||
|
cifar100_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar100', indicator)
|
||
|
imagenet_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('ImageNet16-120', indicator)
|
||
|
cifar010_info = torch.load(cifar010_cache_path)
|
||
|
cifar100_info = torch.load(cifar100_cache_path)
|
||
|
imagenet_info = torch.load(imagenet_cache_path)
|
||
|
indexes = list(range(len(cifar010_info['params'])))
|
||
|
|
||
|
print ('{:} start to visualize relative ranking'.format(time_string()))
|
||
|
|
||
|
cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info['test_accs'][i])
|
||
|
cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info['test_accs'][i])
|
||
|
imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info['test_accs'][i])
|
||
|
|
||
|
cifar100_labels, imagenet_labels = [], []
|
||
|
for idx in cifar010_ord_indexes:
|
||
|
cifar100_labels.append( cifar100_ord_indexes.index(idx) )
|
||
|
imagenet_labels.append( imagenet_ord_indexes.index(idx) )
|
||
|
print ('{:} prepare data done.'.format(time_string()))
|
||
|
|
||
|
dpi, width, height = 200, 1400, 800
|
||
|
figsize = width / float(dpi), height / float(dpi)
|
||
|
LabelSize, LegendFontsize = 18, 12
|
||
|
resnet_scale, resnet_alpha = 120, 0.5
|
||
|
|
||
|
fig = plt.figure(figsize=figsize)
|
||
|
ax = fig.add_subplot(111)
|
||
|
plt.xlim(min(indexes), max(indexes))
|
||
|
plt.ylim(min(indexes), max(indexes))
|
||
|
# plt.ylabel('y').set_rotation(30)
|
||
|
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//3), fontsize=LegendFontsize, rotation='vertical')
|
||
|
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//5), fontsize=LegendFontsize)
|
||
|
ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8)
|
||
|
ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8)
|
||
|
ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8)
|
||
|
ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10')
|
||
|
ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-100')
|
||
|
ax.scatter([-1], [-1], marker='*', s=100, c='tab:red' , label='ImageNet-16-120')
|
||
|
plt.grid(zorder=0)
|
||
|
ax.set_axisbelow(True)
|
||
|
plt.legend(loc=0, fontsize=LegendFontsize)
|
||
|
ax.set_xlabel('architecture ranking in CIFAR-10', fontsize=LabelSize)
|
||
|
ax.set_ylabel('architecture ranking', fontsize=LabelSize)
|
||
|
save_path = (vis_save_dir / '{:}-relative-rank.pdf'.format(indicator)).resolve()
|
||
|
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
|
||
|
save_path = (vis_save_dir / '{:}-relative-rank.png'.format(indicator)).resolve()
|
||
|
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
|
||
|
print ('{:} save into {:}'.format(time_string(), save_path))
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||
|
parser.add_argument('--save_dir', type=str, default='output/NAS-BENCH-202', help='Folder to save checkpoints and log.')
|
||
|
parser.add_argument('--check_N', type=int, default=32768, help='For safety.')
|
||
|
# use for train the model
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
visualize_info(None, Path('output/vis-nas-bench/'), 'tss')
|
||
|
|
||
|
visualize_info(None, Path('output/vis-nas-bench/'), 'sss')
|