add autodl
This commit is contained in:
		
							
								
								
									
										29
									
								
								AutoDL-Projects/exps/NATS-algos/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								AutoDL-Projects/exps/NATS-algos/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | ||||
| # NAS Algorithms evaluated in [NATS-Bench](https://arxiv.org/abs/2009.00437) | ||||
|  | ||||
| The Python files in this folder are used to re-produce the results in ``NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size''. | ||||
|  | ||||
| - [`search-size.py`](https://github.com/D-X-Y/AutoDL-Projects/blob/main/exps/NATS-algos/search-size.py) contains codes for weight-sharing-based search on the size search space. | ||||
| - [`search-cell.py`](https://github.com/D-X-Y/AutoDL-Projects/blob/main/exps/NATS-algos/search-cell.py) contains codes for weight-sharing-based search on the topology search space. | ||||
| - [`bohb.py`](https://github.com/D-X-Y/AutoDL-Projects/blob/main/exps/NATS-algos/bohb.py) contains the BOHB algorithm for both size and topology search spaces. | ||||
| - [`random_wo_share.py`](https://github.com/D-X-Y/AutoDL-Projects/blob/main/exps/NATS-algos/random_wo_share.py) contains the random search algorithm for both search spaces. | ||||
| - [`regularized_ea.py`](https://github.com/D-X-Y/AutoDL-Projects/blob/main/exps/NATS-algos/regularized_ea.py) contains the REA algorithm for both search spaces. | ||||
| - [`reinforce.py`](https://github.com/D-X-Y/AutoDL-Projects/blob/main/exps/NATS-algos/reinforce.py) contains the REINFORCE algorithm for both search spaces. | ||||
|  | ||||
| ## Requirements | ||||
|  | ||||
| - `nats_bench`>=v1.2 : you can use `pip install nats_bench` to install or from [sources](https://github.com/D-X-Y/NATS-Bench) | ||||
| - `hpbandster` : if you want to run BOHB | ||||
|  | ||||
| ## Citation | ||||
|  | ||||
| If you find that this project helps your research, please consider citing the related paper: | ||||
| ``` | ||||
| @article{dong2021nats, | ||||
|   title   = {{NATS-Bench}: Benchmarking NAS Algorithms for Architecture Topology and Size}, | ||||
|   author  = {Dong, Xuanyi and Liu, Lu and Musial, Katarzyna and Gabrys, Bogdan}, | ||||
|   doi     = {10.1109/TPAMI.2021.3054824}, | ||||
|   journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)}, | ||||
|   year    = {2021}, | ||||
|   note    = {\mbox{doi}:\url{10.1109/TPAMI.2021.3054824}} | ||||
| } | ||||
| ``` | ||||
							
								
								
									
										276
									
								
								AutoDL-Projects/exps/NATS-algos/bohb.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										276
									
								
								AutoDL-Projects/exps/NATS-algos/bohb.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,276 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ################################################################### | ||||
| # BOHB: Robust and Efficient Hyperparameter Optimization at Scale # | ||||
| # required to install hpbandster ################################## | ||||
| # pip install hpbandster         ################################## | ||||
| ################################################################### | ||||
| # OMP_NUM_THREADS=4 python exps/NATS-algos/bohb.py --search_space tss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1 | ||||
| # OMP_NUM_THREADS=4 python exps/NATS-algos/bohb.py --search_space sss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1 | ||||
| ################################################################### | ||||
| import os, sys, time, random, argparse, collections | ||||
| from copy import deepcopy | ||||
| import torch | ||||
|  | ||||
| from xautodl.config_utils import load_config | ||||
| from xautodl.datasets import get_datasets, SearchDataset | ||||
| from xautodl.procedures import prepare_seed, prepare_logger | ||||
| from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.models import CellStructure, get_search_spaces | ||||
| from nats_bench import create | ||||
|  | ||||
| # BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018 | ||||
| import ConfigSpace | ||||
| from hpbandster.optimizers.bohb import BOHB | ||||
| import hpbandster.core.nameserver as hpns | ||||
| from hpbandster.core.worker import Worker | ||||
|  | ||||
|  | ||||
| def get_topology_config_space(search_space, max_nodes=4): | ||||
|     cs = ConfigSpace.ConfigurationSpace() | ||||
|     # edge2index   = {} | ||||
|     for i in range(1, max_nodes): | ||||
|         for j in range(i): | ||||
|             node_str = "{:}<-{:}".format(i, j) | ||||
|             cs.add_hyperparameter( | ||||
|                 ConfigSpace.CategoricalHyperparameter(node_str, search_space) | ||||
|             ) | ||||
|     return cs | ||||
|  | ||||
|  | ||||
| def get_size_config_space(search_space): | ||||
|     cs = ConfigSpace.ConfigurationSpace() | ||||
|     for ilayer in range(search_space["numbers"]): | ||||
|         node_str = "layer-{:}".format(ilayer) | ||||
|         cs.add_hyperparameter( | ||||
|             ConfigSpace.CategoricalHyperparameter(node_str, search_space["candidates"]) | ||||
|         ) | ||||
|     return cs | ||||
|  | ||||
|  | ||||
| def config2topology_func(max_nodes=4): | ||||
|     def config2structure(config): | ||||
|         genotypes = [] | ||||
|         for i in range(1, max_nodes): | ||||
|             xlist = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 op_name = config[node_str] | ||||
|                 xlist.append((op_name, j)) | ||||
|             genotypes.append(tuple(xlist)) | ||||
|         return CellStructure(genotypes) | ||||
|  | ||||
|     return config2structure | ||||
|  | ||||
|  | ||||
| def config2size_func(search_space): | ||||
|     def config2structure(config): | ||||
|         channels = [] | ||||
|         for ilayer in range(search_space["numbers"]): | ||||
|             node_str = "layer-{:}".format(ilayer) | ||||
|             channels.append(str(config[node_str])) | ||||
|         return ":".join(channels) | ||||
|  | ||||
|     return config2structure | ||||
|  | ||||
|  | ||||
| class MyWorker(Worker): | ||||
|     def __init__(self, *args, convert_func=None, dataset=None, api=None, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.convert_func = convert_func | ||||
|         self._dataset = dataset | ||||
|         self._api = api | ||||
|         self.total_times = [] | ||||
|         self.trajectory = [] | ||||
|  | ||||
|     def compute(self, config, budget, **kwargs): | ||||
|         arch = self.convert_func(config) | ||||
|         accuracy, latency, time_cost, total_time = self._api.simulate_train_eval( | ||||
|             arch, self._dataset, iepoch=int(budget) - 1, hp="12" | ||||
|         ) | ||||
|         self.trajectory.append((accuracy, arch)) | ||||
|         self.total_times.append(total_time) | ||||
|         return {"loss": 100 - accuracy, "info": self._api.query_index_by_arch(arch)} | ||||
|  | ||||
|  | ||||
| def main(xargs, api): | ||||
|     torch.set_num_threads(4) | ||||
|     prepare_seed(xargs.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|  | ||||
|     logger.log("{:} use api : {:}".format(time_string(), api)) | ||||
|     api.reset_time() | ||||
|     search_space = get_search_spaces(xargs.search_space, "nats-bench") | ||||
|     if xargs.search_space == "tss": | ||||
|         cs = get_topology_config_space(search_space) | ||||
|         config2structure = config2topology_func() | ||||
|     else: | ||||
|         cs = get_size_config_space(search_space) | ||||
|         config2structure = config2size_func(search_space) | ||||
|  | ||||
|     hb_run_id = "0" | ||||
|  | ||||
|     NS = hpns.NameServer(run_id=hb_run_id, host="localhost", port=0) | ||||
|     ns_host, ns_port = NS.start() | ||||
|     num_workers = 1 | ||||
|  | ||||
|     workers = [] | ||||
|     for i in range(num_workers): | ||||
|         w = MyWorker( | ||||
|             nameserver=ns_host, | ||||
|             nameserver_port=ns_port, | ||||
|             convert_func=config2structure, | ||||
|             dataset=xargs.dataset, | ||||
|             api=api, | ||||
|             run_id=hb_run_id, | ||||
|             id=i, | ||||
|         ) | ||||
|         w.run(background=True) | ||||
|         workers.append(w) | ||||
|  | ||||
|     start_time = time.time() | ||||
|     bohb = BOHB( | ||||
|         configspace=cs, | ||||
|         run_id=hb_run_id, | ||||
|         eta=3, | ||||
|         min_budget=1, | ||||
|         max_budget=12, | ||||
|         nameserver=ns_host, | ||||
|         nameserver_port=ns_port, | ||||
|         num_samples=xargs.num_samples, | ||||
|         random_fraction=xargs.random_fraction, | ||||
|         bandwidth_factor=xargs.bandwidth_factor, | ||||
|         ping_interval=10, | ||||
|         min_bandwidth=xargs.min_bandwidth, | ||||
|     ) | ||||
|  | ||||
|     results = bohb.run(xargs.n_iters, min_n_workers=num_workers) | ||||
|  | ||||
|     bohb.shutdown(shutdown_workers=True) | ||||
|     NS.shutdown() | ||||
|  | ||||
|     # print('There are {:} runs.'.format(len(results.get_all_runs()))) | ||||
|     # workers[0].total_times | ||||
|     # workers[0].trajectory | ||||
|     current_best_index = [] | ||||
|     for idx in range(len(workers[0].trajectory)): | ||||
|         trajectory = workers[0].trajectory[: idx + 1] | ||||
|         arch = max(trajectory, key=lambda x: x[0])[1] | ||||
|         current_best_index.append(api.query_index_by_arch(arch)) | ||||
|  | ||||
|     best_arch = max(workers[0].trajectory, key=lambda x: x[0])[1] | ||||
|     logger.log( | ||||
|         "Best found configuration: {:} within {:.3f} s".format( | ||||
|             best_arch, workers[0].total_times[-1] | ||||
|         ) | ||||
|     ) | ||||
|     info = api.query_info_str_by_arch( | ||||
|         best_arch, "200" if xargs.search_space == "tss" else "90" | ||||
|     ) | ||||
|     logger.log("{:}".format(info)) | ||||
|     logger.log("-" * 100) | ||||
|     logger.close() | ||||
|  | ||||
|     return logger.log_dir, current_best_index, workers[0].total_times | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser( | ||||
|         "BOHB: Robust and Efficient Hyperparameter Optimization at Scale" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--dataset", | ||||
|         type=str, | ||||
|         choices=["cifar10", "cifar100", "ImageNet16-120"], | ||||
|         help="Choose between Cifar10/100 and ImageNet-16.", | ||||
|     ) | ||||
|     # general arg | ||||
|     parser.add_argument( | ||||
|         "--search_space", | ||||
|         type=str, | ||||
|         choices=["tss", "sss"], | ||||
|         help="Choose the search space.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--time_budget", | ||||
|         type=int, | ||||
|         default=20000, | ||||
|         help="The total time cost budge for searching (in seconds).", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--loops_if_rand", type=int, default=500, help="The total runs for evaluation." | ||||
|     ) | ||||
|     # BOHB | ||||
|     parser.add_argument( | ||||
|         "--strategy", | ||||
|         default="sampling", | ||||
|         type=str, | ||||
|         nargs="?", | ||||
|         help="optimization strategy for the acquisition function", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--min_bandwidth", | ||||
|         default=0.3, | ||||
|         type=float, | ||||
|         nargs="?", | ||||
|         help="minimum bandwidth for KDE", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--num_samples", | ||||
|         default=64, | ||||
|         type=int, | ||||
|         nargs="?", | ||||
|         help="number of samples for the acquisition function", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--random_fraction", | ||||
|         default=0.33, | ||||
|         type=float, | ||||
|         nargs="?", | ||||
|         help="fraction of random configurations", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--bandwidth_factor", | ||||
|         default=3, | ||||
|         type=int, | ||||
|         nargs="?", | ||||
|         help="factor multiplied to the bandwidth", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--n_iters", | ||||
|         default=300, | ||||
|         type=int, | ||||
|         nargs="?", | ||||
|         help="number of iterations for optimization method", | ||||
|     ) | ||||
|     # log | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./output/search", | ||||
|         help="Folder to save checkpoints and log.", | ||||
|     ) | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     api = create(None, args.search_space, fast_mode=False, verbose=False) | ||||
|  | ||||
|     args.save_dir = os.path.join( | ||||
|         "{:}-{:}".format(args.save_dir, args.search_space), | ||||
|         "{:}-T{:}".format(args.dataset, args.time_budget), | ||||
|         "BOHB", | ||||
|     ) | ||||
|     print("save-dir : {:}".format(args.save_dir)) | ||||
|  | ||||
|     if args.rand_seed < 0: | ||||
|         save_dir, all_info = None, collections.OrderedDict() | ||||
|         for i in range(args.loops_if_rand): | ||||
|             print("{:} : {:03d}/{:03d}".format(time_string(), i, args.loops_if_rand)) | ||||
|             args.rand_seed = random.randint(1, 100000) | ||||
|             save_dir, all_archs, all_total_times = main(args, api) | ||||
|             all_info[i] = {"all_archs": all_archs, "all_total_times": all_total_times} | ||||
|         save_path = save_dir / "results.pth" | ||||
|         print("save into {:}".format(save_path)) | ||||
|         torch.save(all_info, save_path) | ||||
|     else: | ||||
|         main(args, api) | ||||
							
								
								
									
										156
									
								
								AutoDL-Projects/exps/NATS-algos/random_wo_share.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										156
									
								
								AutoDL-Projects/exps/NATS-algos/random_wo_share.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,156 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ############################################################################## | ||||
| # Random Search for Hyper-Parameter Optimization, JMLR 2012 ################## | ||||
| ############################################################################## | ||||
| # python ./exps/NATS-algos/random_wo_share.py --dataset cifar10 --search_space tss | ||||
| # python ./exps/NATS-algos/random_wo_share.py --dataset cifar100 --search_space tss | ||||
| # python ./exps/NATS-algos/random_wo_share.py --dataset ImageNet16-120 --search_space tss | ||||
| ############################################################################## | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np, collections | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| from xautodl.config_utils import load_config, dict2config, configure2str | ||||
| from xautodl.datasets import get_datasets, SearchDataset | ||||
| from xautodl.procedures import ( | ||||
|     prepare_seed, | ||||
|     prepare_logger, | ||||
|     save_checkpoint, | ||||
|     copy_checkpoint, | ||||
|     get_optim_scheduler, | ||||
| ) | ||||
| from xautodl.utils import get_model_infos, obtain_accuracy | ||||
| from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.models import CellStructure, get_search_spaces | ||||
| from nats_bench import create | ||||
|  | ||||
|  | ||||
| def random_topology_func(op_names, max_nodes=4): | ||||
|     # Return a random architecture | ||||
|     def random_architecture(): | ||||
|         genotypes = [] | ||||
|         for i in range(1, max_nodes): | ||||
|             xlist = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 op_name = random.choice(op_names) | ||||
|                 xlist.append((op_name, j)) | ||||
|             genotypes.append(tuple(xlist)) | ||||
|         return CellStructure(genotypes) | ||||
|  | ||||
|     return random_architecture | ||||
|  | ||||
|  | ||||
| def random_size_func(info): | ||||
|     # Return a random architecture | ||||
|     def random_architecture(): | ||||
|         channels = [] | ||||
|         for i in range(info["numbers"]): | ||||
|             channels.append(str(random.choice(info["candidates"]))) | ||||
|         return ":".join(channels) | ||||
|  | ||||
|     return random_architecture | ||||
|  | ||||
|  | ||||
| def main(xargs, api): | ||||
|     torch.set_num_threads(4) | ||||
|     prepare_seed(xargs.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|  | ||||
|     logger.log("{:} use api : {:}".format(time_string(), api)) | ||||
|     api.reset_time() | ||||
|  | ||||
|     search_space = get_search_spaces(xargs.search_space, "nats-bench") | ||||
|     if xargs.search_space == "tss": | ||||
|         random_arch = random_topology_func(search_space) | ||||
|     else: | ||||
|         random_arch = random_size_func(search_space) | ||||
|  | ||||
|     best_arch, best_acc, total_time_cost, history = None, -1, [], [] | ||||
|     current_best_index = [] | ||||
|     while len(total_time_cost) == 0 or total_time_cost[-1] < xargs.time_budget: | ||||
|         arch = random_arch() | ||||
|         accuracy, _, _, total_cost = api.simulate_train_eval( | ||||
|             arch, xargs.dataset, hp="12" | ||||
|         ) | ||||
|         total_time_cost.append(total_cost) | ||||
|         history.append(arch) | ||||
|         if best_arch is None or best_acc < accuracy: | ||||
|             best_acc, best_arch = accuracy, arch | ||||
|         logger.log( | ||||
|             "[{:03d}] : {:} : accuracy = {:.2f}%".format(len(history), arch, accuracy) | ||||
|         ) | ||||
|         current_best_index.append(api.query_index_by_arch(best_arch)) | ||||
|     logger.log( | ||||
|         "{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s.".format( | ||||
|             time_string(), best_arch, best_acc, len(history), total_time_cost[-1] | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     info = api.query_info_str_by_arch( | ||||
|         best_arch, "200" if xargs.search_space == "tss" else "90" | ||||
|     ) | ||||
|     logger.log("{:}".format(info)) | ||||
|     logger.log("-" * 100) | ||||
|     logger.close() | ||||
|     return logger.log_dir, current_best_index, total_time_cost | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("Random NAS") | ||||
|     parser.add_argument( | ||||
|         "--dataset", | ||||
|         type=str, | ||||
|         choices=["cifar10", "cifar100", "ImageNet16-120"], | ||||
|         help="Choose between Cifar10/100 and ImageNet-16.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--search_space", | ||||
|         type=str, | ||||
|         choices=["tss", "sss"], | ||||
|         help="Choose the search space.", | ||||
|     ) | ||||
|  | ||||
|     parser.add_argument( | ||||
|         "--time_budget", | ||||
|         type=int, | ||||
|         default=20000, | ||||
|         help="The total time cost budge for searching (in seconds).", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--loops_if_rand", type=int, default=500, help="The total runs for evaluation." | ||||
|     ) | ||||
|     # log | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./output/search", | ||||
|         help="Folder to save checkpoints and log.", | ||||
|     ) | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     api = create(None, args.search_space, fast_mode=True, verbose=False) | ||||
|  | ||||
|     args.save_dir = os.path.join( | ||||
|         "{:}-{:}".format(args.save_dir, args.search_space), | ||||
|         "{:}-T{:}".format(args.dataset, args.time_budget), | ||||
|         "RANDOM", | ||||
|     ) | ||||
|     print("save-dir : {:}".format(args.save_dir)) | ||||
|  | ||||
|     if args.rand_seed < 0: | ||||
|         save_dir, all_info = None, collections.OrderedDict() | ||||
|         for i in range(args.loops_if_rand): | ||||
|             print("{:} : {:03d}/{:03d}".format(time_string(), i, args.loops_if_rand)) | ||||
|             args.rand_seed = random.randint(1, 100000) | ||||
|             save_dir, all_archs, all_total_times = main(args, api) | ||||
|             all_info[i] = {"all_archs": all_archs, "all_total_times": all_total_times} | ||||
|         save_path = save_dir / "results.pth" | ||||
|         print("save into {:}".format(save_path)) | ||||
|         torch.save(all_info, save_path) | ||||
|     else: | ||||
|         main(args, api) | ||||
							
								
								
									
										302
									
								
								AutoDL-Projects/exps/NATS-algos/regularized_ea.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										302
									
								
								AutoDL-Projects/exps/NATS-algos/regularized_ea.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,302 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ################################################################## | ||||
| # Regularized Evolution for Image Classifier Architecture Search # | ||||
| ################################################################## | ||||
| # python ./exps/NATS-algos/regularized_ea.py --dataset cifar10 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 | ||||
| # python ./exps/NATS-algos/regularized_ea.py --dataset cifar100 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 | ||||
| # python ./exps/NATS-algos/regularized_ea.py --dataset ImageNet16-120 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 | ||||
| # python ./exps/NATS-algos/regularized_ea.py --dataset cifar10 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 | ||||
| # python ./exps/NATS-algos/regularized_ea.py --dataset cifar100 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 | ||||
| # python ./exps/NATS-algos/regularized_ea.py --dataset ImageNet16-120 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 | ||||
| # python ./exps/NATS-algos/regularized_ea.py  --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --use_proxy 0 | ||||
| ################################################################## | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np, collections | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| from xautodl.config_utils import load_config, dict2config, configure2str | ||||
| from xautodl.datasets import get_datasets, SearchDataset | ||||
| from xautodl.procedures import ( | ||||
|     prepare_seed, | ||||
|     prepare_logger, | ||||
|     save_checkpoint, | ||||
|     copy_checkpoint, | ||||
|     get_optim_scheduler, | ||||
| ) | ||||
| from xautodl.utils import get_model_infos, obtain_accuracy | ||||
| from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.models import CellStructure, get_search_spaces | ||||
| from nats_bench import create | ||||
|  | ||||
|  | ||||
| class Model(object): | ||||
|     def __init__(self): | ||||
|         self.arch = None | ||||
|         self.accuracy = None | ||||
|  | ||||
|     def __str__(self): | ||||
|         """Prints a readable version of this bitstring.""" | ||||
|         return "{:}".format(self.arch) | ||||
|  | ||||
|  | ||||
| def random_topology_func(op_names, max_nodes=4): | ||||
|     # Return a random architecture | ||||
|     def random_architecture(): | ||||
|         genotypes = [] | ||||
|         for i in range(1, max_nodes): | ||||
|             xlist = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 op_name = random.choice(op_names) | ||||
|                 xlist.append((op_name, j)) | ||||
|             genotypes.append(tuple(xlist)) | ||||
|         return CellStructure(genotypes) | ||||
|  | ||||
|     return random_architecture | ||||
|  | ||||
|  | ||||
| def random_size_func(info): | ||||
|     # Return a random architecture | ||||
|     def random_architecture(): | ||||
|         channels = [] | ||||
|         for i in range(info["numbers"]): | ||||
|             channels.append(str(random.choice(info["candidates"]))) | ||||
|         return ":".join(channels) | ||||
|  | ||||
|     return random_architecture | ||||
|  | ||||
|  | ||||
| def mutate_topology_func(op_names): | ||||
|     """Computes the architecture for a child of the given parent architecture. | ||||
|     The parent architecture is cloned and mutated to produce the child architecture. The child architecture is mutated by randomly switch one operation to another. | ||||
|     """ | ||||
|  | ||||
|     def mutate_topology_func(parent_arch): | ||||
|         child_arch = deepcopy(parent_arch) | ||||
|         node_id = random.randint(0, len(child_arch.nodes) - 1) | ||||
|         node_info = list(child_arch.nodes[node_id]) | ||||
|         snode_id = random.randint(0, len(node_info) - 1) | ||||
|         xop = random.choice(op_names) | ||||
|         while xop == node_info[snode_id][0]: | ||||
|             xop = random.choice(op_names) | ||||
|         node_info[snode_id] = (xop, node_info[snode_id][1]) | ||||
|         child_arch.nodes[node_id] = tuple(node_info) | ||||
|         return child_arch | ||||
|  | ||||
|     return mutate_topology_func | ||||
|  | ||||
|  | ||||
| def mutate_size_func(info): | ||||
|     """Computes the architecture for a child of the given parent architecture. | ||||
|     The parent architecture is cloned and mutated to produce the child architecture. The child architecture is mutated by randomly switch one operation to another. | ||||
|     """ | ||||
|  | ||||
|     def mutate_size_func(parent_arch): | ||||
|         child_arch = deepcopy(parent_arch) | ||||
|         child_arch = child_arch.split(":") | ||||
|         index = random.randint(0, len(child_arch) - 1) | ||||
|         child_arch[index] = str(random.choice(info["candidates"])) | ||||
|         return ":".join(child_arch) | ||||
|  | ||||
|     return mutate_size_func | ||||
|  | ||||
|  | ||||
| def regularized_evolution( | ||||
|     cycles, | ||||
|     population_size, | ||||
|     sample_size, | ||||
|     time_budget, | ||||
|     random_arch, | ||||
|     mutate_arch, | ||||
|     api, | ||||
|     use_proxy, | ||||
|     dataset, | ||||
| ): | ||||
|     """Algorithm for regularized evolution (i.e. aging evolution). | ||||
|  | ||||
|     Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image | ||||
|     Classifier Architecture Search". | ||||
|  | ||||
|     Args: | ||||
|       cycles: the number of cycles the algorithm should run for. | ||||
|       population_size: the number of individuals to keep in the population. | ||||
|       sample_size: the number of individuals that should participate in each tournament. | ||||
|       time_budget: the upper bound of searching cost | ||||
|  | ||||
|     Returns: | ||||
|       history: a list of `Model` instances, representing all the models computed | ||||
|           during the evolution experiment. | ||||
|     """ | ||||
|     population = collections.deque() | ||||
|     api.reset_time() | ||||
|     history, total_time_cost = ( | ||||
|         [], | ||||
|         [], | ||||
|     )  # Not used by the algorithm, only used to report results. | ||||
|     current_best_index = [] | ||||
|     # Initialize the population with random models. | ||||
|     while len(population) < population_size: | ||||
|         model = Model() | ||||
|         model.arch = random_arch() | ||||
|         model.accuracy, _, _, total_cost = api.simulate_train_eval( | ||||
|             model.arch, dataset, hp="12" if use_proxy else api.full_train_epochs | ||||
|         ) | ||||
|         # Append the info | ||||
|         population.append(model) | ||||
|         history.append((model.accuracy, model.arch)) | ||||
|         total_time_cost.append(total_cost) | ||||
|         current_best_index.append( | ||||
|             api.query_index_by_arch(max(history, key=lambda x: x[0])[1]) | ||||
|         ) | ||||
|  | ||||
|     # Carry out evolution in cycles. Each cycle produces a model and removes another. | ||||
|     while total_time_cost[-1] < time_budget: | ||||
|         # Sample randomly chosen models from the current population. | ||||
|         start_time, sample = time.time(), [] | ||||
|         while len(sample) < sample_size: | ||||
|             # Inefficient, but written this way for clarity. In the case of neural | ||||
|             # nets, the efficiency of this line is irrelevant because training neural | ||||
|             # nets is the rate-determining step. | ||||
|             candidate = random.choice(list(population)) | ||||
|             sample.append(candidate) | ||||
|  | ||||
|         # The parent is the best model in the sample. | ||||
|         parent = max(sample, key=lambda i: i.accuracy) | ||||
|  | ||||
|         # Create the child model and store it. | ||||
|         child = Model() | ||||
|         child.arch = mutate_arch(parent.arch) | ||||
|         child.accuracy, _, _, total_cost = api.simulate_train_eval( | ||||
|             child.arch, dataset, hp="12" if use_proxy else api.full_train_epochs | ||||
|         ) | ||||
|         # Append the info | ||||
|         population.append(child) | ||||
|         history.append((child.accuracy, child.arch)) | ||||
|         current_best_index.append( | ||||
|             api.query_index_by_arch(max(history, key=lambda x: x[0])[1]) | ||||
|         ) | ||||
|         total_time_cost.append(total_cost) | ||||
|  | ||||
|         # Remove the oldest model. | ||||
|         population.popleft() | ||||
|     return history, current_best_index, total_time_cost | ||||
|  | ||||
|  | ||||
| def main(xargs, api): | ||||
|     torch.set_num_threads(4) | ||||
|     prepare_seed(xargs.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|  | ||||
|     search_space = get_search_spaces(xargs.search_space, "nats-bench") | ||||
|     if xargs.search_space == "tss": | ||||
|         random_arch = random_topology_func(search_space) | ||||
|         mutate_arch = mutate_topology_func(search_space) | ||||
|     else: | ||||
|         random_arch = random_size_func(search_space) | ||||
|         mutate_arch = mutate_size_func(search_space) | ||||
|  | ||||
|     x_start_time = time.time() | ||||
|     logger.log("{:} use api : {:}".format(time_string(), api)) | ||||
|     logger.log( | ||||
|         "-" * 30 | ||||
|         + " start searching with the time budget of {:} s".format(xargs.time_budget) | ||||
|     ) | ||||
|     history, current_best_index, total_times = regularized_evolution( | ||||
|         xargs.ea_cycles, | ||||
|         xargs.ea_population, | ||||
|         xargs.ea_sample_size, | ||||
|         xargs.time_budget, | ||||
|         random_arch, | ||||
|         mutate_arch, | ||||
|         api, | ||||
|         xargs.use_proxy > 0, | ||||
|         xargs.dataset, | ||||
|     ) | ||||
|     logger.log( | ||||
|         "{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).".format( | ||||
|             time_string(), len(history), total_times[-1], time.time() - x_start_time | ||||
|         ) | ||||
|     ) | ||||
|     best_arch = max(history, key=lambda x: x[0])[1] | ||||
|     logger.log("{:} best arch is {:}".format(time_string(), best_arch)) | ||||
|  | ||||
|     info = api.query_info_str_by_arch( | ||||
|         best_arch, "200" if xargs.search_space == "tss" else "90" | ||||
|     ) | ||||
|     logger.log("{:}".format(info)) | ||||
|     logger.log("-" * 100) | ||||
|     logger.close() | ||||
|     return logger.log_dir, current_best_index, total_times | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("Regularized Evolution Algorithm") | ||||
|     parser.add_argument( | ||||
|         "--dataset", | ||||
|         type=str, | ||||
|         choices=["cifar10", "cifar100", "ImageNet16-120"], | ||||
|         help="Choose between Cifar10/100 and ImageNet-16.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--search_space", | ||||
|         type=str, | ||||
|         choices=["tss", "sss"], | ||||
|         help="Choose the search space.", | ||||
|     ) | ||||
|     # hyperparameters for REA | ||||
|     parser.add_argument("--ea_cycles", type=int, help="The number of cycles in EA.") | ||||
|     parser.add_argument("--ea_population", type=int, help="The population size in EA.") | ||||
|     parser.add_argument("--ea_sample_size", type=int, help="The sample size in EA.") | ||||
|     parser.add_argument( | ||||
|         "--time_budget", | ||||
|         type=int, | ||||
|         default=20000, | ||||
|         help="The total time cost budge for searching (in seconds).", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--use_proxy", | ||||
|         type=int, | ||||
|         default=1, | ||||
|         help="Whether to use the proxy (H0) task or not.", | ||||
|     ) | ||||
|     # | ||||
|     parser.add_argument( | ||||
|         "--loops_if_rand", type=int, default=500, help="The total runs for evaluation." | ||||
|     ) | ||||
|     # log | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./output/search", | ||||
|         help="Folder to save checkpoints and log.", | ||||
|     ) | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     api = create(None, args.search_space, fast_mode=True, verbose=False) | ||||
|  | ||||
|     args.save_dir = os.path.join( | ||||
|         "{:}-{:}".format(args.save_dir, args.search_space), | ||||
|         "{:}-T{:}{:}".format( | ||||
|             args.dataset, args.time_budget, "" if args.use_proxy > 0 else "-FULL" | ||||
|         ), | ||||
|         "R-EA-SS{:}".format(args.ea_sample_size), | ||||
|     ) | ||||
|     print("save-dir : {:}".format(args.save_dir)) | ||||
|     print("xargs : {:}".format(args)) | ||||
|  | ||||
|     if args.rand_seed < 0: | ||||
|         save_dir, all_info = None, collections.OrderedDict() | ||||
|         for i in range(args.loops_if_rand): | ||||
|             print("{:} : {:03d}/{:03d}".format(time_string(), i, args.loops_if_rand)) | ||||
|             args.rand_seed = random.randint(1, 100000) | ||||
|             save_dir, all_archs, all_total_times = main(args, api) | ||||
|             all_info[i] = {"all_archs": all_archs, "all_total_times": all_total_times} | ||||
|         save_path = save_dir / "results.pth" | ||||
|         print("save into {:}".format(save_path)) | ||||
|         torch.save(all_info, save_path) | ||||
|     else: | ||||
|         main(args, api) | ||||
							
								
								
									
										268
									
								
								AutoDL-Projects/exps/NATS-algos/reinforce.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										268
									
								
								AutoDL-Projects/exps/NATS-algos/reinforce.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,268 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ##################################################################################################### | ||||
| # modified from https://github.com/pytorch/examples/blob/master/reinforcement_learning/reinforce.py # | ||||
| ##################################################################################################### | ||||
| # python ./exps/NATS-algos/reinforce.py --dataset cifar10 --search_space tss --learning_rate 0.01 | ||||
| # python ./exps/NATS-algos/reinforce.py --dataset cifar100 --search_space tss --learning_rate 0.01 | ||||
| # python ./exps/NATS-algos/reinforce.py --dataset ImageNet16-120 --search_space tss --learning_rate 0.01 | ||||
| # python ./exps/NATS-algos/reinforce.py --dataset cifar10 --search_space sss --learning_rate 0.01 | ||||
| # python ./exps/NATS-algos/reinforce.py --dataset cifar100 --search_space sss --learning_rate 0.01 | ||||
| # python ./exps/NATS-algos/reinforce.py --dataset ImageNet16-120 --search_space sss --learning_rate 0.01 | ||||
| ##################################################################################################### | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np, collections | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from torch.distributions import Categorical | ||||
|  | ||||
| from xautodl.config_utils import load_config, dict2config, configure2str | ||||
| from xautodl.datasets import get_datasets, SearchDataset | ||||
| from xautodl.procedures import ( | ||||
|     prepare_seed, | ||||
|     prepare_logger, | ||||
|     save_checkpoint, | ||||
|     copy_checkpoint, | ||||
|     get_optim_scheduler, | ||||
| ) | ||||
| from xautodl.utils import get_model_infos, obtain_accuracy | ||||
| from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.models import CellStructure, get_search_spaces | ||||
| from nats_bench import create | ||||
|  | ||||
|  | ||||
| class PolicyTopology(nn.Module): | ||||
|     def __init__(self, search_space, max_nodes=4): | ||||
|         super(PolicyTopology, self).__init__() | ||||
|         self.max_nodes = max_nodes | ||||
|         self.search_space = deepcopy(search_space) | ||||
|         self.edge2index = {} | ||||
|         for i in range(1, max_nodes): | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 self.edge2index[node_str] = len(self.edge2index) | ||||
|         self.arch_parameters = nn.Parameter( | ||||
|             1e-3 * torch.randn(len(self.edge2index), len(search_space)) | ||||
|         ) | ||||
|  | ||||
|     def generate_arch(self, actions): | ||||
|         genotypes = [] | ||||
|         for i in range(1, self.max_nodes): | ||||
|             xlist = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 op_name = self.search_space[actions[self.edge2index[node_str]]] | ||||
|                 xlist.append((op_name, j)) | ||||
|             genotypes.append(tuple(xlist)) | ||||
|         return CellStructure(genotypes) | ||||
|  | ||||
|     def genotype(self): | ||||
|         genotypes = [] | ||||
|         for i in range(1, self.max_nodes): | ||||
|             xlist = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 with torch.no_grad(): | ||||
|                     weights = self.arch_parameters[self.edge2index[node_str]] | ||||
|                     op_name = self.search_space[weights.argmax().item()] | ||||
|                 xlist.append((op_name, j)) | ||||
|             genotypes.append(tuple(xlist)) | ||||
|         return CellStructure(genotypes) | ||||
|  | ||||
|     def forward(self): | ||||
|         alphas = nn.functional.softmax(self.arch_parameters, dim=-1) | ||||
|         return alphas | ||||
|  | ||||
|  | ||||
| class PolicySize(nn.Module): | ||||
|     def __init__(self, search_space): | ||||
|         super(PolicySize, self).__init__() | ||||
|         self.candidates = search_space["candidates"] | ||||
|         self.numbers = search_space["numbers"] | ||||
|         self.arch_parameters = nn.Parameter( | ||||
|             1e-3 * torch.randn(self.numbers, len(self.candidates)) | ||||
|         ) | ||||
|  | ||||
|     def generate_arch(self, actions): | ||||
|         channels = [str(self.candidates[i]) for i in actions] | ||||
|         return ":".join(channels) | ||||
|  | ||||
|     def genotype(self): | ||||
|         channels = [] | ||||
|         for i in range(self.numbers): | ||||
|             index = self.arch_parameters[i].argmax().item() | ||||
|             channels.append(str(self.candidates[index])) | ||||
|         return ":".join(channels) | ||||
|  | ||||
|     def forward(self): | ||||
|         alphas = nn.functional.softmax(self.arch_parameters, dim=-1) | ||||
|         return alphas | ||||
|  | ||||
|  | ||||
| class ExponentialMovingAverage(object): | ||||
|     """Class that maintains an exponential moving average.""" | ||||
|  | ||||
|     def __init__(self, momentum): | ||||
|         self._numerator = 0 | ||||
|         self._denominator = 0 | ||||
|         self._momentum = momentum | ||||
|  | ||||
|     def update(self, value): | ||||
|         self._numerator = ( | ||||
|             self._momentum * self._numerator + (1 - self._momentum) * value | ||||
|         ) | ||||
|         self._denominator = self._momentum * self._denominator + (1 - self._momentum) | ||||
|  | ||||
|     def value(self): | ||||
|         """Return the current value of the moving average""" | ||||
|         return self._numerator / self._denominator | ||||
|  | ||||
|  | ||||
| def select_action(policy): | ||||
|     probs = policy() | ||||
|     m = Categorical(probs) | ||||
|     action = m.sample() | ||||
|     # policy.saved_log_probs.append(m.log_prob(action)) | ||||
|     return m.log_prob(action), action.cpu().tolist() | ||||
|  | ||||
|  | ||||
| def main(xargs, api): | ||||
|     # torch.set_num_threads(4) | ||||
|     prepare_seed(xargs.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|  | ||||
|     search_space = get_search_spaces(xargs.search_space, "nats-bench") | ||||
|     if xargs.search_space == "tss": | ||||
|         policy = PolicyTopology(search_space) | ||||
|     else: | ||||
|         policy = PolicySize(search_space) | ||||
|     optimizer = torch.optim.Adam(policy.parameters(), lr=xargs.learning_rate) | ||||
|     # optimizer = torch.optim.SGD(policy.parameters(), lr=xargs.learning_rate) | ||||
|     eps = np.finfo(np.float32).eps.item() | ||||
|     baseline = ExponentialMovingAverage(xargs.EMA_momentum) | ||||
|     logger.log("policy    : {:}".format(policy)) | ||||
|     logger.log("optimizer : {:}".format(optimizer)) | ||||
|     logger.log("eps       : {:}".format(eps)) | ||||
|  | ||||
|     # nas dataset load | ||||
|     logger.log("{:} use api : {:}".format(time_string(), api)) | ||||
|     api.reset_time() | ||||
|  | ||||
|     # REINFORCE | ||||
|     x_start_time = time.time() | ||||
|     logger.log( | ||||
|         "Will start searching with time budget of {:} s.".format(xargs.time_budget) | ||||
|     ) | ||||
|     total_steps, total_costs, trace = 0, [], [] | ||||
|     current_best_index = [] | ||||
|     while len(total_costs) == 0 or total_costs[-1] < xargs.time_budget: | ||||
|         start_time = time.time() | ||||
|         log_prob, action = select_action(policy) | ||||
|         arch = policy.generate_arch(action) | ||||
|         reward, _, _, current_total_cost = api.simulate_train_eval( | ||||
|             arch, xargs.dataset, hp="12" | ||||
|         ) | ||||
|         trace.append((reward, arch)) | ||||
|         total_costs.append(current_total_cost) | ||||
|  | ||||
|         baseline.update(reward) | ||||
|         # calculate loss | ||||
|         policy_loss = (-log_prob * (reward - baseline.value())).sum() | ||||
|         optimizer.zero_grad() | ||||
|         policy_loss.backward() | ||||
|         optimizer.step() | ||||
|         # accumulate time | ||||
|         total_steps += 1 | ||||
|         logger.log( | ||||
|             "step [{:3d}] : average-reward={:.3f} : policy_loss={:.4f} : {:}".format( | ||||
|                 total_steps, baseline.value(), policy_loss.item(), policy.genotype() | ||||
|             ) | ||||
|         ) | ||||
|         # to analyze | ||||
|         current_best_index.append( | ||||
|             api.query_index_by_arch(max(trace, key=lambda x: x[0])[1]) | ||||
|         ) | ||||
|     # best_arch = policy.genotype() # first version | ||||
|     best_arch = max(trace, key=lambda x: x[0])[1] | ||||
|     logger.log( | ||||
|         "REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).".format( | ||||
|             total_steps, total_costs[-1], time.time() - x_start_time | ||||
|         ) | ||||
|     ) | ||||
|     info = api.query_info_str_by_arch( | ||||
|         best_arch, "200" if xargs.search_space == "tss" else "90" | ||||
|     ) | ||||
|     logger.log("{:}".format(info)) | ||||
|     logger.log("-" * 100) | ||||
|     logger.close() | ||||
|  | ||||
|     return logger.log_dir, current_best_index, total_costs | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("The REINFORCE Algorithm") | ||||
|     parser.add_argument( | ||||
|         "--dataset", | ||||
|         type=str, | ||||
|         choices=["cifar10", "cifar100", "ImageNet16-120"], | ||||
|         help="Choose between Cifar10/100 and ImageNet-16.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--search_space", | ||||
|         type=str, | ||||
|         choices=["tss", "sss"], | ||||
|         help="Choose the search space.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--learning_rate", type=float, help="The learning rate for REINFORCE." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--EMA_momentum", type=float, default=0.9, help="The momentum value for EMA." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--time_budget", | ||||
|         type=int, | ||||
|         default=20000, | ||||
|         help="The total time cost budge for searching (in seconds).", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--loops_if_rand", type=int, default=500, help="The total runs for evaluation." | ||||
|     ) | ||||
|     # log | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./output/search", | ||||
|         help="Folder to save checkpoints and log.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--arch_nas_dataset", | ||||
|         type=str, | ||||
|         help="The path to load the architecture dataset (tiny-nas-benchmark).", | ||||
|     ) | ||||
|     parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     api = create(None, args.search_space, fast_mode=True, verbose=False) | ||||
|  | ||||
|     args.save_dir = os.path.join( | ||||
|         "{:}-{:}".format(args.save_dir, args.search_space), | ||||
|         "{:}-T{:}".format(args.dataset, args.time_budget), | ||||
|         "REINFORCE-{:}".format(args.learning_rate), | ||||
|     ) | ||||
|     print("save-dir : {:}".format(args.save_dir)) | ||||
|  | ||||
|     if args.rand_seed < 0: | ||||
|         save_dir, all_info = None, collections.OrderedDict() | ||||
|         for i in range(args.loops_if_rand): | ||||
|             print("{:} : {:03d}/{:03d}".format(time_string(), i, args.loops_if_rand)) | ||||
|             args.rand_seed = random.randint(1, 100000) | ||||
|             save_dir, all_archs, all_total_times = main(args, api) | ||||
|             all_info[i] = {"all_archs": all_archs, "all_total_times": all_total_times} | ||||
|         save_path = save_dir / "results.pth" | ||||
|         print("save into {:}".format(save_path)) | ||||
|         torch.save(all_info, save_path) | ||||
|     else: | ||||
|         main(args, api) | ||||
							
								
								
									
										51
									
								
								AutoDL-Projects/exps/NATS-algos/run-all.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								AutoDL-Projects/exps/NATS-algos/run-all.sh
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,51 @@ | ||||
| #!/bin/bash | ||||
| # bash ./exps/NATS-algos/run-all.sh mul | ||||
| # bash ./exps/NATS-algos/run-all.sh ws | ||||
| set -e | ||||
| echo script name: $0 | ||||
| echo $# arguments | ||||
| if [ "$#" -ne 1 ] ;then | ||||
|   echo "Input illegal number of parameters " $# | ||||
|   echo "Need 1 parameters for type of algorithms." | ||||
|   exit 1 | ||||
| fi | ||||
|  | ||||
| alg_type=$1 | ||||
|  | ||||
| if [ "$alg_type" == "mul" ]; then | ||||
|   # datasets="cifar10 cifar100 ImageNet16-120" | ||||
|   run_four_algorithms(){ | ||||
|     dataset=$1 | ||||
|     search_space=$2 | ||||
|     time_budget=$3 | ||||
|     python ./exps/NATS-algos/reinforce.py       --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --learning_rate 0.01 | ||||
|     python ./exps/NATS-algos/regularized_ea.py  --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --ea_cycles 200 --ea_population 10 --ea_sample_size 3 | ||||
|     python ./exps/NATS-algos/random_wo_share.py --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} | ||||
|     python ./exps/NATS-algos/bohb.py            --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 | ||||
|   } | ||||
|   # The topology search space | ||||
|   run_four_algorithms "cifar10"        "tss" "20000" | ||||
|   run_four_algorithms "cifar100"       "tss" "40000" | ||||
|   run_four_algorithms "ImageNet16-120" "tss" "120000" | ||||
|  | ||||
|   # The size search space | ||||
|   run_four_algorithms "cifar10"        "sss" "20000" | ||||
|   run_four_algorithms "cifar100"       "sss" "40000" | ||||
|   run_four_algorithms "ImageNet16-120" "sss" "60000" | ||||
|   # python exps/experimental/vis-bench-algos.py --search_space tss | ||||
|   # python exps/experimental/vis-bench-algos.py --search_space sss | ||||
| else | ||||
|   seeds="777 888 999" | ||||
|   algos="darts-v1 darts-v2 gdas setn random enas" | ||||
|   epoch=200 | ||||
|   for seed in ${seeds} | ||||
|   do | ||||
|     for alg in ${algos} | ||||
|     do | ||||
|       python ./exps/NATS-algos/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo ${alg} --rand_seed ${seed} --overwite_epochs ${epoch} | ||||
|       python ./exps/NATS-algos/search-cell.py --dataset cifar100  --data_path $TORCH_HOME/cifar.python --algo ${alg} --rand_seed ${seed} --overwite_epochs ${epoch} | ||||
|       python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120  --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo ${alg} --rand_seed ${seed} --overwite_epochs ${epoch} | ||||
|     done | ||||
|   done | ||||
| fi | ||||
|  | ||||
							
								
								
									
										879
									
								
								AutoDL-Projects/exps/NATS-algos/search-cell.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										879
									
								
								AutoDL-Projects/exps/NATS-algos/search-cell.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,879 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ###################################################################################### | ||||
| # python ./exps/NATS-algos/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo darts-v1 --rand_seed 777 | ||||
| # python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v1 --drop_path_rate 0.3 | ||||
| # python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v1 | ||||
| #### | ||||
| # python ./exps/NATS-algos/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo darts-v2 --rand_seed 777 | ||||
| # python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v2 | ||||
| # python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v2 | ||||
| #### | ||||
| # python ./exps/NATS-algos/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo gdas --rand_seed 777 | ||||
| # python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo gdas | ||||
| # python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo gdas | ||||
| #### | ||||
| # python ./exps/NATS-algos/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo setn --rand_seed 777 | ||||
| # python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo setn | ||||
| # python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo setn | ||||
| #### | ||||
| # python ./exps/NATS-algos/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo random --rand_seed 777 | ||||
| # python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo random | ||||
| # python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo random | ||||
| #### | ||||
| # python ./exps/NATS-algos/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777 | ||||
| # python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777 | ||||
| # python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777 | ||||
| #### | ||||
| # The following scripts are added in 20 Mar 2022 | ||||
| # python ./exps/NATS-algos/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo gdas_v1 --rand_seed 777 | ||||
| ###################################################################################### | ||||
| import os, sys, time, random, argparse | ||||
| import numpy as np | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| from xautodl.config_utils import load_config, dict2config, configure2str | ||||
| from xautodl.datasets import get_datasets, get_nas_search_loaders | ||||
| from xautodl.procedures import ( | ||||
|     prepare_seed, | ||||
|     prepare_logger, | ||||
|     save_checkpoint, | ||||
|     copy_checkpoint, | ||||
|     get_optim_scheduler, | ||||
| ) | ||||
| from xautodl.utils import count_parameters_in_MB, obtain_accuracy | ||||
| from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.models import get_cell_based_tiny_net, get_search_spaces | ||||
| from nats_bench import create | ||||
|  | ||||
|  | ||||
| # The following three functions are used for DARTS-V2 | ||||
| def _concat(xs): | ||||
|     return torch.cat([x.view(-1) for x in xs]) | ||||
|  | ||||
|  | ||||
| def _hessian_vector_product( | ||||
|     vector, network, criterion, base_inputs, base_targets, r=1e-2 | ||||
| ): | ||||
|     R = r / _concat(vector).norm() | ||||
|     for p, v in zip(network.weights, vector): | ||||
|         p.data.add_(R, v) | ||||
|     _, logits = network(base_inputs) | ||||
|     loss = criterion(logits, base_targets) | ||||
|     grads_p = torch.autograd.grad(loss, network.alphas) | ||||
|  | ||||
|     for p, v in zip(network.weights, vector): | ||||
|         p.data.sub_(2 * R, v) | ||||
|     _, logits = network(base_inputs) | ||||
|     loss = criterion(logits, base_targets) | ||||
|     grads_n = torch.autograd.grad(loss, network.alphas) | ||||
|  | ||||
|     for p, v in zip(network.weights, vector): | ||||
|         p.data.add_(R, v) | ||||
|     return [(x - y).div_(2 * R) for x, y in zip(grads_p, grads_n)] | ||||
|  | ||||
|  | ||||
| def backward_step_unrolled( | ||||
|     network, | ||||
|     criterion, | ||||
|     base_inputs, | ||||
|     base_targets, | ||||
|     w_optimizer, | ||||
|     arch_inputs, | ||||
|     arch_targets, | ||||
| ): | ||||
|     # _compute_unrolled_model | ||||
|     _, logits = network(base_inputs) | ||||
|     loss = criterion(logits, base_targets) | ||||
|     LR, WD, momentum = ( | ||||
|         w_optimizer.param_groups[0]["lr"], | ||||
|         w_optimizer.param_groups[0]["weight_decay"], | ||||
|         w_optimizer.param_groups[0]["momentum"], | ||||
|     ) | ||||
|     with torch.no_grad(): | ||||
|         theta = _concat(network.weights) | ||||
|         try: | ||||
|             moment = _concat( | ||||
|                 w_optimizer.state[v]["momentum_buffer"] for v in network.weights | ||||
|             ) | ||||
|             moment = moment.mul_(momentum) | ||||
|         except: | ||||
|             moment = torch.zeros_like(theta) | ||||
|         dtheta = _concat(torch.autograd.grad(loss, network.weights)) + WD * theta | ||||
|         params = theta.sub(LR, moment + dtheta) | ||||
|     unrolled_model = deepcopy(network) | ||||
|     model_dict = unrolled_model.state_dict() | ||||
|     new_params, offset = {}, 0 | ||||
|     for k, v in network.named_parameters(): | ||||
|         if "arch_parameters" in k: | ||||
|             continue | ||||
|         v_length = np.prod(v.size()) | ||||
|         new_params[k] = params[offset : offset + v_length].view(v.size()) | ||||
|         offset += v_length | ||||
|     model_dict.update(new_params) | ||||
|     unrolled_model.load_state_dict(model_dict) | ||||
|  | ||||
|     unrolled_model.zero_grad() | ||||
|     _, unrolled_logits = unrolled_model(arch_inputs) | ||||
|     unrolled_loss = criterion(unrolled_logits, arch_targets) | ||||
|     unrolled_loss.backward() | ||||
|  | ||||
|     dalpha = unrolled_model.arch_parameters.grad | ||||
|     vector = [v.grad.data for v in unrolled_model.weights] | ||||
|     [implicit_grads] = _hessian_vector_product( | ||||
|         vector, network, criterion, base_inputs, base_targets | ||||
|     ) | ||||
|  | ||||
|     dalpha.data.sub_(LR, implicit_grads.data) | ||||
|  | ||||
|     if network.arch_parameters.grad is None: | ||||
|         network.arch_parameters.grad = deepcopy(dalpha) | ||||
|     else: | ||||
|         network.arch_parameters.grad.data.copy_(dalpha.data) | ||||
|     return unrolled_loss.detach(), unrolled_logits.detach() | ||||
|  | ||||
|  | ||||
| def search_func( | ||||
|     xloader, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     w_optimizer, | ||||
|     a_optimizer, | ||||
|     epoch_str, | ||||
|     print_freq, | ||||
|     algo, | ||||
|     logger, | ||||
| ): | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     end = time.time() | ||||
|     network.train() | ||||
|     for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate( | ||||
|         xloader | ||||
|     ): | ||||
|         scheduler.update(None, 1.0 * step / len(xloader)) | ||||
|         base_inputs = base_inputs.cuda(non_blocking=True) | ||||
|         arch_inputs = arch_inputs.cuda(non_blocking=True) | ||||
|         base_targets = base_targets.cuda(non_blocking=True) | ||||
|         arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - end) | ||||
|  | ||||
|         # Update the weights | ||||
|         if algo == "setn": | ||||
|             sampled_arch = network.dync_genotype(True) | ||||
|             network.set_cal_mode("dynamic", sampled_arch) | ||||
|         elif algo == "gdas": | ||||
|             network.set_cal_mode("gdas", None) | ||||
|         elif algo == "gdas_v1": | ||||
|             network.set_cal_mode("gdas_v1", None) | ||||
|         elif algo.startswith("darts"): | ||||
|             network.set_cal_mode("joint", None) | ||||
|         elif algo == "random": | ||||
|             network.set_cal_mode("urs", None) | ||||
|         elif algo == "enas": | ||||
|             with torch.no_grad(): | ||||
|                 network.controller.eval() | ||||
|                 _, _, sampled_arch = network.controller() | ||||
|             network.set_cal_mode("dynamic", sampled_arch) | ||||
|         else: | ||||
|             raise ValueError("Invalid algo name : {:}".format(algo)) | ||||
|  | ||||
|         network.zero_grad() | ||||
|         _, logits = network(base_inputs) | ||||
|         base_loss = criterion(logits, base_targets) | ||||
|         base_loss.backward() | ||||
|         w_optimizer.step() | ||||
|         # record | ||||
|         base_prec1, base_prec5 = obtain_accuracy( | ||||
|             logits.data, base_targets.data, topk=(1, 5) | ||||
|         ) | ||||
|         base_losses.update(base_loss.item(), base_inputs.size(0)) | ||||
|         base_top1.update(base_prec1.item(), base_inputs.size(0)) | ||||
|         base_top5.update(base_prec5.item(), base_inputs.size(0)) | ||||
|  | ||||
|         # update the architecture-weight | ||||
|         if algo == "setn": | ||||
|             network.set_cal_mode("joint") | ||||
|         elif algo == "gdas": | ||||
|             network.set_cal_mode("gdas", None) | ||||
|         elif algo == "gdas_v1": | ||||
|             network.set_cal_mode("gdas_v1", None) | ||||
|         elif algo.startswith("darts"): | ||||
|             network.set_cal_mode("joint", None) | ||||
|         elif algo == "random": | ||||
|             network.set_cal_mode("urs", None) | ||||
|         elif algo != "enas": | ||||
|             raise ValueError("Invalid algo name : {:}".format(algo)) | ||||
|         network.zero_grad() | ||||
|         if algo == "darts-v2": | ||||
|             arch_loss, logits = backward_step_unrolled( | ||||
|                 network, | ||||
|                 criterion, | ||||
|                 base_inputs, | ||||
|                 base_targets, | ||||
|                 w_optimizer, | ||||
|                 arch_inputs, | ||||
|                 arch_targets, | ||||
|             ) | ||||
|             a_optimizer.step() | ||||
|         elif algo == "random" or algo == "enas": | ||||
|             with torch.no_grad(): | ||||
|                 _, logits = network(arch_inputs) | ||||
|                 arch_loss = criterion(logits, arch_targets) | ||||
|         else: | ||||
|             _, logits = network(arch_inputs) | ||||
|             arch_loss = criterion(logits, arch_targets) | ||||
|             arch_loss.backward() | ||||
|             a_optimizer.step() | ||||
|         # record | ||||
|         arch_prec1, arch_prec5 = obtain_accuracy( | ||||
|             logits.data, arch_targets.data, topk=(1, 5) | ||||
|         ) | ||||
|         arch_losses.update(arch_loss.item(), arch_inputs.size(0)) | ||||
|         arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) | ||||
|         arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|  | ||||
|         if step % print_freq == 0 or step + 1 == len(xloader): | ||||
|             Sstr = ( | ||||
|                 "*SEARCH* " | ||||
|                 + time_string() | ||||
|                 + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) | ||||
|             ) | ||||
|             Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                 batch_time=batch_time, data_time=data_time | ||||
|             ) | ||||
|             Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( | ||||
|                 loss=base_losses, top1=base_top1, top5=base_top5 | ||||
|             ) | ||||
|             Astr = "Arch [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( | ||||
|                 loss=arch_losses, top1=arch_top1, top5=arch_top5 | ||||
|             ) | ||||
|             logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr) | ||||
|     return ( | ||||
|         base_losses.avg, | ||||
|         base_top1.avg, | ||||
|         base_top5.avg, | ||||
|         arch_losses.avg, | ||||
|         arch_top1.avg, | ||||
|         arch_top5.avg, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def train_controller( | ||||
|     xloader, network, criterion, optimizer, prev_baseline, epoch_str, print_freq, logger | ||||
| ): | ||||
|     # config. (containing some necessary arg) | ||||
|     #   baseline: The baseline score (i.e. average val_acc) from the previous epoch | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     ( | ||||
|         GradnormMeter, | ||||
|         LossMeter, | ||||
|         ValAccMeter, | ||||
|         EntropyMeter, | ||||
|         BaselineMeter, | ||||
|         RewardMeter, | ||||
|         xend, | ||||
|     ) = ( | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         time.time(), | ||||
|     ) | ||||
|  | ||||
|     controller_num_aggregate = 20 | ||||
|     controller_train_steps = 50 | ||||
|     controller_bl_dec = 0.99 | ||||
|     controller_entropy_weight = 0.0001 | ||||
|  | ||||
|     network.eval() | ||||
|     network.controller.train() | ||||
|     network.controller.zero_grad() | ||||
|     loader_iter = iter(xloader) | ||||
|     for step in range(controller_train_steps * controller_num_aggregate): | ||||
|         try: | ||||
|             inputs, targets = next(loader_iter) | ||||
|         except: | ||||
|             loader_iter = iter(xloader) | ||||
|             inputs, targets = next(loader_iter) | ||||
|         inputs = inputs.cuda(non_blocking=True) | ||||
|         targets = targets.cuda(non_blocking=True) | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - xend) | ||||
|  | ||||
|         log_prob, entropy, sampled_arch = network.controller() | ||||
|         with torch.no_grad(): | ||||
|             network.set_cal_mode("dynamic", sampled_arch) | ||||
|             _, logits = network(inputs) | ||||
|             val_top1, val_top5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|             val_top1 = val_top1.view(-1) / 100 | ||||
|         reward = val_top1 + controller_entropy_weight * entropy | ||||
|         if prev_baseline is None: | ||||
|             baseline = val_top1 | ||||
|         else: | ||||
|             baseline = prev_baseline - (1 - controller_bl_dec) * ( | ||||
|                 prev_baseline - reward | ||||
|             ) | ||||
|  | ||||
|         loss = -1 * log_prob * (reward - baseline) | ||||
|  | ||||
|         # account | ||||
|         RewardMeter.update(reward.item()) | ||||
|         BaselineMeter.update(baseline.item()) | ||||
|         ValAccMeter.update(val_top1.item() * 100) | ||||
|         LossMeter.update(loss.item()) | ||||
|         EntropyMeter.update(entropy.item()) | ||||
|  | ||||
|         # Average gradient over controller_num_aggregate samples | ||||
|         loss = loss / controller_num_aggregate | ||||
|         loss.backward(retain_graph=True) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - xend) | ||||
|         xend = time.time() | ||||
|         if (step + 1) % controller_num_aggregate == 0: | ||||
|             grad_norm = torch.nn.utils.clip_grad_norm_( | ||||
|                 network.controller.parameters(), 5.0 | ||||
|             ) | ||||
|             GradnormMeter.update(grad_norm) | ||||
|             optimizer.step() | ||||
|             network.controller.zero_grad() | ||||
|  | ||||
|         if step % print_freq == 0: | ||||
|             Sstr = ( | ||||
|                 "*Train-Controller* " | ||||
|                 + time_string() | ||||
|                 + " [{:}][{:03d}/{:03d}]".format( | ||||
|                     epoch_str, step, controller_train_steps * controller_num_aggregate | ||||
|                 ) | ||||
|             ) | ||||
|             Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                 batch_time=batch_time, data_time=data_time | ||||
|             ) | ||||
|             Wstr = "[Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Reward {reward.val:.2f} ({reward.avg:.2f})] Baseline {basel.val:.2f} ({basel.avg:.2f})".format( | ||||
|                 loss=LossMeter, | ||||
|                 top1=ValAccMeter, | ||||
|                 reward=RewardMeter, | ||||
|                 basel=BaselineMeter, | ||||
|             ) | ||||
|             Estr = "Entropy={:.4f} ({:.4f})".format(EntropyMeter.val, EntropyMeter.avg) | ||||
|             logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Estr) | ||||
|  | ||||
|     return LossMeter.avg, ValAccMeter.avg, BaselineMeter.avg, RewardMeter.avg | ||||
|  | ||||
|  | ||||
| def get_best_arch(xloader, network, n_samples, algo): | ||||
|     with torch.no_grad(): | ||||
|         network.eval() | ||||
|         if algo == "random": | ||||
|             archs, valid_accs = network.return_topK(n_samples, True), [] | ||||
|         elif algo == "setn": | ||||
|             archs, valid_accs = network.return_topK(n_samples, False), [] | ||||
|         elif algo.startswith("darts") or algo == "gdas" or algo == "gdas_v1": | ||||
|             arch = network.genotype | ||||
|             archs, valid_accs = [arch], [] | ||||
|         elif algo == "enas": | ||||
|             archs, valid_accs = [], [] | ||||
|             for _ in range(n_samples): | ||||
|                 _, _, sampled_arch = network.controller() | ||||
|                 archs.append(sampled_arch) | ||||
|         else: | ||||
|             raise ValueError("Invalid algorithm name : {:}".format(algo)) | ||||
|         loader_iter = iter(xloader) | ||||
|         for i, sampled_arch in enumerate(archs): | ||||
|             network.set_cal_mode("dynamic", sampled_arch) | ||||
|             try: | ||||
|                 inputs, targets = next(loader_iter) | ||||
|             except: | ||||
|                 loader_iter = iter(xloader) | ||||
|                 inputs, targets = next(loader_iter) | ||||
|             _, logits = network(inputs.cuda(non_blocking=True)) | ||||
|             val_top1, val_top5 = obtain_accuracy( | ||||
|                 logits.cpu().data, targets.data, topk=(1, 5) | ||||
|             ) | ||||
|             valid_accs.append(val_top1.item()) | ||||
|         best_idx = np.argmax(valid_accs) | ||||
|         best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx] | ||||
|         return best_arch, best_valid_acc | ||||
|  | ||||
|  | ||||
| def valid_func(xloader, network, criterion, algo, logger): | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     end = time.time() | ||||
|     with torch.no_grad(): | ||||
|         network.eval() | ||||
|         for step, (arch_inputs, arch_targets) in enumerate(xloader): | ||||
|             arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|             # measure data loading time | ||||
|             data_time.update(time.time() - end) | ||||
|             # prediction | ||||
|             _, logits = network(arch_inputs.cuda(non_blocking=True)) | ||||
|             arch_loss = criterion(logits, arch_targets) | ||||
|             # record | ||||
|             arch_prec1, arch_prec5 = obtain_accuracy( | ||||
|                 logits.data, arch_targets.data, topk=(1, 5) | ||||
|             ) | ||||
|             arch_losses.update(arch_loss.item(), arch_inputs.size(0)) | ||||
|             arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) | ||||
|             arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) | ||||
|             # measure elapsed time | ||||
|             batch_time.update(time.time() - end) | ||||
|             end = time.time() | ||||
|     return arch_losses.avg, arch_top1.avg, arch_top5.avg | ||||
|  | ||||
|  | ||||
| def main(xargs): | ||||
|     assert torch.cuda.is_available(), "CUDA is not available." | ||||
|     torch.backends.cudnn.enabled = True | ||||
|     torch.backends.cudnn.benchmark = False | ||||
|     torch.backends.cudnn.deterministic = True | ||||
|     torch.set_num_threads(xargs.workers) | ||||
|     prepare_seed(xargs.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|  | ||||
|     train_data, valid_data, xshape, class_num = get_datasets( | ||||
|         xargs.dataset, xargs.data_path, -1 | ||||
|     ) | ||||
|     if xargs.overwite_epochs is None: | ||||
|         extra_info = {"class_num": class_num, "xshape": xshape} | ||||
|     else: | ||||
|         extra_info = { | ||||
|             "class_num": class_num, | ||||
|             "xshape": xshape, | ||||
|             "epochs": xargs.overwite_epochs, | ||||
|         } | ||||
|     config = load_config(xargs.config_path, extra_info, logger) | ||||
|     search_loader, train_loader, valid_loader = get_nas_search_loaders( | ||||
|         train_data, | ||||
|         valid_data, | ||||
|         xargs.dataset, | ||||
|         "configs/nas-benchmark/", | ||||
|         (config.batch_size, config.test_batch_size), | ||||
|         xargs.workers, | ||||
|     ) | ||||
|     logger.log( | ||||
|         "||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( | ||||
|             xargs.dataset, len(search_loader), len(valid_loader), config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) | ||||
|  | ||||
|     search_space = get_search_spaces(xargs.search_space, "nats-bench") | ||||
|  | ||||
|     model_config = dict2config( | ||||
|         dict( | ||||
|             name="generic", | ||||
|             C=xargs.channel, | ||||
|             N=xargs.num_cells, | ||||
|             max_nodes=xargs.max_nodes, | ||||
|             num_classes=class_num, | ||||
|             space=search_space, | ||||
|             affine=bool(xargs.affine), | ||||
|             track_running_stats=bool(xargs.track_running_stats), | ||||
|         ), | ||||
|         None, | ||||
|     ) | ||||
|     logger.log("search space : {:}".format(search_space)) | ||||
|     logger.log("model config : {:}".format(model_config)) | ||||
|     search_model = get_cell_based_tiny_net(model_config) | ||||
|     search_model.set_algo(xargs.algo) | ||||
|     logger.log("{:}".format(search_model)) | ||||
|  | ||||
|     w_optimizer, w_scheduler, criterion = get_optim_scheduler( | ||||
|         search_model.weights, config | ||||
|     ) | ||||
|     a_optimizer = torch.optim.Adam( | ||||
|         search_model.alphas, | ||||
|         lr=xargs.arch_learning_rate, | ||||
|         betas=(0.5, 0.999), | ||||
|         weight_decay=xargs.arch_weight_decay, | ||||
|         eps=xargs.arch_eps, | ||||
|     ) | ||||
|     logger.log("w-optimizer : {:}".format(w_optimizer)) | ||||
|     logger.log("a-optimizer : {:}".format(a_optimizer)) | ||||
|     logger.log("w-scheduler : {:}".format(w_scheduler)) | ||||
|     logger.log("criterion   : {:}".format(criterion)) | ||||
|     params = count_parameters_in_MB(search_model) | ||||
|     logger.log("The parameters of the search model = {:.2f} MB".format(params)) | ||||
|     logger.log("search-space : {:}".format(search_space)) | ||||
|     if bool(xargs.use_api): | ||||
|         api = create(None, "topology", fast_mode=True, verbose=False) | ||||
|     else: | ||||
|         api = None | ||||
|     logger.log("{:} create API = {:} done".format(time_string(), api)) | ||||
|  | ||||
|     last_info, model_base_path, model_best_path = ( | ||||
|         logger.path("info"), | ||||
|         logger.path("model"), | ||||
|         logger.path("best"), | ||||
|     ) | ||||
|     network, criterion = search_model.cuda(), criterion.cuda()  # use a single GPU | ||||
|  | ||||
|     last_info, model_base_path, model_best_path = ( | ||||
|         logger.path("info"), | ||||
|         logger.path("model"), | ||||
|         logger.path("best"), | ||||
|     ) | ||||
|  | ||||
|     if last_info.exists():  # automatically resume from previous checkpoint | ||||
|         logger.log( | ||||
|             "=> loading checkpoint of the last-info '{:}' start".format(last_info) | ||||
|         ) | ||||
|         last_info = torch.load(last_info) | ||||
|         start_epoch = last_info["epoch"] | ||||
|         checkpoint = torch.load(last_info["last_checkpoint"]) | ||||
|         genotypes = checkpoint["genotypes"] | ||||
|         baseline = checkpoint["baseline"] | ||||
|         valid_accuracies = checkpoint["valid_accuracies"] | ||||
|         search_model.load_state_dict(checkpoint["search_model"]) | ||||
|         w_scheduler.load_state_dict(checkpoint["w_scheduler"]) | ||||
|         w_optimizer.load_state_dict(checkpoint["w_optimizer"]) | ||||
|         a_optimizer.load_state_dict(checkpoint["a_optimizer"]) | ||||
|         logger.log( | ||||
|             "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format( | ||||
|                 last_info, start_epoch | ||||
|             ) | ||||
|         ) | ||||
|     else: | ||||
|         logger.log("=> do not find the last-info file : {:}".format(last_info)) | ||||
|         start_epoch, valid_accuracies, genotypes = ( | ||||
|             0, | ||||
|             {"best": -1}, | ||||
|             {-1: network.return_topK(1, True)[0]}, | ||||
|         ) | ||||
|         baseline = None | ||||
|  | ||||
|     # start training | ||||
|     start_time, search_time, epoch_time, total_epoch = ( | ||||
|         time.time(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         config.epochs + config.warmup, | ||||
|     ) | ||||
|     for epoch in range(start_epoch, total_epoch): | ||||
|         w_scheduler.update(epoch, 0.0) | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(epoch_time.val * (total_epoch - epoch), True) | ||||
|         ) | ||||
|         epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) | ||||
|         logger.log( | ||||
|             "\n[Search the {:}-th epoch] {:}, LR={:}".format( | ||||
|                 epoch_str, need_time, min(w_scheduler.get_lr()) | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         network.set_drop_path(float(epoch + 1) / total_epoch, xargs.drop_path_rate) | ||||
|         if xargs.algo == "gdas" or xargs.algo == "gdas_v1": | ||||
|             network.set_tau( | ||||
|                 xargs.tau_max | ||||
|                 - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1) | ||||
|             ) | ||||
|             logger.log( | ||||
|                 "[RESET tau as : {:} and drop_path as {:}]".format( | ||||
|                     network.tau, network.drop_path | ||||
|                 ) | ||||
|             ) | ||||
|         ( | ||||
|             search_w_loss, | ||||
|             search_w_top1, | ||||
|             search_w_top5, | ||||
|             search_a_loss, | ||||
|             search_a_top1, | ||||
|             search_a_top5, | ||||
|         ) = search_func( | ||||
|             search_loader, | ||||
|             network, | ||||
|             criterion, | ||||
|             w_scheduler, | ||||
|             w_optimizer, | ||||
|             a_optimizer, | ||||
|             epoch_str, | ||||
|             xargs.print_freq, | ||||
|             xargs.algo, | ||||
|             logger, | ||||
|         ) | ||||
|         search_time.update(time.time() - start_time) | ||||
|         logger.log( | ||||
|             "[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format( | ||||
|                 epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum | ||||
|             ) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format( | ||||
|                 epoch_str, search_a_loss, search_a_top1, search_a_top5 | ||||
|             ) | ||||
|         ) | ||||
|         if xargs.algo == "enas": | ||||
|             ctl_loss, ctl_acc, baseline, ctl_reward = train_controller( | ||||
|                 valid_loader, | ||||
|                 network, | ||||
|                 criterion, | ||||
|                 a_optimizer, | ||||
|                 baseline, | ||||
|                 epoch_str, | ||||
|                 xargs.print_freq, | ||||
|                 logger, | ||||
|             ) | ||||
|             logger.log( | ||||
|                 "[{:}] controller : loss={:}, acc={:}, baseline={:}, reward={:}".format( | ||||
|                     epoch_str, ctl_loss, ctl_acc, baseline, ctl_reward | ||||
|                 ) | ||||
|             ) | ||||
|  | ||||
|         genotype, temp_accuracy = get_best_arch( | ||||
|             valid_loader, network, xargs.eval_candidate_num, xargs.algo | ||||
|         ) | ||||
|         if xargs.algo == "setn" or xargs.algo == "enas": | ||||
|             network.set_cal_mode("dynamic", genotype) | ||||
|         elif xargs.algo == "gdas": | ||||
|             network.set_cal_mode("gdas", None) | ||||
|         elif xargs.algo == "gdas_v1": | ||||
|             network.set_cal_mode("gdas_v1", None) | ||||
|         elif xargs.algo.startswith("darts"): | ||||
|             network.set_cal_mode("joint", None) | ||||
|         elif xargs.algo == "random": | ||||
|             network.set_cal_mode("urs", None) | ||||
|         else: | ||||
|             raise ValueError("Invalid algorithm name : {:}".format(xargs.algo)) | ||||
|         logger.log( | ||||
|             "[{:}] - [get_best_arch] : {:} -> {:}".format( | ||||
|                 epoch_str, genotype, temp_accuracy | ||||
|             ) | ||||
|         ) | ||||
|         valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( | ||||
|             valid_loader, network, criterion, xargs.algo, logger | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}".format( | ||||
|                 epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype | ||||
|             ) | ||||
|         ) | ||||
|         valid_accuracies[epoch] = valid_a_top1 | ||||
|  | ||||
|         genotypes[epoch] = genotype | ||||
|         logger.log( | ||||
|             "<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch]) | ||||
|         ) | ||||
|         # save checkpoint | ||||
|         save_path = save_checkpoint( | ||||
|             { | ||||
|                 "epoch": epoch + 1, | ||||
|                 "args": deepcopy(xargs), | ||||
|                 "baseline": baseline, | ||||
|                 "search_model": search_model.state_dict(), | ||||
|                 "w_optimizer": w_optimizer.state_dict(), | ||||
|                 "a_optimizer": a_optimizer.state_dict(), | ||||
|                 "w_scheduler": w_scheduler.state_dict(), | ||||
|                 "genotypes": genotypes, | ||||
|                 "valid_accuracies": valid_accuracies, | ||||
|             }, | ||||
|             model_base_path, | ||||
|             logger, | ||||
|         ) | ||||
|         last_info = save_checkpoint( | ||||
|             { | ||||
|                 "epoch": epoch + 1, | ||||
|                 "args": deepcopy(args), | ||||
|                 "last_checkpoint": save_path, | ||||
|             }, | ||||
|             logger.path("info"), | ||||
|             logger, | ||||
|         ) | ||||
|         with torch.no_grad(): | ||||
|             logger.log("{:}".format(search_model.show_alphas())) | ||||
|         if api is not None: | ||||
|             logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200"))) | ||||
|         # measure elapsed time | ||||
|         epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|     # the final post procedure : count the time | ||||
|     start_time = time.time() | ||||
|     genotype, temp_accuracy = get_best_arch( | ||||
|         valid_loader, network, xargs.eval_candidate_num, xargs.algo | ||||
|     ) | ||||
|     if xargs.algo == "setn" or xargs.algo == "enas": | ||||
|         network.set_cal_mode("dynamic", genotype) | ||||
|     elif xargs.algo == "gdas": | ||||
|         network.set_cal_mode("gdas", None) | ||||
|     elif xargs.algo == "gdas_v1": | ||||
|         network.set_cal_mode("gdas_v1", None) | ||||
|     elif xargs.algo.startswith("darts"): | ||||
|         network.set_cal_mode("joint", None) | ||||
|     elif xargs.algo == "random": | ||||
|         network.set_cal_mode("urs", None) | ||||
|     else: | ||||
|         raise ValueError("Invalid algorithm name : {:}".format(xargs.algo)) | ||||
|     search_time.update(time.time() - start_time) | ||||
|  | ||||
|     valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( | ||||
|         valid_loader, network, criterion, xargs.algo, logger | ||||
|     ) | ||||
|     logger.log( | ||||
|         "Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.".format( | ||||
|             genotype, valid_a_top1 | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     logger.log("\n" + "-" * 100) | ||||
|     # check the performance from the architecture dataset | ||||
|     logger.log( | ||||
|         "[{:}] run {:} epochs, cost {:.1f} s, last-geno is {:}.".format( | ||||
|             xargs.algo, total_epoch, search_time.sum, genotype | ||||
|         ) | ||||
|     ) | ||||
|     if api is not None: | ||||
|         logger.log("{:}".format(api.query_by_arch(genotype, "200"))) | ||||
|     logger.close() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("Weight sharing NAS methods to search for cells.") | ||||
|     parser.add_argument("--data_path", type=str, help="Path to dataset") | ||||
|     parser.add_argument( | ||||
|         "--dataset", | ||||
|         type=str, | ||||
|         choices=["cifar10", "cifar100", "ImageNet16-120"], | ||||
|         help="Choose between Cifar10/100 and ImageNet-16.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--search_space", | ||||
|         type=str, | ||||
|         default="tss", | ||||
|         choices=["tss"], | ||||
|         help="The search space name.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--algo", | ||||
|         type=str, | ||||
|         choices=["darts-v1", "darts-v2", "gdas", "gdas_v1", "setn", "random", "enas"], | ||||
|         help="The search space name.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--use_api", | ||||
|         type=int, | ||||
|         default=1, | ||||
|         choices=[0, 1], | ||||
|         help="Whether use API or not (which will cost much memory).", | ||||
|     ) | ||||
|     # FOR GDAS | ||||
|     parser.add_argument( | ||||
|         "--tau_min", type=float, default=0.1, help="The minimum tau for Gumbel Softmax." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--tau_max", type=float, default=10, help="The maximum tau for Gumbel Softmax." | ||||
|     ) | ||||
|     # channels and number-of-cells | ||||
|     parser.add_argument( | ||||
|         "--max_nodes", type=int, default=4, help="The maximum number of nodes." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--channel", type=int, default=16, help="The number of channels." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--num_cells", type=int, default=5, help="The number of cells in one stage." | ||||
|     ) | ||||
|     # | ||||
|     parser.add_argument( | ||||
|         "--eval_candidate_num", | ||||
|         type=int, | ||||
|         default=100, | ||||
|         help="The number of selected architectures to evaluate.", | ||||
|     ) | ||||
|     # | ||||
|     parser.add_argument( | ||||
|         "--track_running_stats", | ||||
|         type=int, | ||||
|         default=0, | ||||
|         choices=[0, 1], | ||||
|         help="Whether use track_running_stats or not in the BN layer.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--affine", | ||||
|         type=int, | ||||
|         default=0, | ||||
|         choices=[0, 1], | ||||
|         help="Whether use affine=True or False in the BN layer.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--config_path", | ||||
|         type=str, | ||||
|         default="./configs/nas-benchmark/algos/weight-sharing.config", | ||||
|         help="The path of configuration.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--overwite_epochs", | ||||
|         type=int, | ||||
|         help="The number of epochs to overwrite that value in config files.", | ||||
|     ) | ||||
|     # architecture leraning rate | ||||
|     parser.add_argument( | ||||
|         "--arch_learning_rate", | ||||
|         type=float, | ||||
|         default=3e-4, | ||||
|         help="learning rate for arch encoding", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--arch_weight_decay", | ||||
|         type=float, | ||||
|         default=1e-3, | ||||
|         help="weight decay for arch encoding", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--arch_eps", type=float, default=1e-8, help="weight decay for arch encoding" | ||||
|     ) | ||||
|     parser.add_argument("--drop_path_rate", type=float, help="The drop path rate.") | ||||
|     # log | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=2, | ||||
|         help="number of data loading workers (default: 2)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./output/search", | ||||
|         help="Folder to save checkpoints and log.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--print_freq", type=int, default=200, help="print frequency (default: 200)" | ||||
|     ) | ||||
|     parser.add_argument("--rand_seed", type=int, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     if args.overwite_epochs is None: | ||||
|         args.save_dir = os.path.join( | ||||
|             "{:}-{:}".format(args.save_dir, args.search_space), | ||||
|             args.dataset, | ||||
|             "{:}-affine{:}_BN{:}-{:}".format( | ||||
|                 args.algo, args.affine, args.track_running_stats, args.drop_path_rate | ||||
|             ), | ||||
|         ) | ||||
|     else: | ||||
|         args.save_dir = os.path.join( | ||||
|             "{:}-{:}".format(args.save_dir, args.search_space), | ||||
|             args.dataset, | ||||
|             "{:}-affine{:}_BN{:}-E{:}-{:}".format( | ||||
|                 args.algo, | ||||
|                 args.affine, | ||||
|                 args.track_running_stats, | ||||
|                 args.overwite_epochs, | ||||
|                 args.drop_path_rate, | ||||
|             ), | ||||
|         ) | ||||
|  | ||||
|     main(args) | ||||
							
								
								
									
										582
									
								
								AutoDL-Projects/exps/NATS-algos/search-size.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										582
									
								
								AutoDL-Projects/exps/NATS-algos/search-size.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,582 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ########################################################################################################################################### | ||||
| # | ||||
| # In this file, we aims to evaluate three kinds of channel searching strategies: | ||||
| # - channel-wise interpolation from "Network Pruning via Transformable Architecture Search, NeurIPS 2019" | ||||
| # - masking + Gumbel-Softmax (mask_gumbel) from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020" | ||||
| # - masking + sampling (mask_rl) from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020" | ||||
| # | ||||
| # For simplicity, we use tas, mask_gumbel, and mask_rl to refer these three strategies. Their official implementations are at the following links: | ||||
| # - TAS: https://github.com/D-X-Y/AutoDL-Projects/blob/main/docs/NeurIPS-2019-TAS.md | ||||
| # - FBNetV2: https://github.com/facebookresearch/mobile-vision | ||||
| # - TuNAS: https://github.com/google-research/google-research/tree/master/tunas | ||||
| #### | ||||
| # python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --warmup_ratio 0.25 | ||||
| #### | ||||
| # python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777 | ||||
| # python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777 | ||||
| # python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777 | ||||
| #### | ||||
| # python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777 | ||||
| # python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777 | ||||
| # python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_gumbel --rand_seed 777 | ||||
| #### | ||||
| # python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777 --use_api 0 | ||||
| # python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777 | ||||
| # python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_rl --arch_weight_decay 0 --rand_seed 777 | ||||
| ########################################################################################################################################### | ||||
| import os, sys, time, random, argparse | ||||
| import numpy as np | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| from xautodl.config_utils import load_config, dict2config, configure2str | ||||
| from xautodl.datasets import get_datasets, get_nas_search_loaders | ||||
| from xautodl.procedures import ( | ||||
|     prepare_seed, | ||||
|     prepare_logger, | ||||
|     save_checkpoint, | ||||
|     copy_checkpoint, | ||||
|     get_optim_scheduler, | ||||
| ) | ||||
| from xautodl.utils import count_parameters_in_MB, obtain_accuracy | ||||
| from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.models import get_cell_based_tiny_net, get_search_spaces | ||||
| from nats_bench import create | ||||
|  | ||||
|  | ||||
| # Ad-hoc for RL algorithms. | ||||
| class ExponentialMovingAverage(object): | ||||
|     """Class that maintains an exponential moving average.""" | ||||
|  | ||||
|     def __init__(self, momentum): | ||||
|         self._numerator = 0 | ||||
|         self._denominator = 0 | ||||
|         self._momentum = momentum | ||||
|  | ||||
|     def update(self, value): | ||||
|         self._numerator = ( | ||||
|             self._momentum * self._numerator + (1 - self._momentum) * value | ||||
|         ) | ||||
|         self._denominator = self._momentum * self._denominator + (1 - self._momentum) | ||||
|  | ||||
|     @property | ||||
|     def value(self): | ||||
|         """Return the current value of the moving average""" | ||||
|         return self._numerator / self._denominator | ||||
|  | ||||
|  | ||||
| RL_BASELINE_EMA = ExponentialMovingAverage(0.95) | ||||
|  | ||||
|  | ||||
| def search_func( | ||||
|     xloader, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     w_optimizer, | ||||
|     a_optimizer, | ||||
|     enable_controller, | ||||
|     algo, | ||||
|     epoch_str, | ||||
|     print_freq, | ||||
|     logger, | ||||
| ): | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     end = time.time() | ||||
|     network.train() | ||||
|     for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate( | ||||
|         xloader | ||||
|     ): | ||||
|         scheduler.update(None, 1.0 * step / len(xloader)) | ||||
|         base_inputs = base_inputs.cuda(non_blocking=True) | ||||
|         arch_inputs = arch_inputs.cuda(non_blocking=True) | ||||
|         base_targets = base_targets.cuda(non_blocking=True) | ||||
|         arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - end) | ||||
|  | ||||
|         # Update the weights | ||||
|         network.zero_grad() | ||||
|         _, logits, _ = network(base_inputs) | ||||
|         base_loss = criterion(logits, base_targets) | ||||
|         base_loss.backward() | ||||
|         w_optimizer.step() | ||||
|         # record | ||||
|         base_prec1, base_prec5 = obtain_accuracy( | ||||
|             logits.data, base_targets.data, topk=(1, 5) | ||||
|         ) | ||||
|         base_losses.update(base_loss.item(), base_inputs.size(0)) | ||||
|         base_top1.update(base_prec1.item(), base_inputs.size(0)) | ||||
|         base_top5.update(base_prec5.item(), base_inputs.size(0)) | ||||
|  | ||||
|         # update the architecture-weight | ||||
|         network.zero_grad() | ||||
|         a_optimizer.zero_grad() | ||||
|         _, logits, log_probs = network(arch_inputs) | ||||
|         arch_prec1, arch_prec5 = obtain_accuracy( | ||||
|             logits.data, arch_targets.data, topk=(1, 5) | ||||
|         ) | ||||
|         if algo == "mask_rl": | ||||
|             with torch.no_grad(): | ||||
|                 RL_BASELINE_EMA.update(arch_prec1.item()) | ||||
|                 rl_advantage = arch_prec1 - RL_BASELINE_EMA.value | ||||
|             rl_log_prob = sum(log_probs) | ||||
|             arch_loss = -rl_advantage * rl_log_prob | ||||
|         elif algo == "tas" or algo == "mask_gumbel": | ||||
|             arch_loss = criterion(logits, arch_targets) | ||||
|         else: | ||||
|             raise ValueError("invalid algorightm name: {:}".format(algo)) | ||||
|         if enable_controller: | ||||
|             arch_loss.backward() | ||||
|             a_optimizer.step() | ||||
|         # record | ||||
|         arch_losses.update(arch_loss.item(), arch_inputs.size(0)) | ||||
|         arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) | ||||
|         arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|  | ||||
|         if step % print_freq == 0 or step + 1 == len(xloader): | ||||
|             Sstr = ( | ||||
|                 "*SEARCH* " | ||||
|                 + time_string() | ||||
|                 + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) | ||||
|             ) | ||||
|             Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                 batch_time=batch_time, data_time=data_time | ||||
|             ) | ||||
|             Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( | ||||
|                 loss=base_losses, top1=base_top1, top5=base_top5 | ||||
|             ) | ||||
|             Astr = "Arch [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( | ||||
|                 loss=arch_losses, top1=arch_top1, top5=arch_top5 | ||||
|             ) | ||||
|             logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr) | ||||
|     return ( | ||||
|         base_losses.avg, | ||||
|         base_top1.avg, | ||||
|         base_top5.avg, | ||||
|         arch_losses.avg, | ||||
|         arch_top1.avg, | ||||
|         arch_top5.avg, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def valid_func(xloader, network, criterion, logger): | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     end = time.time() | ||||
|     with torch.no_grad(): | ||||
|         network.eval() | ||||
|         for step, (arch_inputs, arch_targets) in enumerate(xloader): | ||||
|             arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|             # measure data loading time | ||||
|             data_time.update(time.time() - end) | ||||
|             # prediction | ||||
|             _, logits, _ = network(arch_inputs.cuda(non_blocking=True)) | ||||
|             arch_loss = criterion(logits, arch_targets) | ||||
|             # record | ||||
|             arch_prec1, arch_prec5 = obtain_accuracy( | ||||
|                 logits.data, arch_targets.data, topk=(1, 5) | ||||
|             ) | ||||
|             arch_losses.update(arch_loss.item(), arch_inputs.size(0)) | ||||
|             arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) | ||||
|             arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) | ||||
|             # measure elapsed time | ||||
|             batch_time.update(time.time() - end) | ||||
|             end = time.time() | ||||
|     return arch_losses.avg, arch_top1.avg, arch_top5.avg | ||||
|  | ||||
|  | ||||
| def main(xargs): | ||||
|     assert torch.cuda.is_available(), "CUDA is not available." | ||||
|     torch.backends.cudnn.enabled = True | ||||
|     torch.backends.cudnn.benchmark = False | ||||
|     torch.backends.cudnn.deterministic = True | ||||
|     # torch.set_num_threads(xargs.workers) | ||||
|     prepare_seed(xargs.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|  | ||||
|     train_data, valid_data, xshape, class_num = get_datasets( | ||||
|         xargs.dataset, xargs.data_path, -1 | ||||
|     ) | ||||
|     if xargs.overwite_epochs is None: | ||||
|         extra_info = {"class_num": class_num, "xshape": xshape} | ||||
|     else: | ||||
|         extra_info = { | ||||
|             "class_num": class_num, | ||||
|             "xshape": xshape, | ||||
|             "epochs": xargs.overwite_epochs, | ||||
|         } | ||||
|     config = load_config(xargs.config_path, extra_info, logger) | ||||
|     search_loader, train_loader, valid_loader = get_nas_search_loaders( | ||||
|         train_data, | ||||
|         valid_data, | ||||
|         xargs.dataset, | ||||
|         "configs/nas-benchmark/", | ||||
|         (config.batch_size, config.test_batch_size), | ||||
|         xargs.workers, | ||||
|     ) | ||||
|     logger.log( | ||||
|         "||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( | ||||
|             xargs.dataset, len(search_loader), len(valid_loader), config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) | ||||
|  | ||||
|     search_space = get_search_spaces(xargs.search_space, "nats-bench") | ||||
|  | ||||
|     model_config = dict2config( | ||||
|         dict( | ||||
|             name="generic", | ||||
|             super_type="search-shape", | ||||
|             candidate_Cs=search_space["candidates"], | ||||
|             max_num_Cs=search_space["numbers"], | ||||
|             num_classes=class_num, | ||||
|             genotype=args.genotype, | ||||
|             affine=bool(xargs.affine), | ||||
|             track_running_stats=bool(xargs.track_running_stats), | ||||
|         ), | ||||
|         None, | ||||
|     ) | ||||
|     logger.log("search space : {:}".format(search_space)) | ||||
|     logger.log("model config : {:}".format(model_config)) | ||||
|     search_model = get_cell_based_tiny_net(model_config) | ||||
|     search_model.set_algo(xargs.algo) | ||||
|     logger.log("{:}".format(search_model)) | ||||
|  | ||||
|     w_optimizer, w_scheduler, criterion = get_optim_scheduler( | ||||
|         search_model.weights, config | ||||
|     ) | ||||
|     a_optimizer = torch.optim.Adam( | ||||
|         search_model.alphas, | ||||
|         lr=xargs.arch_learning_rate, | ||||
|         betas=(0.5, 0.999), | ||||
|         weight_decay=xargs.arch_weight_decay, | ||||
|         eps=xargs.arch_eps, | ||||
|     ) | ||||
|     logger.log("w-optimizer : {:}".format(w_optimizer)) | ||||
|     logger.log("a-optimizer : {:}".format(a_optimizer)) | ||||
|     logger.log("w-scheduler : {:}".format(w_scheduler)) | ||||
|     logger.log("criterion   : {:}".format(criterion)) | ||||
|     params = count_parameters_in_MB(search_model) | ||||
|     logger.log("The parameters of the search model = {:.2f} MB".format(params)) | ||||
|     logger.log("search-space : {:}".format(search_space)) | ||||
|     if bool(xargs.use_api): | ||||
|         api = create(None, "size", fast_mode=True, verbose=False) | ||||
|     else: | ||||
|         api = None | ||||
|     logger.log("{:} create API = {:} done".format(time_string(), api)) | ||||
|  | ||||
|     last_info, model_base_path, model_best_path = ( | ||||
|         logger.path("info"), | ||||
|         logger.path("model"), | ||||
|         logger.path("best"), | ||||
|     ) | ||||
|     network, criterion = search_model.cuda(), criterion.cuda()  # use a single GPU | ||||
|  | ||||
|     last_info, model_base_path, model_best_path = ( | ||||
|         logger.path("info"), | ||||
|         logger.path("model"), | ||||
|         logger.path("best"), | ||||
|     ) | ||||
|  | ||||
|     if last_info.exists():  # automatically resume from previous checkpoint | ||||
|         logger.log( | ||||
|             "=> loading checkpoint of the last-info '{:}' start".format(last_info) | ||||
|         ) | ||||
|         last_info = torch.load(last_info) | ||||
|         start_epoch = last_info["epoch"] | ||||
|         checkpoint = torch.load(last_info["last_checkpoint"]) | ||||
|         genotypes = checkpoint["genotypes"] | ||||
|         valid_accuracies = checkpoint["valid_accuracies"] | ||||
|         search_model.load_state_dict(checkpoint["search_model"]) | ||||
|         w_scheduler.load_state_dict(checkpoint["w_scheduler"]) | ||||
|         w_optimizer.load_state_dict(checkpoint["w_optimizer"]) | ||||
|         a_optimizer.load_state_dict(checkpoint["a_optimizer"]) | ||||
|         logger.log( | ||||
|             "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format( | ||||
|                 last_info, start_epoch | ||||
|             ) | ||||
|         ) | ||||
|     else: | ||||
|         logger.log("=> do not find the last-info file : {:}".format(last_info)) | ||||
|         start_epoch, valid_accuracies, genotypes = 0, {"best": -1}, {-1: network.random} | ||||
|  | ||||
|     # start training | ||||
|     start_time, search_time, epoch_time, total_epoch = ( | ||||
|         time.time(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         config.epochs + config.warmup, | ||||
|     ) | ||||
|     for epoch in range(start_epoch, total_epoch): | ||||
|         w_scheduler.update(epoch, 0.0) | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(epoch_time.val * (total_epoch - epoch), True) | ||||
|         ) | ||||
|         epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) | ||||
|  | ||||
|         if ( | ||||
|             xargs.warmup_ratio is None | ||||
|             or xargs.warmup_ratio <= float(epoch) / total_epoch | ||||
|         ): | ||||
|             enable_controller = True | ||||
|             network.set_warmup_ratio(None) | ||||
|         else: | ||||
|             enable_controller = False | ||||
|             network.set_warmup_ratio( | ||||
|                 1.0 - float(epoch) / total_epoch / xargs.warmup_ratio | ||||
|             ) | ||||
|  | ||||
|         logger.log( | ||||
|             "\n[Search the {:}-th epoch] {:}, LR={:}, controller-warmup={:}, enable_controller={:}".format( | ||||
|                 epoch_str, | ||||
|                 need_time, | ||||
|                 min(w_scheduler.get_lr()), | ||||
|                 network.warmup_ratio, | ||||
|                 enable_controller, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         if xargs.algo == "mask_gumbel" or xargs.algo == "tas": | ||||
|             network.set_tau( | ||||
|                 xargs.tau_max | ||||
|                 - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1) | ||||
|             ) | ||||
|             logger.log("[RESET tau as : {:}]".format(network.tau)) | ||||
|         ( | ||||
|             search_w_loss, | ||||
|             search_w_top1, | ||||
|             search_w_top5, | ||||
|             search_a_loss, | ||||
|             search_a_top1, | ||||
|             search_a_top5, | ||||
|         ) = search_func( | ||||
|             search_loader, | ||||
|             network, | ||||
|             criterion, | ||||
|             w_scheduler, | ||||
|             w_optimizer, | ||||
|             a_optimizer, | ||||
|             enable_controller, | ||||
|             xargs.algo, | ||||
|             epoch_str, | ||||
|             xargs.print_freq, | ||||
|             logger, | ||||
|         ) | ||||
|         search_time.update(time.time() - start_time) | ||||
|         logger.log( | ||||
|             "[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format( | ||||
|                 epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum | ||||
|             ) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format( | ||||
|                 epoch_str, search_a_loss, search_a_top1, search_a_top5 | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         genotype = network.genotype | ||||
|         logger.log("[{:}] - [get_best_arch] : {:}".format(epoch_str, genotype)) | ||||
|         valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( | ||||
|             valid_loader, network, criterion, logger | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}".format( | ||||
|                 epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype | ||||
|             ) | ||||
|         ) | ||||
|         valid_accuracies[epoch] = valid_a_top1 | ||||
|  | ||||
|         genotypes[epoch] = genotype | ||||
|         logger.log( | ||||
|             "<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch]) | ||||
|         ) | ||||
|         # save checkpoint | ||||
|         save_path = save_checkpoint( | ||||
|             { | ||||
|                 "epoch": epoch + 1, | ||||
|                 "args": deepcopy(xargs), | ||||
|                 "search_model": search_model.state_dict(), | ||||
|                 "w_optimizer": w_optimizer.state_dict(), | ||||
|                 "a_optimizer": a_optimizer.state_dict(), | ||||
|                 "w_scheduler": w_scheduler.state_dict(), | ||||
|                 "genotypes": genotypes, | ||||
|                 "valid_accuracies": valid_accuracies, | ||||
|             }, | ||||
|             model_base_path, | ||||
|             logger, | ||||
|         ) | ||||
|         last_info = save_checkpoint( | ||||
|             { | ||||
|                 "epoch": epoch + 1, | ||||
|                 "args": deepcopy(args), | ||||
|                 "last_checkpoint": save_path, | ||||
|             }, | ||||
|             logger.path("info"), | ||||
|             logger, | ||||
|         ) | ||||
|         with torch.no_grad(): | ||||
|             logger.log("{:}".format(search_model.show_alphas())) | ||||
|         if api is not None: | ||||
|             logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "90"))) | ||||
|         # measure elapsed time | ||||
|         epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|     # the final post procedure : count the time | ||||
|     start_time = time.time() | ||||
|     genotype = network.genotype | ||||
|     search_time.update(time.time() - start_time) | ||||
|  | ||||
|     valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( | ||||
|         valid_loader, network, criterion, logger | ||||
|     ) | ||||
|     logger.log( | ||||
|         "Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.".format( | ||||
|             genotype, valid_a_top1 | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     logger.log("\n" + "-" * 100) | ||||
|     # check the performance from the architecture dataset | ||||
|     logger.log( | ||||
|         "[{:}] run {:} epochs, cost {:.1f} s, last-geno is {:}.".format( | ||||
|             xargs.algo, total_epoch, search_time.sum, genotype | ||||
|         ) | ||||
|     ) | ||||
|     if api is not None: | ||||
|         logger.log("{:}".format(api.query_by_arch(genotype, "90"))) | ||||
|     logger.close() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("Weight sharing NAS methods to search for cells.") | ||||
|     parser.add_argument("--data_path", type=str, help="Path to dataset") | ||||
|     parser.add_argument( | ||||
|         "--dataset", | ||||
|         type=str, | ||||
|         choices=["cifar10", "cifar100", "ImageNet16-120"], | ||||
|         help="Choose between Cifar10/100 and ImageNet-16.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--search_space", | ||||
|         type=str, | ||||
|         default="sss", | ||||
|         choices=["sss"], | ||||
|         help="The search space name.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--algo", | ||||
|         type=str, | ||||
|         choices=["tas", "mask_gumbel", "mask_rl"], | ||||
|         help="The search space name.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--genotype", | ||||
|         type=str, | ||||
|         default="|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|", | ||||
|         help="The genotype.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--use_api", | ||||
|         type=int, | ||||
|         default=1, | ||||
|         choices=[0, 1], | ||||
|         help="Whether use API or not (which will cost much memory).", | ||||
|     ) | ||||
|     # FOR GDAS | ||||
|     parser.add_argument( | ||||
|         "--tau_min", type=float, default=0.1, help="The minimum tau for Gumbel Softmax." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--tau_max", type=float, default=10, help="The maximum tau for Gumbel Softmax." | ||||
|     ) | ||||
|     # FOR ALL | ||||
|     parser.add_argument( | ||||
|         "--warmup_ratio", type=float, help="The warmup ratio, if None, not use warmup." | ||||
|     ) | ||||
|     # | ||||
|     parser.add_argument( | ||||
|         "--track_running_stats", | ||||
|         type=int, | ||||
|         default=0, | ||||
|         choices=[0, 1], | ||||
|         help="Whether use track_running_stats or not in the BN layer.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--affine", | ||||
|         type=int, | ||||
|         default=0, | ||||
|         choices=[0, 1], | ||||
|         help="Whether use affine=True or False in the BN layer.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--config_path", | ||||
|         type=str, | ||||
|         default="./configs/nas-benchmark/algos/weight-sharing.config", | ||||
|         help="The path of configuration.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--overwite_epochs", | ||||
|         type=int, | ||||
|         help="The number of epochs to overwrite that value in config files.", | ||||
|     ) | ||||
|     # architecture leraning rate | ||||
|     parser.add_argument( | ||||
|         "--arch_learning_rate", | ||||
|         type=float, | ||||
|         default=3e-4, | ||||
|         help="learning rate for arch encoding", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--arch_weight_decay", | ||||
|         type=float, | ||||
|         default=1e-3, | ||||
|         help="weight decay for arch encoding", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--arch_eps", type=float, default=1e-8, help="weight decay for arch encoding" | ||||
|     ) | ||||
|     # log | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=2, | ||||
|         help="number of data loading workers (default: 2)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./output/search", | ||||
|         help="Folder to save checkpoints and log.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--print_freq", type=int, default=200, help="print frequency (default: 200)" | ||||
|     ) | ||||
|     parser.add_argument("--rand_seed", type=int, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     dirname = "{:}-affine{:}_BN{:}-AWD{:}-WARM{:}".format( | ||||
|         args.algo, | ||||
|         args.affine, | ||||
|         args.track_running_stats, | ||||
|         args.arch_weight_decay, | ||||
|         args.warmup_ratio, | ||||
|     ) | ||||
|     if args.overwite_epochs is not None: | ||||
|         dirname = dirname + "-E{:}".format(args.overwite_epochs) | ||||
|     args.save_dir = os.path.join( | ||||
|         "{:}-{:}".format(args.save_dir, args.search_space), args.dataset, dirname | ||||
|     ) | ||||
|  | ||||
|     main(args) | ||||
		Reference in New Issue
	
	Block a user