Update REA, REINFORCE, and RANDOM

This commit is contained in:
D-X-Y 2020-07-13 10:04:52 +00:00
parent 041a9aa4a3
commit 6dc494be08
12 changed files with 277 additions and 53 deletions

View File

@ -72,6 +72,14 @@ def test_api(api, is_301=True):
print('{:}\n'.format(info))
print('{:} finish testing the api : {:}'.format(time_string(), api))
if not is_301:
arch_str = '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|'
matrix = api.str2matrix(arch_str)
print('Compute the adjacency matrix of {:}'.format(arch_str))
print(matrix)
info = api.simulate_train_eval(123, 'cifar10')
print('simulate_train_eval : {:}'.format(info))
def test_issue_81_82(api):
results = api.query_by_index(0, 'cifar10-valid', hp='12')

View File

@ -0,0 +1,91 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
##############################################################################
# Random Search for Hyper-Parameter Optimization, JMLR 2012 ##################
##############################################################################
# python ./exps/algos-v2/random_wo_share.py --dataset cifar10 --search_space tss
##############################################################################
import os, sys, time, glob, random, argparse
import numpy as np, collections
from copy import deepcopy
import torch
import torch.nn as nn
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_search_spaces
from nas_201_api import NASBench201API, NASBench301API
from .regularized_ea import random_topology_func, random_size_func
def main(xargs, api):
torch.set_num_threads(4)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
if xargs.search_space == 'tss':
random_arch = random_topology_func(search_space)
else:
random_arch = random_size_func(search_space)
x_start_time = time.time()
logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench))
best_arch, best_acc, total_time_cost, history = None, -1, [], []
while total_time_cost[-1] < xargs.time_budget:
arch = random_arch()
accuracy, _, _, total_cost = api.simulate_train_eval(arch, xargs.dataset, '12')
total_time_cost.append(total_cost)
history.append(arch)
if best_arch is None or best_acc < accuracy:
best_acc, best_arch = accuracy, arch
logger.log('[{:03d}] : {:} : accuracy = {:.2f}%'.format(len(history), arch, accuracy))
logger.log('{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s (real-cost = {:.3f} s).'.format(time_string(), best_arch, best_acc, len(history), total_time_cost, time.time()-x_start_time))
info = api.query_info_str_by_arch(best_arch, '200' if xargs.search_space == 'tss' else '90')
logger.log('{:}'.format(info))
logger.log('-'*100)
logger.close()
return logger.log_dir, total_time_cost, history
if __name__ == '__main__':
parser = argparse.ArgumentParser("Random NAS")
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.')
parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).')
parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.')
# log
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
args = parser.parse_args()
if args.search_space == 'tss':
api = NASBench201API(verbose=False)
elif args.search_space == 'sss':
api = NASBench301API(verbose=False)
else:
raise ValueError('Invalid search space : {:}'.format(args.search_space))
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'RANDOM')
print('save-dir : {:}'.format(args.save_dir))
if args.rand_seed < 0:
save_dir, all_info = None, {}
for i in range(args.loops_if_rand):
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, args.loops_if_rand))
args.rand_seed = random.randint(1, 100000)
save_dir, all_archs, all_total_times = main(args, api)
all_info[i] = {'all_archs': all_archs,
'all_total_times': all_total_times}
save_path = save_dir / 'results.pth'
print('save into {:}'.format(save_path))
torch.save(all_info, save_path)
else:
main(args, api)

View File

@ -3,12 +3,12 @@
##################################################################
# Regularized Evolution for Image Classifier Architecture Search #
##################################################################
# python ./exps/algos-v2/REA.py --dataset cifar10 --search_space tss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/algos-v2/REA.py --dataset cifar100 --search_space tss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/algos-v2/REA.py --dataset ImageNet16-120 --search_space tss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/algos-v2/REA.py --dataset cifar10 --search_space sss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/algos-v2/REA.py --dataset cifar100 --search_space sss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/algos-v2/REA.py --dataset ImageNet16-120 --search_space sss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/algos-v2/regularized_ea.py --dataset cifar10 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/algos-v2/regularized_ea.py --dataset cifar100 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/algos-v2/regularized_ea.py --dataset ImageNet16-120 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/algos-v2/regularized_ea.py --dataset cifar10 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/algos-v2/regularized_ea.py --dataset cifar100 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/algos-v2/regularized_ea.py --dataset ImageNet16-120 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
##################################################################
import os, sys, time, glob, random, argparse
import numpy as np, collections
@ -160,7 +160,7 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran
while len(population) < population_size:
model = Model()
model.arch = random_arch()
model.accuracy, time_cost, total_cost = api.simulate_train_eval(model.arch, dataset, '12')
model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, '12')
# Append the info
population.append(model)
history.append(model)
@ -183,7 +183,7 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran
# Create the child model and store it.
child = Model()
child.arch = mutate_arch(parent.arch)
child.accuracy, time_cost, total_cost = api.simulate_train_eval(model.arch, dataset, '12')
child.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, '12')
# Append the info
population.append(child)
history.append(child)
@ -195,11 +195,7 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran
def main(xargs, api):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
torch.set_num_threads(4)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
@ -235,12 +231,11 @@ if __name__ == '__main__':
parser.add_argument('--ea_cycles', type=int, help='The number of cycles in EA.')
parser.add_argument('--ea_population', type=int, help='The population size in EA.')
parser.add_argument('--ea_sample_size', type=int, help='The sample size in EA.')
parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).')
parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.')
parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).')
parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.')
# log
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.')
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
args = parser.parse_args()
if args.search_space == 'tss':

View File

@ -3,12 +3,12 @@
#####################################################################################################
# modified from https://github.com/pytorch/examples/blob/master/reinforcement_learning/reinforce.py #
#####################################################################################################
# python ./exps/algos-v2/reinforce.py --dataset cifar10 --search_space tss --time_budget 12000 --learning_rate 0.001
# python ./exps/algos-v2/reinforce.py --dataset cifar100 --search_space tss --time_budget 12000 --learning_rate 0.001
# python ./exps/algos-v2/reinforce.py --dataset ImageNet16-120 --search_space tss --time_budget 12000 --learning_rate 0.001
# python ./exps/algos-v2/reinforce.py --dataset cifar10 --search_space sss --time_budget 12000 --learning_rate 0.001
# python ./exps/algos-v2/reinforce.py --dataset cifar100 --search_space sss --time_budget 12000 --learning_rate 0.001
# python ./exps/algos-v2/reinforce.py --dataset ImageNet16-120 --search_space sss --time_budget 12000 --learning_rate 0.001
# python ./exps/algos-v2/reinforce.py --dataset cifar10 --search_space tss --learning_rate 0.001
# python ./exps/algos-v2/reinforce.py --dataset cifar100 --search_space tss --learning_rate 0.001
# python ./exps/algos-v2/reinforce.py --dataset ImageNet16-120 --search_space tss --learning_rate 0.001
# python ./exps/algos-v2/reinforce.py --dataset cifar10 --search_space sss --learning_rate 0.001
# python ./exps/algos-v2/reinforce.py --dataset cifar100 --search_space sss --learning_rate 0.001
# python ./exps/algos-v2/reinforce.py --dataset ImageNet16-120 --search_space sss --learning_rate 0.001
#####################################################################################################
import os, sys, time, glob, random, argparse
import numpy as np, collections
@ -120,15 +120,10 @@ def select_action(policy):
def main(xargs, api):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
torch.set_num_threads(4)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
if xargs.search_space == 'tss':
policy = PolicyTopology(search_space)
@ -144,6 +139,7 @@ def main(xargs, api):
# nas dataset load
logger.log('{:} use api : {:}'.format(time_string(), api))
api.reset_time()
# REINFORCE
x_start_time = time.time()
@ -153,7 +149,7 @@ def main(xargs, api):
start_time = time.time()
log_prob, action = select_action( policy )
arch = policy.generate_arch( action )
reward, _, current_total_cost = api.simulate_train_eval(arch, xargs.dataset, '12')
reward, _, _, current_total_cost = api.simulate_train_eval(arch, xargs.dataset, '12')
trace.append((reward, arch))
total_costs.append(current_total_cost)
@ -177,7 +173,7 @@ def main(xargs, api):
logger.log('-'*100)
logger.close()
return logger.log_dir, [api.query_index_by_arch(x[0]) for x in trace], total_costs
return logger.log_dir, [api.query_index_by_arch(x[1]) for x in trace], total_costs
if __name__ == '__main__':
@ -186,15 +182,14 @@ if __name__ == '__main__':
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.')
parser.add_argument('--learning_rate', type=float, help='The learning rate for REINFORCE.')
parser.add_argument('--EMA_momentum', type=float, default=0.9, help='The momentum value for EMA.')
parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).')
parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.')
parser.add_argument('--EMA_momentum', type=float, default=0.9, help='The momentum value for EMA.')
parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).')
parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.')
# log
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.')
parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).')
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
args = parser.parse_args()
if args.search_space == 'tss':

17
exps/algos-v2/run-all.sh Normal file
View File

@ -0,0 +1,17 @@
#!/bin/bash
# bash ./exps/algos-v2/run-all.sh
echo script name: $0
echo $# arguments
datasets="cifar10 cifar100 ImageNet16-120"
search_spaces="tss sss"
for dataset in ${datasets}
do
for search_space in ${search_spaces}
do
python ./exps/algos-v2/reinforce.py --dataset ${dataset} --search_space ${search_space} --learning_rate 0.001
python ./exps/algos-v2/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --ea_cycles 200 --ea_population 10 --ea_sample_size 3
done
done

View File

@ -84,7 +84,7 @@ def main(xargs, nas_bench):
if __name__ == '__main__':
parser = argparse.ArgumentParser("Regularized Evolution Algorithm")
parser = argparse.ArgumentParser("Random NAS")
parser.add_argument('--data_path', type=str, help='Path to dataset')
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
# channels and number-of-cells

View File

@ -0,0 +1,107 @@
###############################################################
# NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) #
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
###############################################################
# Usage: python exps/experimental/vis-bench-algos.py
###############################################################
import os, sys, time, torch, argparse
import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict, OrderedDict
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
matplotlib.use('agg')
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from nas_201_api import NASBench201API, NASBench301API
from log_utils import time_string
def fetch_data(root_dir='./output/search', search_space='tss', dataset=None):
ss_dir = '{:}-{:}'.format(root_dir, search_space)
alg2name, alg2path = OrderedDict(), OrderedDict()
alg2name['REA'] = 'R-EA-SS3'
alg2name['REINFORCE'] = 'REINFORCE-0.001'
for alg, name in alg2name.items():
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'results.pth')
assert os.path.isfile(alg2path[alg])
alg2data = OrderedDict()
for alg, path in alg2path.items():
data = torch.load(path)
for index, info in data.items():
info['time_w_arch'] = [(x, y) for x, y in zip(info['all_total_times'], info['all_archs'])]
for j, arch in enumerate(info['all_archs']):
assert arch != -1, 'invalid arch from {:} {:} {:} ({:}, {:})'.format(alg, search_space, dataset, index, j)
alg2data[alg] = data
return alg2data
def query_performance(api, data, dataset, ticket):
results, is_301 = [], isinstance(api, NASBench301API)
for i, info in data.items():
time_w_arch = sorted(info['time_w_arch'], key=lambda x: abs(x[0]-ticket))
time_a, arch_a = time_w_arch[0]
time_b, arch_b = time_w_arch[1]
info_a = api.get_more_info(arch_a, dataset=dataset, hp=90 if is_301 else 200, is_random=False)
info_b = api.get_more_info(arch_b, dataset=dataset, hp=90 if is_301 else 200, is_random=False)
accuracy_a, accuracy_b = info_a['test-accuracy'], info_b['test-accuracy']
interplate = (time_b-ticket) / (time_b-time_a) * accuracy_a + (ticket-time_a) / (time_b-time_a) * accuracy_b
results.append(interplate)
return sum(results) / len(results)
def visualize_curve(api, vis_save_dir, search_space, max_time):
vis_save_dir = vis_save_dir.resolve()
vis_save_dir.mkdir(parents=True, exist_ok=True)
dpi, width, height = 250, 4700, 1500
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 14, 14
def sub_plot_fn(ax, dataset):
alg2data = fetch_data(search_space=search_space, dataset=dataset)
alg2accuracies = OrderedDict()
time_tickets = [float(i) / 100 * max_time for i in range(100)]
colors = ['b', 'g', 'c', 'm', 'y']
for idx, (alg, data) in enumerate(alg2data.items()):
print('plot alg : {:}'.format(alg))
accuracies = []
for ticket in time_tickets:
accuracy = query_performance(api, data, dataset, ticket)
accuracies.append(accuracy)
alg2accuracies[alg] = accuracies
ax.plot(time_tickets, accuracies, c=colors[idx], label='{:}'.format(alg))
ax.legend(loc=4, fontsize=LegendFontsize)
fig, axs = plt.subplots(1, 3, figsize=figsize)
datasets = ['cifar10', 'cifar100', 'ImageNet16-120']
for dataset, ax in zip(datasets, axs):
sub_plot_fn(ax, dataset)
print('sub-plot {:} on {:} done.'.format(dataset, search_space))
save_path = (vis_save_dir / '{:}-curve.png'.format(search_space)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png')
print ('{:} save into {:}'.format(time_string(), save_path))
plt.close('all')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/nas-algos', help='Folder to save checkpoints and log.')
parser.add_argument('--max_time', type=float, default=20000, help='The maximum time budget.')
args = parser.parse_args()
save_dir = Path(args.save_dir)
api201 = NASBench201API(verbose=False)
visualize_curve(api201, save_dir, 'tss', args.max_time)
api301 = NASBench301API(verbose=False)
visualize_curve(api301, save_dir, 'sss', args.max_time)

View File

@ -3,7 +3,7 @@
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
###############################################################
# Usage: python exps/NAS-Bench-201/test-nas-api-vis.py
# Usage: python exps/experimental/visualize-nas-bench-x.py
###############################################################
import os, sys, time, torch, argparse
import numpy as np
@ -384,24 +384,25 @@ def visualize_all_rank_info(api, vis_save_dir, indicator):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--save_dir', type=str, default='output/NAS-BENCH-202', help='Folder to save checkpoints and log.')
parser.add_argument('--check_N', type=int, default=32768, help='For safety.')
parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench', help='Folder to save checkpoints and log.')
# use for train the model
args = parser.parse_args()
to_save_dir = Path(args.save_dir)
datasets = ['cifar10', 'cifar100', 'ImageNet16-120']
api201 = NASBench201API(None, verbose=True)
for xdata in datasets:
visualize_tss_info(api201, xdata, Path('output/vis-nas-bench'))
visualize_tss_info(api201, xdata, to_save_dir)
api301 = NASBench301API(None, verbose=True)
for xdata in datasets:
visualize_sss_info(api301, xdata, Path('output/vis-nas-bench'))
visualize_sss_info(api301, xdata, to_save_dir)
visualize_info(None, Path('output/vis-nas-bench/'), 'tss')
visualize_info(None, Path('output/vis-nas-bench/'), 'sss')
visualize_rank_info(None, Path('output/vis-nas-bench/'), 'tss')
visualize_rank_info(None, Path('output/vis-nas-bench/'), 'sss')
visualize_info(None, to_save_dir, 'tss')
visualize_info(None, to_save_dir, 'sss')
visualize_rank_info(None, to_save_dir, 'tss')
visualize_rank_info(None, to_save_dir, 'sss')
visualize_all_rank_info(None, Path('output/vis-nas-bench/'), 'tss')
visualize_all_rank_info(None, Path('output/vis-nas-bench/'), 'sss')
visualize_all_rank_info(None, to_save_dir, 'tss')
visualize_all_rank_info(None, to_save_dir, 'sss')

View File

@ -141,9 +141,12 @@ class NASBench201API(NASBenchMetaAPI):
# `is_random`
# When is_random=True, the performance of a random architecture will be returned
# When is_random=False, the performanceo of all trials will be averaged.
def get_more_info(self, index: int, dataset, iepoch=None, hp='12', is_random=True):
def get_more_info(self, index, dataset, iepoch=None, hp='12', is_random=True):
if self.verbose:
print('Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(index, dataset, iepoch, hp, is_random))
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
if index not in self.arch2infos_dict:
raise ValueError('Did not find {:} from arch2infos_dict.'.format(index))
archresult = self.arch2infos_dict[index][str(hp)]
# if randomly select one trial, select the seed at first
if isinstance(is_random, bool) and is_random:

View File

@ -131,7 +131,7 @@ class NASBench301API(NASBenchMetaAPI):
print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp))
return self._query_info_str_by_arch(arch, hp, print_information)
def get_more_info(self, index: int, dataset: Text, iepoch=None, hp='12', is_random=True):
def get_more_info(self, index, dataset: Text, iepoch=None, hp='12', is_random=True):
"""This function will return the metric for the `index`-th architecture
`dataset` indicates the dataset:
'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set
@ -151,6 +151,9 @@ class NASBench301API(NASBenchMetaAPI):
"""
if self.verbose:
print('Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(index, dataset, iepoch, hp, is_random))
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
if index not in self.arch2infos_dict:
raise ValueError('Did not find {:} from arch2infos_dict.'.format(index))
archresult = self.arch2infos_dict[index][str(hp)]
# if randomly select one trial, select the seed at first
if isinstance(is_random, bool) and is_random:

View File

@ -68,7 +68,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
def reset_time(self):
self._used_time = 0
def simulate_train_eval(self, arch, dataset, hp='12'):
def simulate_train_eval(self, arch, dataset, hp='12', account_time=True):
index = self.query_index_by_arch(arch)
all_names = ('cifar10', 'cifar100', 'ImageNet16-120')
assert dataset in all_names, 'Invalid dataset name : {:} vs {:}'.format(dataset, all_names)
@ -77,8 +77,10 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
else:
info = self.get_more_info(index, dataset, iepoch=None, hp=hp, is_random=True)
valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time']
self._used_time += time_cost
return valid_acc, time_cost, self._used_time
latency = self.get_latency(index, dataset)
if account_time:
self._used_time += time_cost
return valid_acc, latency, time_cost, self._used_time
def random(self):
"""Return a random index of all architectures."""

View File

@ -8,7 +8,9 @@ import torch.nn as nn
from models import CellStructure
from log_utils import time_string
def evaluate_one_shot(model, xloader, api, cal_mode, seed=111):
print ('This is an old version of codes to use NAS-Bench-API, and should be modified to align with the new version. Please contact me for more details if you use this function.')
weights = deepcopy(model.state_dict())
model.train(cal_mode)
with torch.no_grad():