Update new version of BOHB
This commit is contained in:
		| @@ -6,6 +6,7 @@ | ||||
| # pip install hpbandster         ################################## | ||||
| ################################################################### | ||||
| # OMP_NUM_THREADS=4 python exps/algos-v2/bohb.py --search_space tss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1 | ||||
| # OMP_NUM_THREADS=4 python exps/algos-v2/bohb.py --search_space sss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1 | ||||
| ################################################################### | ||||
| import os, sys, time, random, argparse, collections | ||||
| from copy import deepcopy | ||||
| @@ -38,12 +39,9 @@ def get_topology_config_space(search_space, max_nodes=4): | ||||
|  | ||||
| def get_size_config_space(search_space): | ||||
|   cs = ConfigSpace.ConfigurationSpace() | ||||
|   import pdb; pdb.set_trace() | ||||
|   #edge2index   = {} | ||||
|   for i in range(1, max_nodes): | ||||
|     for j in range(i): | ||||
|       node_str = '{:}<-{:}'.format(i, j) | ||||
|       cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space)) | ||||
|   for ilayer in range(search_space['numbers']): | ||||
|     node_str = 'layer-{:}'.format(ilayer) | ||||
|     cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space['candidates'])) | ||||
|   return cs | ||||
|  | ||||
|  | ||||
| @@ -61,6 +59,16 @@ def config2topology_func(max_nodes=4): | ||||
|   return config2structure | ||||
|  | ||||
|  | ||||
| def config2size_func(search_space): | ||||
|   def config2structure(config): | ||||
|     channels = [] | ||||
|     for ilayer in range(search_space['numbers']): | ||||
|       node_str = 'layer-{:}'.format(ilayer) | ||||
|       channels.append(str(config[node_str])) | ||||
|     return ':'.join(channels) | ||||
|   return config2structure | ||||
|  | ||||
|  | ||||
| class MyWorker(Worker): | ||||
|  | ||||
|   def __init__(self, *args, convert_func=None, dataset=None, api=None, **kwargs): | ||||
| @@ -89,11 +97,11 @@ def main(xargs, api): | ||||
|   api.reset_time() | ||||
|   search_space = get_search_spaces(xargs.search_space, 'nas-bench-301') | ||||
|   if xargs.search_space == 'tss': | ||||
|   	cs = get_topology_config_space(search_space) | ||||
|   	config2structure = config2topology_func() | ||||
|     cs = get_topology_config_space(search_space) | ||||
|     config2structure = config2topology_func() | ||||
|   else: | ||||
|     cs = get_size_config_space(search_space) | ||||
|     import pdb; pdb.set_trace() | ||||
|     config2structure = config2size_func(search_space) | ||||
|    | ||||
|   hb_run_id = '0' | ||||
|  | ||||
|   | ||||
| @@ -17,3 +17,6 @@ do | ||||
|     python exps/algos-v2/bohb.py --dataset ${dataset} --search_space ${search_space} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 | ||||
|   done | ||||
| done | ||||
|  | ||||
| python exps/experimental/vis-bench-algos.py --search_space tss | ||||
| python exps/experimental/vis-bench-algos.py --search_space sss | ||||
|   | ||||
		Reference in New Issue
	
	Block a user