Add get_torch_home func for NATS-Bench

This commit is contained in:
D-X-Y 2020-12-01 22:25:23 +08:00
parent 8afb62ad2e
commit 46b92e37e2
7 changed files with 294 additions and 10 deletions

View File

@ -385,7 +385,7 @@ def visualize_all_rank_info(api, vis_save_dir, indicator):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = argparse.ArgumentParser(description='NATS-Bench', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench', help='Folder to save checkpoints and log.')
# use for train the model
args = parser.parse_args()

View File

@ -0,0 +1,175 @@
###############################################################
# NATS-Bench (https://arxiv.org/pdf/2009.00437.pdf) #
# The code to draw Figure 6 in our paper. #
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
###############################################################
# Usage: python exps/NATS-Bench/draw-fig8.py #
###############################################################
import os, gc, sys, time, torch, argparse
import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict, OrderedDict
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
matplotlib.use('agg')
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from nats_bench import create
from log_utils import time_string
plt.rcParams.update({
"text.usetex": True,
"font.family": "sans-serif",
"font.sans-serif": ["Helvetica"]})
## for Palatino and other serif fonts use:
plt.rcParams.update({
"text.usetex": True,
"font.family": "serif",
"font.serif": ["Palatino"],
})
def fetch_data(root_dir='./output/search', search_space='tss', dataset=None):
ss_dir = '{:}-{:}'.format(root_dir, search_space)
alg2all = OrderedDict()
# alg2name['REINFORCE'] = 'REINFORCE-0.01'
# alg2name['RANDOM'] = 'RANDOM'
# alg2name['BOHB'] = 'BOHB'
if dataset == 'cifar10':
suffixes = ['-T200000', '-T200000-FULL']
elif dataset == 'cifar100':
suffixes = ['-T40000', '-T40000-FULL']
elif search_space == 'tss':
suffixes = ['-T120000', '-T120000-FULL']
elif search_space == 'sss':
suffixes = ['-T60000', '-T60000-FULL']
else:
raise ValueError('Unkonwn dataset : {:}'.format(dataset))
if search_space == 'tss':
hp = '$\mathcal{H}^{1}$'
elif search_space == 'sss':
hp = '$\mathcal{H}^{2}$'
else:
raise ValueError('Unkonwn search space: {:}'.format(search_space))
alg2all[r'REA ($\mathcal{H}^{0}$)'] = dict(
path=os.path.join(ss_dir, dataset + suffixes[0], 'R-EA-SS3', 'results.pth'),
color='b', linestyle='-')
alg2all[r'REA ({:})'.format(hp)] = dict(
path=os.path.join(ss_dir, dataset + suffixes[1], 'R-EA-SS3', 'results.pth'),
color='b', linestyle='--')
for alg, xdata in alg2all.items():
data = torch.load(xdata['path'])
for index, info in data.items():
info['time_w_arch'] = [(x, y) for x, y in zip(info['all_total_times'], info['all_archs'])]
for j, arch in enumerate(info['all_archs']):
assert arch != -1, 'invalid arch from {:} {:} {:} ({:}, {:})'.format(alg, search_space, dataset, index, j)
xdata['data'] = data
return alg2all
def query_performance(api, data, dataset, ticket):
results, is_size_space = [], api.search_space_name == 'size'
for i, info in data.items():
time_w_arch = sorted(info['time_w_arch'], key=lambda x: abs(x[0]-ticket))
time_a, arch_a = time_w_arch[0]
time_b, arch_b = time_w_arch[1]
info_a = api.get_more_info(arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
info_b = api.get_more_info(arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False)
accuracy_a, accuracy_b = info_a['test-accuracy'], info_b['test-accuracy']
interplate = (time_b-ticket) / (time_b-time_a) * accuracy_a + (ticket-time_a) / (time_b-time_a) * accuracy_b
results.append(interplate)
# return sum(results) / len(results)
return np.mean(results), np.std(results)
y_min_s = {('cifar10', 'tss'): 90,
('cifar10', 'sss'): 90,
('cifar100', 'tss'): 65,
('cifar100', 'sss'): 65,
('ImageNet16-120', 'tss'): 36,
('ImageNet16-120', 'sss'): 40}
y_max_s = {('cifar10', 'tss'): 94.5,
('cifar10', 'sss'): 94.5,
('cifar100', 'tss'): 72.5,
('cifar100', 'sss'): 70.5,
('ImageNet16-120', 'tss'): 46,
('ImageNet16-120', 'sss'): 46}
x_axis_s = {('cifar10', 'tss'): 200000,
('cifar10', 'sss'): 200000,
('cifar100', 'tss'): 400,
('cifar100', 'sss'): 400,
('ImageNet16-120', 'tss'): 1200,
('ImageNet16-120', 'sss'): 600}
name2label = {'cifar10': 'CIFAR-10',
'cifar100': 'CIFAR-100',
'ImageNet16-120': 'ImageNet-16-120'}
spaces2latex = {'tss': r'$\mathcal{S}_{t}$',
'sss': r'$\mathcal{S}_{s}$',}
def visualize_curve(api_dict, vis_save_dir):
vis_save_dir = vis_save_dir.resolve()
vis_save_dir.mkdir(parents=True, exist_ok=True)
dpi, width, height = 250, 4000, 2400
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 16, 16
def sub_plot_fn(ax, search_space, dataset):
max_time = x_axis_s[(dataset, search_space)]
alg2data = fetch_data(search_space=search_space, dataset=dataset)
alg2accuracies = OrderedDict()
total_tickets = 200
time_tickets = [float(i) / total_tickets * int(max_time) for i in range(total_tickets)]
ax.set_xlim(0, x_axis_s[(dataset, search_space)])
ax.set_ylim(y_min_s[(dataset, search_space)],
y_max_s[(dataset, search_space)])
for idx, (alg, xdata) in enumerate(alg2data.items()):
accuracies = []
for ticket in time_tickets:
# import pdb; pdb.set_trace()
accuracy, accuracy_std = query_performance(
api_dict[search_space], xdata['data'], dataset, ticket)
accuracies.append(accuracy)
# print('{:} plot alg : {:10s}, final accuracy = {:.2f}$\pm${:.2f}'.format(time_string(), alg, accuracy, accuracy_std))
print('{:} plot alg : {:10s} on {:}'.format(time_string(), alg, search_space))
alg2accuracies[alg] = accuracies
ax.plot(time_tickets, accuracies, c=xdata['color'], linestyle=xdata['linestyle'], label='{:}'.format(alg))
ax.set_xlabel('Estimated wall-clock time', fontsize=LabelSize)
ax.set_ylabel('Test accuracy', fontsize=LabelSize)
ax.set_title(r'Searching results on {:} for {:}'.format(name2label[dataset], spaces2latex[search_space]),
fontsize=LabelSize+4)
ax.legend(loc=4, fontsize=LegendFontsize)
fig, axs = plt.subplots(1, 2, figsize=figsize)
sub_plot_fn(axs[0], 'tss', 'cifar10')
sub_plot_fn(axs[1], 'sss', 'cifar10')
save_path = (vis_save_dir / 'full-curve.png').resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
plt.close('all')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/nas-algos-vs-h', help='Folder to save checkpoints and log.')
args = parser.parse_args()
save_dir = Path(args.save_dir)
api_tss = create(None, 'tss', fast_mode=True, verbose=False)
api_sss = create(None, 'sss', fast_mode=True, verbose=False)
visualize_curve(dict(tss=api_tss, sss=api_sss), save_dir)

View File

@ -0,0 +1,96 @@
###############################################################
# NATS-Bench (https://arxiv.org/pdf/2009.00437.pdf) #
# The code to draw Figure 2 / 3 / 4 / 5 in our paper. #
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
###############################################################
# Usage: python exps/NATS-Bench/draw-ranks.py #
###############################################################
import os, sys, time, torch, argparse
import scipy
import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
matplotlib.use('agg')
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from log_utils import time_string
from models import get_cell_based_tiny_net
from nats_bench import create
def visualize_relative_info(api, vis_save_dir, indicator):
vis_save_dir = vis_save_dir.resolve()
# print ('{:} start to visualize {:} information'.format(time_string(), api))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cifar010_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar10', indicator)
cifar100_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('cifar100', indicator)
imagenet_cache_path = vis_save_dir / '{:}-cache-{:}-info.pth'.format('ImageNet16-120', indicator)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
indexes = list(range(len(cifar010_info['params'])))
print ('{:} start to visualize relative ranking'.format(time_string()))
cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info['test_accs'][i])
cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info['test_accs'][i])
imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info['test_accs'][i])
cifar100_labels, imagenet_labels = [], []
for idx in cifar010_ord_indexes:
cifar100_labels.append( cifar100_ord_indexes.index(idx) )
imagenet_labels.append( imagenet_ord_indexes.index(idx) )
print ('{:} prepare data done.'.format(time_string()))
dpi, width, height = 200, 1400, 800
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 18, 12
resnet_scale, resnet_alpha = 120, 0.5
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
plt.xlim(min(indexes), max(indexes))
plt.ylim(min(indexes), max(indexes))
# plt.ylabel('y').set_rotation(30)
plt.yticks(np.arange(min(indexes), max(indexes), max(indexes)//3), fontsize=LegendFontsize, rotation='vertical')
plt.xticks(np.arange(min(indexes), max(indexes), max(indexes)//5), fontsize=LegendFontsize)
ax.scatter(indexes, cifar100_labels, marker='^', s=0.5, c='tab:green', alpha=0.8)
ax.scatter(indexes, imagenet_labels, marker='*', s=0.5, c='tab:red' , alpha=0.8)
ax.scatter(indexes, indexes , marker='o', s=0.5, c='tab:blue' , alpha=0.8)
ax.scatter([-1], [-1], marker='o', s=100, c='tab:blue' , label='CIFAR-10')
ax.scatter([-1], [-1], marker='^', s=100, c='tab:green', label='CIFAR-100')
ax.scatter([-1], [-1], marker='*', s=100, c='tab:red' , label='ImageNet-16-120')
plt.grid(zorder=0)
ax.set_axisbelow(True)
plt.legend(loc=0, fontsize=LegendFontsize)
ax.set_xlabel('architecture ranking in CIFAR-10', fontsize=LabelSize)
ax.set_ylabel('architecture ranking', fontsize=LabelSize)
save_path = (vis_save_dir / '{:}-relative-rank.pdf'.format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='pdf')
save_path = (vis_save_dir / '{:}-relative-rank.png'.format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NATS-Bench', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/rank-stability', help='Folder to save checkpoints and log.')
# use for train the model
args = parser.parse_args()
to_save_dir = Path(args.save_dir)
# Figure 2
visualize_relative_info(None, to_save_dir, 'tss')
visualize_relative_info(None, to_save_dir, 'sss')

View File

@ -9,6 +9,7 @@
# python ./exps/NATS-algos/regularized_ea.py --dataset cifar10 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/NATS-algos/regularized_ea.py --dataset cifar100 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/NATS-algos/regularized_ea.py --dataset ImageNet16-120 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/NATS-algos/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --use_proxy 0
##################################################################
import os, sys, time, glob, random, argparse
import numpy as np, collections
@ -119,10 +120,8 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran
while len(population) < population_size:
model = Model()
model.arch = random_arch()
if use_proxy:
model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp='12')
else:
model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp=api.full_train_epochs)
model.accuracy, _, _, total_cost = api.simulate_train_eval(
model.arch, dataset, hp='12' if use_proxy else api.full_train_epochs)
# Append the info
population.append(model)
history.append((model.accuracy, model.arch))
@ -146,7 +145,8 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran
# Create the child model and store it.
child = Model()
child.arch = mutate_arch(parent.arch)
child.accuracy, _, _, total_cost = api.simulate_train_eval(child.arch, dataset, hp='12')
child.accuracy, _, _, total_cost = api.simulate_train_eval(
child.arch, dataset, hp='12' if use_proxy else api.full_train_epochs)
# Append the info
population.append(child)
history.append((child.accuracy, child.arch))

View File

@ -17,6 +17,7 @@ from typing import Dict, Optional, Text, Union, Any
from nats_bench.api_utils import ArchResults
from nats_bench.api_utils import NASBenchMetaAPI
from nats_bench.api_utils import get_torch_home
from nats_bench.api_utils import nats_is_dir
from nats_bench.api_utils import nats_is_file
from nats_bench.api_utils import PICKLE_EXT
@ -88,10 +89,10 @@ class NATSsize(NASBenchMetaAPI):
if file_path_or_dict is None:
if self._fast_mode:
self._archive_dir = os.path.join(
os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1]))
get_torch_home(), '{:}-simple'.format(ALL_BASE_NAMES[-1]))
else:
file_path_or_dict = os.path.join(
os.environ['TORCH_HOME'], '{:}.{:}'.format(
get_torch_home(), '{:}.{:}'.format(
ALL_BASE_NAMES[-1], PICKLE_EXT))
print('{:} Try to use the default NATS-Bench (size) path from '
'fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode,

View File

@ -17,6 +17,7 @@ from typing import Any, Dict, List, Optional, Text, Union
from nats_bench.api_utils import ArchResults
from nats_bench.api_utils import NASBenchMetaAPI
from nats_bench.api_utils import get_torch_home
from nats_bench.api_utils import nats_is_dir
from nats_bench.api_utils import nats_is_file
from nats_bench.api_utils import PICKLE_EXT
@ -88,10 +89,10 @@ class NATStopology(NASBenchMetaAPI):
if file_path_or_dict is None:
if self._fast_mode:
self._archive_dir = os.path.join(
os.environ['TORCH_HOME'], '{:}-simple'.format(ALL_BASE_NAMES[-1]))
get_torch_home(), '{:}-simple'.format(ALL_BASE_NAMES[-1]))
else:
file_path_or_dict = os.path.join(
os.environ['TORCH_HOME'], '{:}.{:}'.format(
get_torch_home(), '{:}.{:}'.format(
ALL_BASE_NAMES[-1], PICKLE_EXT))
print('{:} Try to use the default NATS-Bench (topology) path from '
'fast_mode={:} and path={:}.'.format(time_string(), self._fast_mode, file_path_or_dict))

View File

@ -45,6 +45,17 @@ def get_file_system():
return _FILE_SYSTEM
def get_torch_home():
if 'TORCH_HOME' in os.environ:
return os.environ['TORCH_HOME']
elif 'HOME' in os.environ:
return os.path.join(os.environ['HOME'], '.torch')
else:
raise ValueError('Did not find HOME in os.environ. '
'Please at least setup the path of HOME or TORCH_HOME '
'in the environment.')
def nats_is_dir(file_path):
if _FILE_SYSTEM == 'default':
return os.path.isdir(file_path)