281 lines
		
	
	
		
			9.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			281 lines
		
	
	
		
			9.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| ##################################################
 | |
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
 | |
| ###################################################################
 | |
| # BOHB: Robust and Efficient Hyperparameter Optimization at Scale #
 | |
| # required to install hpbandster ##################################
 | |
| # pip install hpbandster         ##################################
 | |
| ###################################################################
 | |
| # OMP_NUM_THREADS=4 python exps/NATS-algos/bohb.py --search_space tss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1
 | |
| # OMP_NUM_THREADS=4 python exps/NATS-algos/bohb.py --search_space sss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1
 | |
| ###################################################################
 | |
| import os, sys, time, random, argparse, collections
 | |
| from copy import deepcopy
 | |
| from pathlib import Path
 | |
| import torch
 | |
| 
 | |
| lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
 | |
| if str(lib_dir) not in sys.path:
 | |
|     sys.path.insert(0, str(lib_dir))
 | |
| from config_utils import load_config
 | |
| from datasets import get_datasets, SearchDataset
 | |
| from procedures import prepare_seed, prepare_logger
 | |
| from log_utils import AverageMeter, time_string, convert_secs2time
 | |
| from nats_bench import create
 | |
| from models import CellStructure, get_search_spaces
 | |
| 
 | |
| # BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018
 | |
| import ConfigSpace
 | |
| from hpbandster.optimizers.bohb import BOHB
 | |
| import hpbandster.core.nameserver as hpns
 | |
| from hpbandster.core.worker import Worker
 | |
| 
 | |
| 
 | |
| def get_topology_config_space(search_space, max_nodes=4):
 | |
|     cs = ConfigSpace.ConfigurationSpace()
 | |
|     # edge2index   = {}
 | |
|     for i in range(1, max_nodes):
 | |
|         for j in range(i):
 | |
|             node_str = "{:}<-{:}".format(i, j)
 | |
|             cs.add_hyperparameter(
 | |
|                 ConfigSpace.CategoricalHyperparameter(node_str, search_space)
 | |
|             )
 | |
|     return cs
 | |
| 
 | |
| 
 | |
| def get_size_config_space(search_space):
 | |
|     cs = ConfigSpace.ConfigurationSpace()
 | |
|     for ilayer in range(search_space["numbers"]):
 | |
|         node_str = "layer-{:}".format(ilayer)
 | |
|         cs.add_hyperparameter(
 | |
|             ConfigSpace.CategoricalHyperparameter(node_str, search_space["candidates"])
 | |
|         )
 | |
|     return cs
 | |
| 
 | |
| 
 | |
| def config2topology_func(max_nodes=4):
 | |
|     def config2structure(config):
 | |
|         genotypes = []
 | |
|         for i in range(1, max_nodes):
 | |
|             xlist = []
 | |
|             for j in range(i):
 | |
|                 node_str = "{:}<-{:}".format(i, j)
 | |
|                 op_name = config[node_str]
 | |
|                 xlist.append((op_name, j))
 | |
|             genotypes.append(tuple(xlist))
 | |
|         return CellStructure(genotypes)
 | |
| 
 | |
|     return config2structure
 | |
| 
 | |
| 
 | |
| def config2size_func(search_space):
 | |
|     def config2structure(config):
 | |
|         channels = []
 | |
|         for ilayer in range(search_space["numbers"]):
 | |
|             node_str = "layer-{:}".format(ilayer)
 | |
|             channels.append(str(config[node_str]))
 | |
|         return ":".join(channels)
 | |
| 
 | |
|     return config2structure
 | |
| 
 | |
| 
 | |
| class MyWorker(Worker):
 | |
|     def __init__(self, *args, convert_func=None, dataset=None, api=None, **kwargs):
 | |
|         super().__init__(*args, **kwargs)
 | |
|         self.convert_func = convert_func
 | |
|         self._dataset = dataset
 | |
|         self._api = api
 | |
|         self.total_times = []
 | |
|         self.trajectory = []
 | |
| 
 | |
|     def compute(self, config, budget, **kwargs):
 | |
|         arch = self.convert_func(config)
 | |
|         accuracy, latency, time_cost, total_time = self._api.simulate_train_eval(
 | |
|             arch, self._dataset, iepoch=int(budget) - 1, hp="12"
 | |
|         )
 | |
|         self.trajectory.append((accuracy, arch))
 | |
|         self.total_times.append(total_time)
 | |
|         return {"loss": 100 - accuracy, "info": self._api.query_index_by_arch(arch)}
 | |
| 
 | |
| 
 | |
| def main(xargs, api):
 | |
|     torch.set_num_threads(4)
 | |
|     prepare_seed(xargs.rand_seed)
 | |
|     logger = prepare_logger(args)
 | |
| 
 | |
|     logger.log("{:} use api : {:}".format(time_string(), api))
 | |
|     api.reset_time()
 | |
|     search_space = get_search_spaces(xargs.search_space, "nats-bench")
 | |
|     if xargs.search_space == "tss":
 | |
|         cs = get_topology_config_space(search_space)
 | |
|         config2structure = config2topology_func()
 | |
|     else:
 | |
|         cs = get_size_config_space(search_space)
 | |
|         config2structure = config2size_func(search_space)
 | |
| 
 | |
|     hb_run_id = "0"
 | |
| 
 | |
|     NS = hpns.NameServer(run_id=hb_run_id, host="localhost", port=0)
 | |
|     ns_host, ns_port = NS.start()
 | |
|     num_workers = 1
 | |
| 
 | |
|     workers = []
 | |
|     for i in range(num_workers):
 | |
|         w = MyWorker(
 | |
|             nameserver=ns_host,
 | |
|             nameserver_port=ns_port,
 | |
|             convert_func=config2structure,
 | |
|             dataset=xargs.dataset,
 | |
|             api=api,
 | |
|             run_id=hb_run_id,
 | |
|             id=i,
 | |
|         )
 | |
|         w.run(background=True)
 | |
|         workers.append(w)
 | |
| 
 | |
|     start_time = time.time()
 | |
|     bohb = BOHB(
 | |
|         configspace=cs,
 | |
|         run_id=hb_run_id,
 | |
|         eta=3,
 | |
|         min_budget=1,
 | |
|         max_budget=12,
 | |
|         nameserver=ns_host,
 | |
|         nameserver_port=ns_port,
 | |
|         num_samples=xargs.num_samples,
 | |
|         random_fraction=xargs.random_fraction,
 | |
|         bandwidth_factor=xargs.bandwidth_factor,
 | |
|         ping_interval=10,
 | |
|         min_bandwidth=xargs.min_bandwidth,
 | |
|     )
 | |
| 
 | |
|     results = bohb.run(xargs.n_iters, min_n_workers=num_workers)
 | |
| 
 | |
|     bohb.shutdown(shutdown_workers=True)
 | |
|     NS.shutdown()
 | |
| 
 | |
|     # print('There are {:} runs.'.format(len(results.get_all_runs())))
 | |
|     # workers[0].total_times
 | |
|     # workers[0].trajectory
 | |
|     current_best_index = []
 | |
|     for idx in range(len(workers[0].trajectory)):
 | |
|         trajectory = workers[0].trajectory[: idx + 1]
 | |
|         arch = max(trajectory, key=lambda x: x[0])[1]
 | |
|         current_best_index.append(api.query_index_by_arch(arch))
 | |
| 
 | |
|     best_arch = max(workers[0].trajectory, key=lambda x: x[0])[1]
 | |
|     logger.log(
 | |
|         "Best found configuration: {:} within {:.3f} s".format(
 | |
|             best_arch, workers[0].total_times[-1]
 | |
|         )
 | |
|     )
 | |
|     info = api.query_info_str_by_arch(
 | |
|         best_arch, "200" if xargs.search_space == "tss" else "90"
 | |
|     )
 | |
|     logger.log("{:}".format(info))
 | |
|     logger.log("-" * 100)
 | |
|     logger.close()
 | |
| 
 | |
|     return logger.log_dir, current_best_index, workers[0].total_times
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     parser = argparse.ArgumentParser(
 | |
|         "BOHB: Robust and Efficient Hyperparameter Optimization at Scale"
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--dataset",
 | |
|         type=str,
 | |
|         choices=["cifar10", "cifar100", "ImageNet16-120"],
 | |
|         help="Choose between Cifar10/100 and ImageNet-16.",
 | |
|     )
 | |
|     # general arg
 | |
|     parser.add_argument(
 | |
|         "--search_space",
 | |
|         type=str,
 | |
|         choices=["tss", "sss"],
 | |
|         help="Choose the search space.",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--time_budget",
 | |
|         type=int,
 | |
|         default=20000,
 | |
|         help="The total time cost budge for searching (in seconds).",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--loops_if_rand", type=int, default=500, help="The total runs for evaluation."
 | |
|     )
 | |
|     # BOHB
 | |
|     parser.add_argument(
 | |
|         "--strategy",
 | |
|         default="sampling",
 | |
|         type=str,
 | |
|         nargs="?",
 | |
|         help="optimization strategy for the acquisition function",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--min_bandwidth",
 | |
|         default=0.3,
 | |
|         type=float,
 | |
|         nargs="?",
 | |
|         help="minimum bandwidth for KDE",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--num_samples",
 | |
|         default=64,
 | |
|         type=int,
 | |
|         nargs="?",
 | |
|         help="number of samples for the acquisition function",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--random_fraction",
 | |
|         default=0.33,
 | |
|         type=float,
 | |
|         nargs="?",
 | |
|         help="fraction of random configurations",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--bandwidth_factor",
 | |
|         default=3,
 | |
|         type=int,
 | |
|         nargs="?",
 | |
|         help="factor multiplied to the bandwidth",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--n_iters",
 | |
|         default=300,
 | |
|         type=int,
 | |
|         nargs="?",
 | |
|         help="number of iterations for optimization method",
 | |
|     )
 | |
|     # log
 | |
|     parser.add_argument(
 | |
|         "--save_dir",
 | |
|         type=str,
 | |
|         default="./output/search",
 | |
|         help="Folder to save checkpoints and log.",
 | |
|     )
 | |
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
 | |
|     args = parser.parse_args()
 | |
| 
 | |
|     api = create(None, args.search_space, fast_mode=False, verbose=False)
 | |
| 
 | |
|     args.save_dir = os.path.join(
 | |
|         "{:}-{:}".format(args.save_dir, args.search_space),
 | |
|         "{:}-T{:}".format(args.dataset, args.time_budget),
 | |
|         "BOHB",
 | |
|     )
 | |
|     print("save-dir : {:}".format(args.save_dir))
 | |
| 
 | |
|     if args.rand_seed < 0:
 | |
|         save_dir, all_info = None, collections.OrderedDict()
 | |
|         for i in range(args.loops_if_rand):
 | |
|             print("{:} : {:03d}/{:03d}".format(time_string(), i, args.loops_if_rand))
 | |
|             args.rand_seed = random.randint(1, 100000)
 | |
|             save_dir, all_archs, all_total_times = main(args, api)
 | |
|             all_info[i] = {"all_archs": all_archs, "all_total_times": all_total_times}
 | |
|         save_path = save_dir / "results.pth"
 | |
|         print("save into {:}".format(save_path))
 | |
|         torch.save(all_info, save_path)
 | |
|     else:
 | |
|         main(args, api)
 |