Update visualization codees for WS.

This commit is contained in:
D-X-Y 2020-07-20 06:19:16 +00:00
parent 16c5651bdc
commit e09decf114
5 changed files with 140 additions and 8 deletions

View File

@ -3,12 +3,12 @@
##################################################################################################### #####################################################################################################
# modified from https://github.com/pytorch/examples/blob/master/reinforcement_learning/reinforce.py # # modified from https://github.com/pytorch/examples/blob/master/reinforcement_learning/reinforce.py #
##################################################################################################### #####################################################################################################
# python ./exps/algos-v2/reinforce.py --dataset cifar10 --search_space tss --learning_rate 0.001 # python ./exps/algos-v2/reinforce.py --dataset cifar10 --search_space tss --learning_rate 0.01
# python ./exps/algos-v2/reinforce.py --dataset cifar100 --search_space tss --learning_rate 0.001 # python ./exps/algos-v2/reinforce.py --dataset cifar100 --search_space tss --learning_rate 0.01
# python ./exps/algos-v2/reinforce.py --dataset ImageNet16-120 --search_space tss --learning_rate 0.001 # python ./exps/algos-v2/reinforce.py --dataset ImageNet16-120 --search_space tss --learning_rate 0.01
# python ./exps/algos-v2/reinforce.py --dataset cifar10 --search_space sss --learning_rate 0.001 # python ./exps/algos-v2/reinforce.py --dataset cifar10 --search_space sss --learning_rate 0.01
# python ./exps/algos-v2/reinforce.py --dataset cifar100 --search_space sss --learning_rate 0.001 # python ./exps/algos-v2/reinforce.py --dataset cifar100 --search_space sss --learning_rate 0.01
# python ./exps/algos-v2/reinforce.py --dataset ImageNet16-120 --search_space sss --learning_rate 0.001 # python ./exps/algos-v2/reinforce.py --dataset ImageNet16-120 --search_space sss --learning_rate 0.01
##################################################################################################### #####################################################################################################
import os, sys, time, glob, random, argparse import os, sys, time, glob, random, argparse
import numpy as np, collections import numpy as np, collections

View File

@ -11,7 +11,7 @@ for dataset in ${datasets}
do do
for search_space in ${search_spaces} for search_space in ${search_spaces}
do do
python ./exps/algos-v2/reinforce.py --dataset ${dataset} --search_space ${search_space} --learning_rate 0.001 python ./exps/algos-v2/reinforce.py --dataset ${dataset} --search_space ${search_space} --learning_rate 0.01
python ./exps/algos-v2/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --ea_cycles 200 --ea_population 10 --ea_sample_size 3 python ./exps/algos-v2/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --ea_cycles 200 --ea_population 10 --ea_sample_size 3
python ./exps/algos-v2/random_wo_share.py --dataset ${dataset} --search_space ${search_space} python ./exps/algos-v2/random_wo_share.py --dataset ${dataset} --search_space ${search_space}
python ./exps/algos-v2/bohb.py --dataset ${dataset} --search_space ${search_space} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 python ./exps/algos-v2/bohb.py --dataset ${dataset} --search_space ${search_space} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3

View File

@ -399,6 +399,9 @@ def main(xargs):
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()))) logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))
network.set_drop_path(float(epoch+1) / total_epoch, xargs.drop_path_rate) network.set_drop_path(float(epoch+1) / total_epoch, xargs.drop_path_rate)
if xargs.algo == 'gdas':
network.set_tau( xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1) )
logger.log('[Reset tau as : {:}'.format(network.tau))
search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
= search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, xargs.algo, logger) = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, xargs.algo, logger)
search_time.update(time.time() - start_time) search_time.update(time.time() - start_time)
@ -480,6 +483,9 @@ if __name__ == '__main__':
parser.add_argument('--dataset' , type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') parser.add_argument('--dataset' , type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
parser.add_argument('--search_space', type=str, default='tss', choices=['tss'], help='The search space name.') parser.add_argument('--search_space', type=str, default='tss', choices=['tss'], help='The search space name.')
parser.add_argument('--algo' , type=str, choices=['darts-v1', 'darts-v2', 'gdas', 'setn', 'random', 'enas'], help='The search space name.') parser.add_argument('--algo' , type=str, choices=['darts-v1', 'darts-v2', 'gdas', 'setn', 'random', 'enas'], help='The search space name.')
# FOR GDAS
parser.add_argument('--tau_min', type=float, default=0.1, help='The minimum tau for Gumbel Softmax.')
parser.add_argument('--tau_max', type=float, default=10, help='The maximum tau for Gumbel Softmax.')
# channels and number-of-cells # channels and number-of-cells
parser.add_argument('--max_nodes' , type=int, default=4, help='The maximum number of nodes.') parser.add_argument('--max_nodes' , type=int, default=4, help='The maximum number of nodes.')
parser.add_argument('--channel' , type=int, default=16, help='The number of channels.') parser.add_argument('--channel' , type=int, default=16, help='The number of channels.')

View File

@ -30,7 +30,7 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None):
ss_dir = '{:}-{:}'.format(root_dir, search_space) ss_dir = '{:}-{:}'.format(root_dir, search_space)
alg2name, alg2path = OrderedDict(), OrderedDict() alg2name, alg2path = OrderedDict(), OrderedDict()
alg2name['REA'] = 'R-EA-SS3' alg2name['REA'] = 'R-EA-SS3'
alg2name['REINFORCE'] = 'REINFORCE-0.001' alg2name['REINFORCE'] = 'REINFORCE-0.01'
alg2name['RANDOM'] = 'RANDOM' alg2name['RANDOM'] = 'RANDOM'
alg2name['BOHB'] = 'BOHB' alg2name['BOHB'] = 'BOHB'
for alg, name in alg2name.items(): for alg, name in alg2name.items():

View File

@ -0,0 +1,126 @@
###############################################################
# NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) #
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
###############################################################
# Usage: python exps/experimental/vis-bench-ws.py --search_space tss
# Usage: python exps/experimental/vis-bench-ws.py --search_space sss
###############################################################
import os, gc, sys, time, torch, argparse
import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict, OrderedDict
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
matplotlib.use('agg')
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from nas_201_api import NASBench201API, NASBench301API
from log_utils import time_string
def fetch_data(root_dir='./output/search', search_space='tss', dataset=None):
ss_dir = '{:}-{:}'.format(root_dir, search_space)
alg2name, alg2path = OrderedDict(), OrderedDict()
seeds = [777]
alg2name['GDAS'] = 'gdas-affine1_BN0-None'
"""
alg2name['DARTS (1st)'] = 'darts-v1-affine1_BN0-None'
alg2name['DARTS (2nd)'] = 'darts-v2-affine1_BN0-None'
alg2name['SETN'] = 'setn-affine1_BN0-None'
alg2name['RSPS'] = 'random-affine1_BN0-None'
"""
for alg, name in alg2name.items():
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth')
alg2data = OrderedDict()
for alg, path in alg2path.items():
alg2data[alg] = []
for seed in seeds:
xpath = path.format(seed)
assert os.path.isfile(xpath), 'invalid path : {:}'.format(xpath)
data = torch.load(xpath, map_location=torch.device('cpu'))
data = torch.load(data['last_checkpoint'], map_location=torch.device('cpu'))
alg2data[alg].append(data['genotypes'])
return alg2data
y_min_s = {('cifar10', 'tss'): 90,
('cifar10', 'sss'): 92,
('cifar100', 'tss'): 65,
('cifar100', 'sss'): 65,
('ImageNet16-120', 'tss'): 36,
('ImageNet16-120', 'sss'): 40}
y_max_s = {('cifar10', 'tss'): 94.5,
('cifar10', 'sss'): 93.3,
('cifar100', 'tss'): 72,
('cifar100', 'sss'): 70,
('ImageNet16-120', 'tss'): 44,
('ImageNet16-120', 'sss'): 46}
def visualize_curve(api, vis_save_dir, search_space):
vis_save_dir = vis_save_dir.resolve()
vis_save_dir.mkdir(parents=True, exist_ok=True)
dpi, width, height = 250, 5200, 1400
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 16, 16
def sub_plot_fn(ax, dataset):
alg2data = fetch_data(search_space=search_space, dataset=dataset)
alg2accuracies = OrderedDict()
epochs = 20
colors = ['b', 'g', 'c', 'm', 'y']
ax.set_xlim(0, epochs)
# ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)])
for idx, (alg, data) in enumerate(alg2data.items()):
print('plot alg : {:}'.format(alg))
xs, accuracies = [], []
for iepoch in range(epochs+1):
structures, accs = [_[iepoch-1] for _ in data], []
for structure in structures:
info = api.get_more_info(structure, dataset=dataset, hp=90 if isinstance(api, NASBench301API) else 200, is_random=False)
accs.append(info['test-accuracy'])
accuracies.append(sum(accs)/len(accs))
xs.append(iepoch)
alg2accuracies[alg] = accuracies
ax.plot(xs, accuracies, c=colors[idx], label='{:}'.format(alg))
ax.set_xlabel('The searching epoch', fontsize=LabelSize)
ax.set_ylabel('Test accuracy on {:}'.format(dataset), fontsize=LabelSize)
ax.set_title('Searching results on {:}'.format(dataset), fontsize=LabelSize+4)
ax.legend(loc=4, fontsize=LegendFontsize)
fig, axs = plt.subplots(1, 3, figsize=figsize)
datasets = ['cifar10', 'cifar100', 'ImageNet16-120']
for dataset, ax in zip(datasets, axs):
sub_plot_fn(ax, dataset)
print('sub-plot {:} on {:} done.'.format(dataset, search_space))
save_path = (vis_save_dir / '{:}-ws-curve.png'.format(search_space)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
plt.close('all')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/nas-algos', help='Folder to save checkpoints and log.')
parser.add_argument('--search_space', type=str, default='tss', choices=['tss', 'sss'], help='Choose the search space.')
args = parser.parse_args()
save_dir = Path(args.save_dir)
alg2data = fetch_data(search_space='tss', dataset='cifar10')
if args.search_space == 'tss':
api = NASBench201API(verbose=False)
elif args.search_space == 'sss':
api = NASBench301API(verbose=False)
else:
raise ValueError('Invalid search space : {:}'.format(args.search_space))
visualize_curve(api, save_dir, args.search_space)