Update REA, REINFORCE, RANDOM, and BOHB

This commit is contained in:
D-X-Y 2020-07-14 11:53:21 +00:00
parent 168b08d9e6
commit 2c861f33c4
8 changed files with 79 additions and 88 deletions

2
.gitignore vendored
View File

@ -123,3 +123,5 @@ scripts-search/l2s-algos
TEMP-L.sh TEMP-L.sh
.nfs00* .nfs00*
*.swo
*/*.swo

View File

@ -5,9 +5,9 @@
# required to install hpbandster ################################## # required to install hpbandster ##################################
# pip install hpbandster ################################## # pip install hpbandster ##################################
################################################################### ###################################################################
# python exps/algos-v2/bohb.py --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 # OMP_NUM_THREADS=4 python exps/algos-v2/bohb.py --search_space tss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1
################################################################### ###################################################################
import os, sys, time, random, argparse import os, sys, time, random, argparse, collections
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
import torch import torch
@ -17,7 +17,7 @@ from config_utils import load_config
from datasets import get_datasets, SearchDataset from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger from procedures import prepare_seed, prepare_logger
from log_utils import AverageMeter, time_string, convert_secs2time from log_utils import AverageMeter, time_string, convert_secs2time
from nas_201_api import NASBench201API as API from nas_201_api import NASBench201API, NASBench301API
from models import CellStructure, get_search_spaces from models import CellStructure, get_search_spaces
# BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018 # BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018
import ConfigSpace import ConfigSpace
@ -38,7 +38,7 @@ def get_topology_config_space(search_space, max_nodes=4):
def get_size_config_space(search_space): def get_size_config_space(search_space):
cs = ConfigSpace.ConfigurationSpace() cs = ConfigSpace.ConfigurationSpace()
import pdb; pdb.set_trace() import pdb; pdb.set_trace()
#edge2index = {} #edge2index = {}
for i in range(1, max_nodes): for i in range(1, max_nodes):
for j in range(i): for j in range(i):
@ -63,52 +63,21 @@ def config2topology_func(max_nodes=4):
class MyWorker(Worker): class MyWorker(Worker):
def __init__(self, *args, convert_func=None, dataname=None, nas_bench=None, time_budget=None, **kwargs): def __init__(self, *args, convert_func=None, dataset=None, api=None, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.convert_func = convert_func self.convert_func = convert_func
self._dataname = dataname self._dataset = dataset
self._nas_bench = nas_bench self._api = api
self.time_budget = time_budget self.total_times = []
self.seen_archs = [] self.trajectory = []
self.sim_cost_time = 0
self.real_cost_time = 0
self.is_end = False
def get_the_best(self):
assert len(self.seen_archs) > 0
best_index, best_acc = -1, None
for arch_index in self.seen_archs:
info = self._nas_bench.get_more_info(arch_index, self._dataname, None, hp='200', is_random=True)
vacc = info['valid-accuracy']
if best_acc is None or best_acc < vacc:
best_acc = vacc
best_index = arch_index
assert best_index != -1
return best_index
def compute(self, config, budget, **kwargs): def compute(self, config, budget, **kwargs):
start_time = time.time() arch = self.convert_func( config )
structure = self.convert_func( config ) accuracy, latency, time_cost, total_time = self._api.simulate_train_eval(arch, self._dataset, iepoch=int(budget)-1, hp='12')
arch_index = self._nas_bench.query_index_by_arch( structure ) self.trajectory.append((accuracy, arch))
info = self._nas_bench.get_more_info(arch_index, self._dataname, None, hp='200', is_random=True) self.total_times.append(total_time)
cur_time = info['train-all-time'] + info['valid-per-time'] return ({'loss': 100 - accuracy,
cur_vacc = info['valid-accuracy'] 'info': self._api.query_index_by_arch(arch)})
self.real_cost_time += (time.time() - start_time)
if self.sim_cost_time + cur_time <= self.time_budget and not self.is_end:
self.sim_cost_time += cur_time
self.seen_archs.append( arch_index )
return ({'loss': 100 - float(cur_vacc),
'info': {'seen-arch' : len(self.seen_archs),
'sim-test-time' : self.sim_cost_time,
'current-arch' : arch_index}
})
else:
self.is_end = True
return ({'loss': 100,
'info': {'seen-arch' : len(self.seen_archs),
'sim-test-time' : self.sim_cost_time,
'current-arch' : None}
})
def main(xargs, api): def main(xargs, api):
@ -117,12 +86,13 @@ def main(xargs, api):
logger = prepare_logger(args) logger = prepare_logger(args)
logger.log('{:} use api : {:}'.format(time_string(), api)) logger.log('{:} use api : {:}'.format(time_string(), api))
api.reset_time()
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301') search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
if xargs.search_space == 'tss': if xargs.search_space == 'tss':
cs = get_topology_config_space(xargs.max_nodes, search_space) cs = get_topology_config_space(search_space)
config2structure = config2topology_func(xargs.max_nodes) config2structure = config2topology_func()
else: else:
cs = get_size_config_space(xargs.max_nodes, search_space) cs = get_size_config_space(search_space)
import pdb; pdb.set_trace() import pdb; pdb.set_trace()
hb_run_id = '0' hb_run_id = '0'
@ -133,41 +103,41 @@ def main(xargs, api):
workers = [] workers = []
for i in range(num_workers): for i in range(num_workers):
w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, dataname=dataname, nas_bench=nas_bench, time_budget=xargs.time_budget, run_id=hb_run_id, id=i) w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, dataset=xargs.dataset, api=api, run_id=hb_run_id, id=i)
w.run(background=True) w.run(background=True)
workers.append(w) workers.append(w)
start_time = time.time() start_time = time.time()
bohb = BOHB(configspace=cs, bohb = BOHB(configspace=cs, run_id=hb_run_id,
run_id=hb_run_id, eta=3, min_budget=1, max_budget=12,
eta=3, min_budget=12, max_budget=200, nameserver=ns_host,
nameserver=ns_host, nameserver_port=ns_port,
nameserver_port=ns_port, num_samples=xargs.num_samples,
num_samples=xargs.num_samples, random_fraction=xargs.random_fraction, bandwidth_factor=xargs.bandwidth_factor,
random_fraction=xargs.random_fraction, bandwidth_factor=xargs.bandwidth_factor, ping_interval=10, min_bandwidth=xargs.min_bandwidth)
ping_interval=10, min_bandwidth=xargs.min_bandwidth)
results = bohb.run(xargs.n_iters, min_n_workers=num_workers) results = bohb.run(xargs.n_iters, min_n_workers=num_workers)
bohb.shutdown(shutdown_workers=True) bohb.shutdown(shutdown_workers=True)
NS.shutdown() NS.shutdown()
real_cost_time = time.time() - start_time # print('There are {:} runs.'.format(len(results.get_all_runs())))
# workers[0].total_times
# workers[0].trajectory
current_best_index = []
for idx in range(len(workers[0].trajectory)):
trajectory = workers[0].trajectory[:idx+1]
arch = max(trajectory, key=lambda x: x[0])[1]
current_best_index.append(api.query_index_by_arch(arch))
id2config = results.get_id2config_mapping() best_arch = max(workers[0].trajectory, key=lambda x: x[0])[1]
incumbent = results.get_incumbent_id() logger.log('Best found configuration: {:} within {:.3f} s'.format(best_arch, workers[0].total_times[-1]))
logger.log('Best found configuration: {:} within {:.3f} s'.format(id2config[incumbent]['config'], real_cost_time)) info = api.query_info_str_by_arch(best_arch, '200' if xargs.search_space == 'tss' else '90')
best_arch = config2structure( id2config[incumbent]['config'] ) logger.log('{:}'.format(info))
info = nas_bench.query_by_arch(best_arch, '200')
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
else : logger.log('{:}'.format(info))
logger.log('-'*100) logger.log('-'*100)
logger.log('workers : {:.1f}s with {:} archs'.format(workers[0].time_budget, len(workers[0].seen_archs)))
logger.close() logger.close()
return logger.log_dir, nas_bench.query_index_by_arch( best_arch ), real_cost_time
return logger.log_dir, current_best_index, workers[0].total_times
if __name__ == '__main__': if __name__ == '__main__':
@ -185,8 +155,8 @@ if __name__ == '__main__':
parser.add_argument('--bandwidth_factor', default=3, type=int, nargs='?', help='factor multiplied to the bandwidth') parser.add_argument('--bandwidth_factor', default=3, type=int, nargs='?', help='factor multiplied to the bandwidth')
parser.add_argument('--n_iters', default=300, type=int, nargs='?', help='number of iterations for optimization method') parser.add_argument('--n_iters', default=300, type=int, nargs='?', help='number of iterations for optimization method')
# log # log
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.')
parser.add_argument('--rand_seed', type=int, help='manual seed') parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
args = parser.parse_args() args = parser.parse_args()
if args.search_space == 'tss': if args.search_space == 'tss':

View File

@ -43,7 +43,7 @@ def main(xargs, api):
current_best_index = [] current_best_index = []
while len(total_time_cost) == 0 or total_time_cost[-1] < xargs.time_budget: while len(total_time_cost) == 0 or total_time_cost[-1] < xargs.time_budget:
arch = random_arch() arch = random_arch()
accuracy, _, _, total_cost = api.simulate_train_eval(arch, xargs.dataset, '12') accuracy, _, _, total_cost = api.simulate_train_eval(arch, xargs.dataset, hp='12')
total_time_cost.append(total_cost) total_time_cost.append(total_cost)
history.append(arch) history.append(arch)
if best_arch is None or best_acc < accuracy: if best_arch is None or best_acc < accuracy:

View File

@ -160,7 +160,7 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran
while len(population) < population_size: while len(population) < population_size:
model = Model() model = Model()
model.arch = random_arch() model.arch = random_arch()
model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, '12') model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp='12')
# Append the info # Append the info
population.append(model) population.append(model)
history.append((model.accuracy, model.arch)) history.append((model.accuracy, model.arch))
@ -184,7 +184,7 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran
# Create the child model and store it. # Create the child model and store it.
child = Model() child = Model()
child.arch = mutate_arch(parent.arch) child.arch = mutate_arch(parent.arch)
child.accuracy, _, _, total_cost = api.simulate_train_eval(child.arch, dataset, '12') child.accuracy, _, _, total_cost = api.simulate_train_eval(child.arch, dataset, hp='12')
# Append the info # Append the info
population.append(child) population.append(child)
history.append((child.accuracy, child.arch)) history.append((child.accuracy, child.arch))

View File

@ -150,7 +150,7 @@ def main(xargs, api):
start_time = time.time() start_time = time.time()
log_prob, action = select_action( policy ) log_prob, action = select_action( policy )
arch = policy.generate_arch( action ) arch = policy.generate_arch( action )
reward, _, _, current_total_cost = api.simulate_train_eval(arch, xargs.dataset, '12') reward, _, _, current_total_cost = api.simulate_train_eval(arch, xargs.dataset, hp='12')
trace.append((reward, arch)) trace.append((reward, arch))
total_costs.append(current_total_cost) total_costs.append(current_total_cost)

View File

@ -1,18 +1,19 @@
#!/bin/bash #!/bin/bash
# bash ./exps/algos-v2/run-all.sh # bash ./exps/algos-v2/run-all.sh
set -e
echo script name: $0 echo script name: $0
echo $# arguments echo $# arguments
datasets="cifar10 cifar100 ImageNet16-120" datasets="cifar10 cifar100 ImageNet16-120"
search_spaces="tss sss" search_spaces="tss sss"
for dataset in ${datasets} for dataset in ${datasets}
do do
for search_space in ${search_spaces} for search_space in ${search_spaces}
do do
# python ./exps/algos-v2/reinforce.py --dataset ${dataset} --search_space ${search_space} --learning_rate 0.001 python ./exps/algos-v2/reinforce.py --dataset ${dataset} --search_space ${search_space} --learning_rate 0.001
python ./exps/algos-v2/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --ea_cycles 200 --ea_population 10 --ea_sample_size 3 python ./exps/algos-v2/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --ea_cycles 200 --ea_population 10 --ea_sample_size 3
# python ./exps/algos-v2/random_wo_share.py --dataset ${dataset} --search_space ${search_space} python ./exps/algos-v2/random_wo_share.py --dataset ${dataset} --search_space ${search_space}
python exps/algos-v2/bohb.py --dataset ${dataset} --search_space ${search_space} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3
done done
done done

View File

@ -5,7 +5,7 @@
############################################################### ###############################################################
# Usage: python exps/experimental/vis-bench-algos.py # # Usage: python exps/experimental/vis-bench-algos.py #
############################################################### ###############################################################
import os, sys, time, torch, argparse import os, gc, sys, time, torch, argparse
import numpy as np import numpy as np
from typing import List, Text, Dict, Any from typing import List, Text, Dict, Any
from shutil import copyfile from shutil import copyfile
@ -31,6 +31,7 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None):
alg2name['REA'] = 'R-EA-SS3' alg2name['REA'] = 'R-EA-SS3'
alg2name['REINFORCE'] = 'REINFORCE-0.001' alg2name['REINFORCE'] = 'REINFORCE-0.001'
alg2name['RANDOM'] = 'RANDOM' alg2name['RANDOM'] = 'RANDOM'
alg2name['BOHB'] = 'BOHB'
for alg, name in alg2name.items(): for alg, name in alg2name.items():
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'results.pth') alg2path[alg] = os.path.join(ss_dir, dataset, name, 'results.pth')
assert os.path.isfile(alg2path[alg]), 'invalid path : {:}'.format(alg2path[alg]) assert os.path.isfile(alg2path[alg]), 'invalid path : {:}'.format(alg2path[alg])
@ -58,14 +59,27 @@ def query_performance(api, data, dataset, ticket):
results.append(interplate) results.append(interplate)
return sum(results) / len(results) return sum(results) / len(results)
y_min_s = {('cifar10', 'tss'): 90,
('cifar10', 'sss'): 92,
('cifar100', 'tss'): 65,
('cifar100', 'sss'): 65,
('ImageNet16-120', 'tss'): 36,
('ImageNet16-120', 'sss'): 40}
y_max_s = {('cifar10', 'tss'): 94.5,
('cifar10', 'sss'): 93.3,
('cifar100', 'tss'): 72,
('cifar100', 'sss'): 70,
('ImageNet16-120', 'tss'): 44,
('ImageNet16-120', 'sss'): 46}
def visualize_curve(api, vis_save_dir, search_space, max_time): def visualize_curve(api, vis_save_dir, search_space, max_time):
vis_save_dir = vis_save_dir.resolve() vis_save_dir = vis_save_dir.resolve()
vis_save_dir.mkdir(parents=True, exist_ok=True) vis_save_dir.mkdir(parents=True, exist_ok=True)
dpi, width, height = 250, 5100, 1500 dpi, width, height = 250, 5200, 1400
figsize = width / float(dpi), height / float(dpi) figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 14, 14 LabelSize, LegendFontsize = 16, 16
def sub_plot_fn(ax, dataset): def sub_plot_fn(ax, dataset):
alg2data = fetch_data(search_space=search_space, dataset=dataset) alg2data = fetch_data(search_space=search_space, dataset=dataset)
@ -73,6 +87,8 @@ def visualize_curve(api, vis_save_dir, search_space, max_time):
total_tickets = 150 total_tickets = 150
time_tickets = [float(i) / total_tickets * max_time for i in range(total_tickets)] time_tickets = [float(i) / total_tickets * max_time for i in range(total_tickets)]
colors = ['b', 'g', 'c', 'm', 'y'] colors = ['b', 'g', 'c', 'm', 'y']
ax.set_xlim(0, 200)
ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)])
for idx, (alg, data) in enumerate(alg2data.items()): for idx, (alg, data) in enumerate(alg2data.items()):
print('plot alg : {:}'.format(alg)) print('plot alg : {:}'.format(alg))
accuracies = [] accuracies = []
@ -107,5 +123,7 @@ if __name__ == '__main__':
api201 = NASBench201API(verbose=False) api201 = NASBench201API(verbose=False)
visualize_curve(api201, save_dir, 'tss', args.max_time) visualize_curve(api201, save_dir, 'tss', args.max_time)
del api201
gc.collect()
api301 = NASBench301API(verbose=False) api301 = NASBench301API(verbose=False)
visualize_curve(api301, save_dir, 'sss', args.max_time) visualize_curve(api301, save_dir, 'sss', args.max_time)

View File

@ -68,14 +68,14 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
def reset_time(self): def reset_time(self):
self._used_time = 0 self._used_time = 0
def simulate_train_eval(self, arch, dataset, hp='12', account_time=True): def simulate_train_eval(self, arch, dataset, iepoch=None, hp='12', account_time=True):
index = self.query_index_by_arch(arch) index = self.query_index_by_arch(arch)
all_names = ('cifar10', 'cifar100', 'ImageNet16-120') all_names = ('cifar10', 'cifar100', 'ImageNet16-120')
assert dataset in all_names, 'Invalid dataset name : {:} vs {:}'.format(dataset, all_names) assert dataset in all_names, 'Invalid dataset name : {:} vs {:}'.format(dataset, all_names)
if dataset == 'cifar10': if dataset == 'cifar10':
info = self.get_more_info(index, 'cifar10-valid', iepoch=None, hp=hp, is_random=True) info = self.get_more_info(index, 'cifar10-valid', iepoch=iepoch, hp=hp, is_random=True)
else: else:
info = self.get_more_info(index, dataset, iepoch=None, hp=hp, is_random=True) info = self.get_more_info(index, dataset, iepoch=iepoch, hp=hp, is_random=True)
valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time'] valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time']
latency = self.get_latency(index, dataset) latency = self.get_latency(index, dataset)
if account_time: if account_time: