Update new version of BOHB
This commit is contained in:
parent
2c861f33c4
commit
a99df6dc31
@ -6,6 +6,7 @@
|
|||||||
# pip install hpbandster ##################################
|
# pip install hpbandster ##################################
|
||||||
###################################################################
|
###################################################################
|
||||||
# OMP_NUM_THREADS=4 python exps/algos-v2/bohb.py --search_space tss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1
|
# OMP_NUM_THREADS=4 python exps/algos-v2/bohb.py --search_space tss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1
|
||||||
|
# OMP_NUM_THREADS=4 python exps/algos-v2/bohb.py --search_space sss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1
|
||||||
###################################################################
|
###################################################################
|
||||||
import os, sys, time, random, argparse, collections
|
import os, sys, time, random, argparse, collections
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@ -38,12 +39,9 @@ def get_topology_config_space(search_space, max_nodes=4):
|
|||||||
|
|
||||||
def get_size_config_space(search_space):
|
def get_size_config_space(search_space):
|
||||||
cs = ConfigSpace.ConfigurationSpace()
|
cs = ConfigSpace.ConfigurationSpace()
|
||||||
import pdb; pdb.set_trace()
|
for ilayer in range(search_space['numbers']):
|
||||||
#edge2index = {}
|
node_str = 'layer-{:}'.format(ilayer)
|
||||||
for i in range(1, max_nodes):
|
cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space['candidates']))
|
||||||
for j in range(i):
|
|
||||||
node_str = '{:}<-{:}'.format(i, j)
|
|
||||||
cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space))
|
|
||||||
return cs
|
return cs
|
||||||
|
|
||||||
|
|
||||||
@ -61,6 +59,16 @@ def config2topology_func(max_nodes=4):
|
|||||||
return config2structure
|
return config2structure
|
||||||
|
|
||||||
|
|
||||||
|
def config2size_func(search_space):
|
||||||
|
def config2structure(config):
|
||||||
|
channels = []
|
||||||
|
for ilayer in range(search_space['numbers']):
|
||||||
|
node_str = 'layer-{:}'.format(ilayer)
|
||||||
|
channels.append(str(config[node_str]))
|
||||||
|
return ':'.join(channels)
|
||||||
|
return config2structure
|
||||||
|
|
||||||
|
|
||||||
class MyWorker(Worker):
|
class MyWorker(Worker):
|
||||||
|
|
||||||
def __init__(self, *args, convert_func=None, dataset=None, api=None, **kwargs):
|
def __init__(self, *args, convert_func=None, dataset=None, api=None, **kwargs):
|
||||||
@ -89,11 +97,11 @@ def main(xargs, api):
|
|||||||
api.reset_time()
|
api.reset_time()
|
||||||
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
|
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
|
||||||
if xargs.search_space == 'tss':
|
if xargs.search_space == 'tss':
|
||||||
cs = get_topology_config_space(search_space)
|
cs = get_topology_config_space(search_space)
|
||||||
config2structure = config2topology_func()
|
config2structure = config2topology_func()
|
||||||
else:
|
else:
|
||||||
cs = get_size_config_space(search_space)
|
cs = get_size_config_space(search_space)
|
||||||
import pdb; pdb.set_trace()
|
config2structure = config2size_func(search_space)
|
||||||
|
|
||||||
hb_run_id = '0'
|
hb_run_id = '0'
|
||||||
|
|
||||||
|
@ -17,3 +17,6 @@ do
|
|||||||
python exps/algos-v2/bohb.py --dataset ${dataset} --search_space ${search_space} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3
|
python exps/algos-v2/bohb.py --dataset ${dataset} --search_space ${search_space} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3
|
||||||
done
|
done
|
||||||
done
|
done
|
||||||
|
|
||||||
|
python exps/experimental/vis-bench-algos.py --search_space tss
|
||||||
|
python exps/experimental/vis-bench-algos.py --search_space sss
|
||||||
|
@ -3,7 +3,8 @@
|
|||||||
###############################################################
|
###############################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
|
||||||
###############################################################
|
###############################################################
|
||||||
# Usage: python exps/experimental/vis-bench-algos.py #
|
# Usage: python exps/experimental/vis-bench-algos.py --search_space tss
|
||||||
|
# Usage: python exps/experimental/vis-bench-algos.py --search_space sss
|
||||||
###############################################################
|
###############################################################
|
||||||
import os, gc, sys, time, torch, argparse
|
import os, gc, sys, time, torch, argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -115,15 +116,17 @@ def visualize_curve(api, vis_save_dir, search_space, max_time):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/nas-algos', help='Folder to save checkpoints and log.')
|
parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/nas-algos', help='Folder to save checkpoints and log.')
|
||||||
parser.add_argument('--max_time', type=float, default=20000, help='The maximum time budget.')
|
parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.')
|
||||||
|
parser.add_argument('--max_time', type=float, default=20000, help='The maximum time budget.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
save_dir = Path(args.save_dir)
|
save_dir = Path(args.save_dir)
|
||||||
|
|
||||||
api201 = NASBench201API(verbose=False)
|
if args.search_space == 'tss':
|
||||||
visualize_curve(api201, save_dir, 'tss', args.max_time)
|
api = NASBench201API(verbose=False)
|
||||||
del api201
|
elif args.search_space == 'sss':
|
||||||
gc.collect()
|
api = NASBench301API(verbose=False)
|
||||||
api301 = NASBench301API(verbose=False)
|
else:
|
||||||
visualize_curve(api301, save_dir, 'sss', args.max_time)
|
raise ValueError('Invalid search space : {:}'.format(args.search_space))
|
||||||
|
visualize_curve(api, save_dir, args.search_space, args.max_time)
|
||||||
|
Loading…
Reference in New Issue
Block a user