diff --git a/exps/algos-v2/bohb.py b/exps/algos-v2/bohb.py index 27263d6..0c5983c 100644 --- a/exps/algos-v2/bohb.py +++ b/exps/algos-v2/bohb.py @@ -6,6 +6,7 @@ # pip install hpbandster ################################## ################################################################### # OMP_NUM_THREADS=4 python exps/algos-v2/bohb.py --search_space tss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1 +# OMP_NUM_THREADS=4 python exps/algos-v2/bohb.py --search_space sss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1 ################################################################### import os, sys, time, random, argparse, collections from copy import deepcopy @@ -38,12 +39,9 @@ def get_topology_config_space(search_space, max_nodes=4): def get_size_config_space(search_space): cs = ConfigSpace.ConfigurationSpace() - import pdb; pdb.set_trace() - #edge2index = {} - for i in range(1, max_nodes): - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space)) + for ilayer in range(search_space['numbers']): + node_str = 'layer-{:}'.format(ilayer) + cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space['candidates'])) return cs @@ -61,6 +59,16 @@ def config2topology_func(max_nodes=4): return config2structure +def config2size_func(search_space): + def config2structure(config): + channels = [] + for ilayer in range(search_space['numbers']): + node_str = 'layer-{:}'.format(ilayer) + channels.append(str(config[node_str])) + return ':'.join(channels) + return config2structure + + class MyWorker(Worker): def __init__(self, *args, convert_func=None, dataset=None, api=None, **kwargs): @@ -89,11 +97,11 @@ def main(xargs, api): api.reset_time() search_space = get_search_spaces(xargs.search_space, 'nas-bench-301') if xargs.search_space == 'tss': - cs = get_topology_config_space(search_space) - config2structure = config2topology_func() + cs = get_topology_config_space(search_space) + config2structure = config2topology_func() else: cs = get_size_config_space(search_space) - import pdb; pdb.set_trace() + config2structure = config2size_func(search_space) hb_run_id = '0' diff --git a/exps/algos-v2/run-all.sh b/exps/algos-v2/run-all.sh index 2f5ee1f..4b2199b 100644 --- a/exps/algos-v2/run-all.sh +++ b/exps/algos-v2/run-all.sh @@ -17,3 +17,6 @@ do python exps/algos-v2/bohb.py --dataset ${dataset} --search_space ${search_space} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 done done + +python exps/experimental/vis-bench-algos.py --search_space tss +python exps/experimental/vis-bench-algos.py --search_space sss diff --git a/exps/experimental/vis-bench-algos.py b/exps/experimental/vis-bench-algos.py index 550597a..b9adcf2 100644 --- a/exps/experimental/vis-bench-algos.py +++ b/exps/experimental/vis-bench-algos.py @@ -3,7 +3,8 @@ ############################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 # ############################################################### -# Usage: python exps/experimental/vis-bench-algos.py # +# Usage: python exps/experimental/vis-bench-algos.py --search_space tss +# Usage: python exps/experimental/vis-bench-algos.py --search_space sss ############################################################### import os, gc, sys, time, torch, argparse import numpy as np @@ -115,15 +116,17 @@ def visualize_curve(api, vis_save_dir, search_space, max_time): if __name__ == '__main__': parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/nas-algos', help='Folder to save checkpoints and log.') - parser.add_argument('--max_time', type=float, default=20000, help='The maximum time budget.') + parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/nas-algos', help='Folder to save checkpoints and log.') + parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.') + parser.add_argument('--max_time', type=float, default=20000, help='The maximum time budget.') args = parser.parse_args() save_dir = Path(args.save_dir) - api201 = NASBench201API(verbose=False) - visualize_curve(api201, save_dir, 'tss', args.max_time) - del api201 - gc.collect() - api301 = NASBench301API(verbose=False) - visualize_curve(api301, save_dir, 'sss', args.max_time) + if args.search_space == 'tss': + api = NASBench201API(verbose=False) + elif args.search_space == 'sss': + api = NASBench301API(verbose=False) + else: + raise ValueError('Invalid search space : {:}'.format(args.search_space)) + visualize_curve(api, save_dir, args.search_space, args.max_time)