Move algos to NAS-Bench-201-algos
This commit is contained in:
		
							
								
								
									
										367
									
								
								exps/NAS-Bench-201-algos/BOHB.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										367
									
								
								exps/NAS-Bench-201-algos/BOHB.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,367 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ################################################################### | ||||
| # BOHB: Robust and Efficient Hyperparameter Optimization at Scale # | ||||
| # required to install hpbandster ################################## | ||||
| # pip install hpbandster         ################################## | ||||
| ################################################################### | ||||
| # bash ./scripts-search/algos/BOHB.sh -1         ################## | ||||
| ################################################################### | ||||
| import os, sys, time, random, argparse | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
| import torch | ||||
|  | ||||
| from xautodl.config_utils import load_config | ||||
| from xautodl.datasets import get_datasets, SearchDataset | ||||
| from xautodl.procedures import prepare_seed, prepare_logger | ||||
| from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.models import CellStructure, get_search_spaces | ||||
| from nas_201_api import NASBench201API as API | ||||
|  | ||||
| # BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018 | ||||
| import ConfigSpace | ||||
| from hpbandster.optimizers.bohb import BOHB | ||||
| import hpbandster.core.nameserver as hpns | ||||
| from hpbandster.core.worker import Worker | ||||
|  | ||||
|  | ||||
| def get_configuration_space(max_nodes, search_space): | ||||
|     cs = ConfigSpace.ConfigurationSpace() | ||||
|     # edge2index   = {} | ||||
|     for i in range(1, max_nodes): | ||||
|         for j in range(i): | ||||
|             node_str = "{:}<-{:}".format(i, j) | ||||
|             cs.add_hyperparameter( | ||||
|                 ConfigSpace.CategoricalHyperparameter(node_str, search_space) | ||||
|             ) | ||||
|     return cs | ||||
|  | ||||
|  | ||||
| def config2structure_func(max_nodes): | ||||
|     def config2structure(config): | ||||
|         genotypes = [] | ||||
|         for i in range(1, max_nodes): | ||||
|             xlist = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 op_name = config[node_str] | ||||
|                 xlist.append((op_name, j)) | ||||
|             genotypes.append(tuple(xlist)) | ||||
|         return CellStructure(genotypes) | ||||
|  | ||||
|     return config2structure | ||||
|  | ||||
|  | ||||
| class MyWorker(Worker): | ||||
|     def __init__( | ||||
|         self, | ||||
|         *args, | ||||
|         convert_func=None, | ||||
|         dataname=None, | ||||
|         nas_bench=None, | ||||
|         time_budget=None, | ||||
|         **kwargs | ||||
|     ): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.convert_func = convert_func | ||||
|         self._dataname = dataname | ||||
|         self._nas_bench = nas_bench | ||||
|         self.time_budget = time_budget | ||||
|         self.seen_archs = [] | ||||
|         self.sim_cost_time = 0 | ||||
|         self.real_cost_time = 0 | ||||
|         self.is_end = False | ||||
|  | ||||
|     def get_the_best(self): | ||||
|         assert len(self.seen_archs) > 0 | ||||
|         best_index, best_acc = -1, None | ||||
|         for arch_index in self.seen_archs: | ||||
|             info = self._nas_bench.get_more_info( | ||||
|                 arch_index, self._dataname, None, hp="200", is_random=True | ||||
|             ) | ||||
|             vacc = info["valid-accuracy"] | ||||
|             if best_acc is None or best_acc < vacc: | ||||
|                 best_acc = vacc | ||||
|                 best_index = arch_index | ||||
|         assert best_index != -1 | ||||
|         return best_index | ||||
|  | ||||
|     def compute(self, config, budget, **kwargs): | ||||
|         start_time = time.time() | ||||
|         structure = self.convert_func(config) | ||||
|         arch_index = self._nas_bench.query_index_by_arch(structure) | ||||
|         info = self._nas_bench.get_more_info( | ||||
|             arch_index, self._dataname, None, hp="200", is_random=True | ||||
|         ) | ||||
|         cur_time = info["train-all-time"] + info["valid-per-time"] | ||||
|         cur_vacc = info["valid-accuracy"] | ||||
|         self.real_cost_time += time.time() - start_time | ||||
|         if self.sim_cost_time + cur_time <= self.time_budget and not self.is_end: | ||||
|             self.sim_cost_time += cur_time | ||||
|             self.seen_archs.append(arch_index) | ||||
|             return { | ||||
|                 "loss": 100 - float(cur_vacc), | ||||
|                 "info": { | ||||
|                     "seen-arch": len(self.seen_archs), | ||||
|                     "sim-test-time": self.sim_cost_time, | ||||
|                     "current-arch": arch_index, | ||||
|                 }, | ||||
|             } | ||||
|         else: | ||||
|             self.is_end = True | ||||
|             return { | ||||
|                 "loss": 100, | ||||
|                 "info": { | ||||
|                     "seen-arch": len(self.seen_archs), | ||||
|                     "sim-test-time": self.sim_cost_time, | ||||
|                     "current-arch": None, | ||||
|                 }, | ||||
|             } | ||||
|  | ||||
|  | ||||
| def main(xargs, nas_bench): | ||||
|     assert torch.cuda.is_available(), "CUDA is not available." | ||||
|     torch.backends.cudnn.enabled = True | ||||
|     torch.backends.cudnn.benchmark = False | ||||
|     torch.backends.cudnn.deterministic = True | ||||
|     torch.set_num_threads(xargs.workers) | ||||
|     prepare_seed(xargs.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|  | ||||
|     if xargs.dataset == "cifar10": | ||||
|         dataname = "cifar10-valid" | ||||
|     else: | ||||
|         dataname = xargs.dataset | ||||
|     if xargs.data_path is not None: | ||||
|         train_data, valid_data, xshape, class_num = get_datasets( | ||||
|             xargs.dataset, xargs.data_path, -1 | ||||
|         ) | ||||
|         split_Fpath = "configs/nas-benchmark/cifar-split.txt" | ||||
|         cifar_split = load_config(split_Fpath, None, None) | ||||
|         train_split, valid_split = cifar_split.train, cifar_split.valid | ||||
|         logger.log("Load split file from {:}".format(split_Fpath)) | ||||
|         config_path = "configs/nas-benchmark/algos/R-EA.config" | ||||
|         config = load_config( | ||||
|             config_path, {"class_num": class_num, "xshape": xshape}, logger | ||||
|         ) | ||||
|         # To split data | ||||
|         train_data_v2 = deepcopy(train_data) | ||||
|         train_data_v2.transform = valid_data.transform | ||||
|         valid_data = train_data_v2 | ||||
|         search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | ||||
|         # data loader | ||||
|         train_loader = torch.utils.data.DataLoader( | ||||
|             train_data, | ||||
|             batch_size=config.batch_size, | ||||
|             sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), | ||||
|             num_workers=xargs.workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|         valid_loader = torch.utils.data.DataLoader( | ||||
|             valid_data, | ||||
|             batch_size=config.batch_size, | ||||
|             sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), | ||||
|             num_workers=xargs.workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|         logger.log( | ||||
|             "||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( | ||||
|                 xargs.dataset, len(train_loader), len(valid_loader), config.batch_size | ||||
|             ) | ||||
|         ) | ||||
|         logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) | ||||
|         extra_info = { | ||||
|             "config": config, | ||||
|             "train_loader": train_loader, | ||||
|             "valid_loader": valid_loader, | ||||
|         } | ||||
|     else: | ||||
|         config_path = "configs/nas-benchmark/algos/R-EA.config" | ||||
|         config = load_config(config_path, None, logger) | ||||
|         logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) | ||||
|         extra_info = {"config": config, "train_loader": None, "valid_loader": None} | ||||
|  | ||||
|     # nas dataset load | ||||
|     assert xargs.arch_nas_dataset is not None and os.path.isfile(xargs.arch_nas_dataset) | ||||
|     search_space = get_search_spaces("cell", xargs.search_space_name) | ||||
|     cs = get_configuration_space(xargs.max_nodes, search_space) | ||||
|  | ||||
|     config2structure = config2structure_func(xargs.max_nodes) | ||||
|     hb_run_id = "0" | ||||
|  | ||||
|     NS = hpns.NameServer(run_id=hb_run_id, host="localhost", port=0) | ||||
|     ns_host, ns_port = NS.start() | ||||
|     num_workers = 1 | ||||
|  | ||||
|     # nas_bench = AANASBenchAPI(xargs.arch_nas_dataset) | ||||
|     # logger.log('{:} Create NAS-BENCH-API DONE'.format(time_string())) | ||||
|     workers = [] | ||||
|     for i in range(num_workers): | ||||
|         w = MyWorker( | ||||
|             nameserver=ns_host, | ||||
|             nameserver_port=ns_port, | ||||
|             convert_func=config2structure, | ||||
|             dataname=dataname, | ||||
|             nas_bench=nas_bench, | ||||
|             time_budget=xargs.time_budget, | ||||
|             run_id=hb_run_id, | ||||
|             id=i, | ||||
|         ) | ||||
|         w.run(background=True) | ||||
|         workers.append(w) | ||||
|  | ||||
|     start_time = time.time() | ||||
|     bohb = BOHB( | ||||
|         configspace=cs, | ||||
|         run_id=hb_run_id, | ||||
|         eta=3, | ||||
|         min_budget=12, | ||||
|         max_budget=200, | ||||
|         nameserver=ns_host, | ||||
|         nameserver_port=ns_port, | ||||
|         num_samples=xargs.num_samples, | ||||
|         random_fraction=xargs.random_fraction, | ||||
|         bandwidth_factor=xargs.bandwidth_factor, | ||||
|         ping_interval=10, | ||||
|         min_bandwidth=xargs.min_bandwidth, | ||||
|     ) | ||||
|  | ||||
|     results = bohb.run(xargs.n_iters, min_n_workers=num_workers) | ||||
|  | ||||
|     bohb.shutdown(shutdown_workers=True) | ||||
|     NS.shutdown() | ||||
|  | ||||
|     real_cost_time = time.time() - start_time | ||||
|  | ||||
|     id2config = results.get_id2config_mapping() | ||||
|     incumbent = results.get_incumbent_id() | ||||
|     logger.log( | ||||
|         "Best found configuration: {:} within {:.3f} s".format( | ||||
|             id2config[incumbent]["config"], real_cost_time | ||||
|         ) | ||||
|     ) | ||||
|     best_arch = config2structure(id2config[incumbent]["config"]) | ||||
|  | ||||
|     info = nas_bench.query_by_arch(best_arch, "200") | ||||
|     if info is None: | ||||
|         logger.log("Did not find this architecture : {:}.".format(best_arch)) | ||||
|     else: | ||||
|         logger.log("{:}".format(info)) | ||||
|     logger.log("-" * 100) | ||||
|  | ||||
|     logger.log( | ||||
|         "workers : {:.1f}s with {:} archs".format( | ||||
|             workers[0].time_budget, len(workers[0].seen_archs) | ||||
|         ) | ||||
|     ) | ||||
|     logger.close() | ||||
|     return logger.log_dir, nas_bench.query_index_by_arch(best_arch), real_cost_time | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser( | ||||
|         "BOHB: Robust and Efficient Hyperparameter Optimization at Scale" | ||||
|     ) | ||||
|     parser.add_argument("--data_path", type=str, help="Path to dataset") | ||||
|     parser.add_argument( | ||||
|         "--dataset", | ||||
|         type=str, | ||||
|         choices=["cifar10", "cifar100", "ImageNet16-120"], | ||||
|         help="Choose between Cifar10/100 and ImageNet-16.", | ||||
|     ) | ||||
|     # channels and number-of-cells | ||||
|     parser.add_argument("--search_space_name", type=str, help="The search space name.") | ||||
|     parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") | ||||
|     parser.add_argument("--channel", type=int, help="The number of channels.") | ||||
|     parser.add_argument( | ||||
|         "--num_cells", type=int, help="The number of cells in one stage." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--time_budget", | ||||
|         type=int, | ||||
|         help="The total time cost budge for searching (in seconds).", | ||||
|     ) | ||||
|     # BOHB | ||||
|     parser.add_argument( | ||||
|         "--strategy", | ||||
|         default="sampling", | ||||
|         type=str, | ||||
|         nargs="?", | ||||
|         help="optimization strategy for the acquisition function", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--min_bandwidth", | ||||
|         default=0.3, | ||||
|         type=float, | ||||
|         nargs="?", | ||||
|         help="minimum bandwidth for KDE", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--num_samples", | ||||
|         default=64, | ||||
|         type=int, | ||||
|         nargs="?", | ||||
|         help="number of samples for the acquisition function", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--random_fraction", | ||||
|         default=0.33, | ||||
|         type=float, | ||||
|         nargs="?", | ||||
|         help="fraction of random configurations", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--bandwidth_factor", | ||||
|         default=3, | ||||
|         type=int, | ||||
|         nargs="?", | ||||
|         help="factor multiplied to the bandwidth", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--n_iters", | ||||
|         default=100, | ||||
|         type=int, | ||||
|         nargs="?", | ||||
|         help="number of iterations for optimization method", | ||||
|     ) | ||||
|     # log | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=2, | ||||
|         help="number of data loading workers (default: 2)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", type=str, help="Folder to save checkpoints and log." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--arch_nas_dataset", | ||||
|         type=str, | ||||
|         help="The path to load the architecture dataset (tiny-nas-benchmark).", | ||||
|     ) | ||||
|     parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") | ||||
|     parser.add_argument("--rand_seed", type=int, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|     # if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) | ||||
|     if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset): | ||||
|         nas_bench = None | ||||
|     else: | ||||
|         print( | ||||
|             "{:} build NAS-Benchmark-API from {:}".format( | ||||
|                 time_string(), args.arch_nas_dataset | ||||
|             ) | ||||
|         ) | ||||
|         nas_bench = API(args.arch_nas_dataset) | ||||
|     if args.rand_seed < 0: | ||||
|         save_dir, all_indexes, num, all_times = None, [], 500, [] | ||||
|         for i in range(num): | ||||
|             print("{:} : {:03d}/{:03d}".format(time_string(), i, num)) | ||||
|             args.rand_seed = random.randint(1, 100000) | ||||
|             save_dir, index, ctime = main(args, nas_bench) | ||||
|             all_indexes.append(index) | ||||
|             all_times.append(ctime) | ||||
|         print("\n average time : {:.3f} s".format(sum(all_times) / len(all_times))) | ||||
|         torch.save(all_indexes, save_dir / "results.pth") | ||||
|     else: | ||||
|         main(args, nas_bench) | ||||
							
								
								
									
										417
									
								
								exps/NAS-Bench-201-algos/DARTS-V1.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										417
									
								
								exps/NAS-Bench-201-algos/DARTS-V1.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,417 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ######################################################## | ||||
| # DARTS: Differentiable Architecture Search, ICLR 2019 # | ||||
| ######################################################## | ||||
| import sys, time, random, argparse | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| from pathlib import Path | ||||
|  | ||||
| from xautodl.config_utils import load_config, dict2config, configure2str | ||||
| from xautodl.datasets import get_datasets, get_nas_search_loaders | ||||
| from xautodl.procedures import ( | ||||
|     prepare_seed, | ||||
|     prepare_logger, | ||||
|     save_checkpoint, | ||||
|     copy_checkpoint, | ||||
|     get_optim_scheduler, | ||||
| ) | ||||
| from xautodl.utils import get_model_infos, obtain_accuracy | ||||
| from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.models import get_cell_based_tiny_net, get_search_spaces | ||||
| from nas_201_api import NASBench201API as API | ||||
|  | ||||
|  | ||||
| def search_func( | ||||
|     xloader, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     w_optimizer, | ||||
|     a_optimizer, | ||||
|     epoch_str, | ||||
|     print_freq, | ||||
|     logger, | ||||
|     gradient_clip, | ||||
| ): | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     network.train() | ||||
|     end = time.time() | ||||
|     for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate( | ||||
|         xloader | ||||
|     ): | ||||
|         scheduler.update(None, 1.0 * step / len(xloader)) | ||||
|         base_targets = base_targets.cuda(non_blocking=True) | ||||
|         arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - end) | ||||
|  | ||||
|         # update the weights | ||||
|         w_optimizer.zero_grad() | ||||
|         _, logits = network(base_inputs) | ||||
|         base_loss = criterion(logits, base_targets) | ||||
|         base_loss.backward() | ||||
|         if gradient_clip > 0: | ||||
|             torch.nn.utils.clip_grad_norm_(network.parameters(), gradient_clip) | ||||
|         w_optimizer.step() | ||||
|         # record | ||||
|         base_prec1, base_prec5 = obtain_accuracy( | ||||
|             logits.data, base_targets.data, topk=(1, 5) | ||||
|         ) | ||||
|         base_losses.update(base_loss.item(), base_inputs.size(0)) | ||||
|         base_top1.update(base_prec1.item(), base_inputs.size(0)) | ||||
|         base_top5.update(base_prec5.item(), base_inputs.size(0)) | ||||
|  | ||||
|         # update the architecture-weight | ||||
|         a_optimizer.zero_grad() | ||||
|         _, logits = network(arch_inputs) | ||||
|         arch_loss = criterion(logits, arch_targets) | ||||
|         arch_loss.backward() | ||||
|         a_optimizer.step() | ||||
|         # record | ||||
|         arch_prec1, arch_prec5 = obtain_accuracy( | ||||
|             logits.data, arch_targets.data, topk=(1, 5) | ||||
|         ) | ||||
|         arch_losses.update(arch_loss.item(), arch_inputs.size(0)) | ||||
|         arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) | ||||
|         arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|  | ||||
|         if step % print_freq == 0 or step + 1 == len(xloader): | ||||
|             Sstr = ( | ||||
|                 "*SEARCH* " | ||||
|                 + time_string() | ||||
|                 + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) | ||||
|             ) | ||||
|             Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                 batch_time=batch_time, data_time=data_time | ||||
|             ) | ||||
|             Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( | ||||
|                 loss=base_losses, top1=base_top1, top5=base_top5 | ||||
|             ) | ||||
|             Astr = "Arch [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( | ||||
|                 loss=arch_losses, top1=arch_top1, top5=arch_top5 | ||||
|             ) | ||||
|             logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr) | ||||
|     return base_losses.avg, base_top1.avg, base_top5.avg | ||||
|  | ||||
|  | ||||
| def valid_func(xloader, network, criterion): | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     network.eval() | ||||
|     end = time.time() | ||||
|     with torch.no_grad(): | ||||
|         for step, (arch_inputs, arch_targets) in enumerate(xloader): | ||||
|             arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|             # measure data loading time | ||||
|             data_time.update(time.time() - end) | ||||
|             # prediction | ||||
|             _, logits = network(arch_inputs) | ||||
|             arch_loss = criterion(logits, arch_targets) | ||||
|             # record | ||||
|             arch_prec1, arch_prec5 = obtain_accuracy( | ||||
|                 logits.data, arch_targets.data, topk=(1, 5) | ||||
|             ) | ||||
|             arch_losses.update(arch_loss.item(), arch_inputs.size(0)) | ||||
|             arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) | ||||
|             arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) | ||||
|             # measure elapsed time | ||||
|             batch_time.update(time.time() - end) | ||||
|             end = time.time() | ||||
|     return arch_losses.avg, arch_top1.avg, arch_top5.avg | ||||
|  | ||||
|  | ||||
| def main(xargs): | ||||
|     assert torch.cuda.is_available(), "CUDA is not available." | ||||
|     torch.backends.cudnn.enabled = True | ||||
|     torch.backends.cudnn.benchmark = False | ||||
|     torch.backends.cudnn.deterministic = True | ||||
|     torch.set_num_threads(xargs.workers) | ||||
|     prepare_seed(xargs.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|  | ||||
|     train_data, valid_data, xshape, class_num = get_datasets( | ||||
|         xargs.dataset, xargs.data_path, -1 | ||||
|     ) | ||||
|     # config_path = 'configs/nas-benchmark/algos/DARTS.config' | ||||
|     config = load_config( | ||||
|         xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger | ||||
|     ) | ||||
|     search_loader, _, valid_loader = get_nas_search_loaders( | ||||
|         train_data, | ||||
|         valid_data, | ||||
|         xargs.dataset, | ||||
|         "configs/nas-benchmark/", | ||||
|         config.batch_size, | ||||
|         xargs.workers, | ||||
|     ) | ||||
|     logger.log( | ||||
|         "||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( | ||||
|             xargs.dataset, len(search_loader), len(valid_loader), config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) | ||||
|  | ||||
|     search_space = get_search_spaces("cell", xargs.search_space_name) | ||||
|     if xargs.model_config is None: | ||||
|         model_config = dict2config( | ||||
|             { | ||||
|                 "name": "DARTS-V1", | ||||
|                 "C": xargs.channel, | ||||
|                 "N": xargs.num_cells, | ||||
|                 "max_nodes": xargs.max_nodes, | ||||
|                 "num_classes": class_num, | ||||
|                 "space": search_space, | ||||
|                 "affine": False, | ||||
|                 "track_running_stats": bool(xargs.track_running_stats), | ||||
|             }, | ||||
|             None, | ||||
|         ) | ||||
|     else: | ||||
|         model_config = load_config( | ||||
|             xargs.model_config, | ||||
|             { | ||||
|                 "num_classes": class_num, | ||||
|                 "space": search_space, | ||||
|                 "affine": False, | ||||
|                 "track_running_stats": bool(xargs.track_running_stats), | ||||
|             }, | ||||
|             None, | ||||
|         ) | ||||
|     search_model = get_cell_based_tiny_net(model_config) | ||||
|     logger.log("search-model :\n{:}".format(search_model)) | ||||
|  | ||||
|     w_optimizer, w_scheduler, criterion = get_optim_scheduler( | ||||
|         search_model.get_weights(), config | ||||
|     ) | ||||
|     a_optimizer = torch.optim.Adam( | ||||
|         search_model.get_alphas(), | ||||
|         lr=xargs.arch_learning_rate, | ||||
|         betas=(0.5, 0.999), | ||||
|         weight_decay=xargs.arch_weight_decay, | ||||
|     ) | ||||
|     logger.log("w-optimizer : {:}".format(w_optimizer)) | ||||
|     logger.log("a-optimizer : {:}".format(a_optimizer)) | ||||
|     logger.log("w-scheduler : {:}".format(w_scheduler)) | ||||
|     logger.log("criterion   : {:}".format(criterion)) | ||||
|     flop, param = get_model_infos(search_model, xshape) | ||||
|     # logger.log('{:}'.format(search_model)) | ||||
|     logger.log("FLOP = {:.2f} M, Params = {:.2f} MB".format(flop, param)) | ||||
|     if xargs.arch_nas_dataset is None: | ||||
|         api = None | ||||
|     else: | ||||
|         api = API(xargs.arch_nas_dataset) | ||||
|     logger.log("{:} create API = {:} done".format(time_string(), api)) | ||||
|  | ||||
|     last_info, model_base_path, model_best_path = ( | ||||
|         logger.path("info"), | ||||
|         logger.path("model"), | ||||
|         logger.path("best"), | ||||
|     ) | ||||
|     network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() | ||||
|  | ||||
|     if last_info.exists():  # automatically resume from previous checkpoint | ||||
|         logger.log( | ||||
|             "=> loading checkpoint of the last-info '{:}' start".format(last_info) | ||||
|         ) | ||||
|         last_info = torch.load(last_info) | ||||
|         start_epoch = last_info["epoch"] | ||||
|         checkpoint = torch.load(last_info["last_checkpoint"]) | ||||
|         genotypes = checkpoint["genotypes"] | ||||
|         valid_accuracies = checkpoint["valid_accuracies"] | ||||
|         search_model.load_state_dict(checkpoint["search_model"]) | ||||
|         w_scheduler.load_state_dict(checkpoint["w_scheduler"]) | ||||
|         w_optimizer.load_state_dict(checkpoint["w_optimizer"]) | ||||
|         a_optimizer.load_state_dict(checkpoint["a_optimizer"]) | ||||
|         logger.log( | ||||
|             "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format( | ||||
|                 last_info, start_epoch | ||||
|             ) | ||||
|         ) | ||||
|     else: | ||||
|         logger.log("=> do not find the last-info file : {:}".format(last_info)) | ||||
|         start_epoch, valid_accuracies, genotypes = ( | ||||
|             0, | ||||
|             {"best": -1}, | ||||
|             {-1: search_model.genotype()}, | ||||
|         ) | ||||
|  | ||||
|     # start training | ||||
|     start_time, search_time, epoch_time, total_epoch = ( | ||||
|         time.time(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         config.epochs + config.warmup, | ||||
|     ) | ||||
|     for epoch in range(start_epoch, total_epoch): | ||||
|         w_scheduler.update(epoch, 0.0) | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(epoch_time.val * (total_epoch - epoch), True) | ||||
|         ) | ||||
|         epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) | ||||
|         logger.log( | ||||
|             "\n[Search the {:}-th epoch] {:}, LR={:}".format( | ||||
|                 epoch_str, need_time, min(w_scheduler.get_lr()) | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         search_w_loss, search_w_top1, search_w_top5 = search_func( | ||||
|             search_loader, | ||||
|             network, | ||||
|             criterion, | ||||
|             w_scheduler, | ||||
|             w_optimizer, | ||||
|             a_optimizer, | ||||
|             epoch_str, | ||||
|             xargs.print_freq, | ||||
|             logger, | ||||
|             xargs.gradient_clip, | ||||
|         ) | ||||
|         search_time.update(time.time() - start_time) | ||||
|         logger.log( | ||||
|             "[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format( | ||||
|                 epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum | ||||
|             ) | ||||
|         ) | ||||
|         valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( | ||||
|             valid_loader, network, criterion | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format( | ||||
|                 epoch_str, valid_a_loss, valid_a_top1, valid_a_top5 | ||||
|             ) | ||||
|         ) | ||||
|         # check the best accuracy | ||||
|         valid_accuracies[epoch] = valid_a_top1 | ||||
|         if valid_a_top1 > valid_accuracies["best"]: | ||||
|             valid_accuracies["best"] = valid_a_top1 | ||||
|             genotypes["best"] = search_model.genotype() | ||||
|             find_best = True | ||||
|         else: | ||||
|             find_best = False | ||||
|  | ||||
|         genotypes[epoch] = search_model.genotype() | ||||
|         logger.log( | ||||
|             "<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch]) | ||||
|         ) | ||||
|         # save checkpoint | ||||
|         save_path = save_checkpoint( | ||||
|             { | ||||
|                 "epoch": epoch + 1, | ||||
|                 "args": deepcopy(xargs), | ||||
|                 "search_model": search_model.state_dict(), | ||||
|                 "w_optimizer": w_optimizer.state_dict(), | ||||
|                 "a_optimizer": a_optimizer.state_dict(), | ||||
|                 "w_scheduler": w_scheduler.state_dict(), | ||||
|                 "genotypes": genotypes, | ||||
|                 "valid_accuracies": valid_accuracies, | ||||
|             }, | ||||
|             model_base_path, | ||||
|             logger, | ||||
|         ) | ||||
|         last_info = save_checkpoint( | ||||
|             { | ||||
|                 "epoch": epoch + 1, | ||||
|                 "args": deepcopy(args), | ||||
|                 "last_checkpoint": save_path, | ||||
|             }, | ||||
|             logger.path("info"), | ||||
|             logger, | ||||
|         ) | ||||
|         if find_best: | ||||
|             logger.log( | ||||
|                 "<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.".format( | ||||
|                     epoch_str, valid_a_top1 | ||||
|                 ) | ||||
|             ) | ||||
|             copy_checkpoint(model_base_path, model_best_path, logger) | ||||
|         with torch.no_grad(): | ||||
|             # logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) | ||||
|             logger.log("{:}".format(search_model.show_alphas())) | ||||
|         if api is not None: | ||||
|             logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200"))) | ||||
|         # measure elapsed time | ||||
|         epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|     logger.log("\n" + "-" * 100) | ||||
|     logger.log( | ||||
|         "DARTS-V1 : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format( | ||||
|             total_epoch, search_time.sum, genotypes[total_epoch - 1] | ||||
|         ) | ||||
|     ) | ||||
|     if api is not None: | ||||
|         logger.log("{:}".format(api.query_by_arch(genotypes[total_epoch - 1], "200"))) | ||||
|     logger.close() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("DARTS first order") | ||||
|     parser.add_argument("--data_path", type=str, help="Path to dataset") | ||||
|     parser.add_argument( | ||||
|         "--dataset", | ||||
|         type=str, | ||||
|         choices=["cifar10", "cifar100", "ImageNet16-120"], | ||||
|         help="Choose between Cifar10/100 and ImageNet-16.", | ||||
|     ) | ||||
|     # channels and number-of-cells | ||||
|     parser.add_argument("--search_space_name", type=str, help="The search space name.") | ||||
|     parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") | ||||
|     parser.add_argument("--channel", type=int, help="The number of channels.") | ||||
|     parser.add_argument( | ||||
|         "--num_cells", type=int, help="The number of cells in one stage." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--track_running_stats", | ||||
|         type=int, | ||||
|         choices=[0, 1], | ||||
|         help="Whether use track_running_stats or not in the BN layer.", | ||||
|     ) | ||||
|     parser.add_argument("--config_path", type=str, help="The config path.") | ||||
|     parser.add_argument( | ||||
|         "--model_config", | ||||
|         type=str, | ||||
|         help="The path of the model configuration. When this arg is set, it will cover max_nodes / channels / num_cells.", | ||||
|     ) | ||||
|     parser.add_argument("--gradient_clip", type=float, default=5, help="") | ||||
|     # architecture leraning rate | ||||
|     parser.add_argument( | ||||
|         "--arch_learning_rate", | ||||
|         type=float, | ||||
|         default=3e-4, | ||||
|         help="learning rate for arch encoding", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--arch_weight_decay", | ||||
|         type=float, | ||||
|         default=1e-3, | ||||
|         help="weight decay for arch encoding", | ||||
|     ) | ||||
|     # log | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=2, | ||||
|         help="number of data loading workers (default: 2)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", type=str, help="Folder to save checkpoints and log." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--arch_nas_dataset", | ||||
|         type=str, | ||||
|         help="The path to load the architecture dataset (nas-benchmark).", | ||||
|     ) | ||||
|     parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") | ||||
|     parser.add_argument("--rand_seed", type=int, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     main(args) | ||||
							
								
								
									
										496
									
								
								exps/NAS-Bench-201-algos/DARTS-V2.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										496
									
								
								exps/NAS-Bench-201-algos/DARTS-V2.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,496 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ######################################################## | ||||
| # DARTS: Differentiable Architecture Search, ICLR 2019 # | ||||
| ######################################################## | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| from xautodl.config_utils import load_config, dict2config, configure2str | ||||
| from xautodl.datasets import get_datasets, get_nas_search_loaders | ||||
| from xautodl.procedures import ( | ||||
|     prepare_seed, | ||||
|     prepare_logger, | ||||
|     save_checkpoint, | ||||
|     copy_checkpoint, | ||||
|     get_optim_scheduler, | ||||
| ) | ||||
| from xautodl.utils import get_model_infos, obtain_accuracy | ||||
| from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.models import get_cell_based_tiny_net, get_search_spaces | ||||
| from nas_201_api import NASBench201API as API | ||||
|  | ||||
|  | ||||
| def _concat(xs): | ||||
|     return torch.cat([x.view(-1) for x in xs]) | ||||
|  | ||||
|  | ||||
| def _hessian_vector_product( | ||||
|     vector, network, criterion, base_inputs, base_targets, r=1e-2 | ||||
| ): | ||||
|     R = r / _concat(vector).norm() | ||||
|     for p, v in zip(network.module.get_weights(), vector): | ||||
|         p.data.add_(R, v) | ||||
|     _, logits = network(base_inputs) | ||||
|     loss = criterion(logits, base_targets) | ||||
|     grads_p = torch.autograd.grad(loss, network.module.get_alphas()) | ||||
|  | ||||
|     for p, v in zip(network.module.get_weights(), vector): | ||||
|         p.data.sub_(2 * R, v) | ||||
|     _, logits = network(base_inputs) | ||||
|     loss = criterion(logits, base_targets) | ||||
|     grads_n = torch.autograd.grad(loss, network.module.get_alphas()) | ||||
|  | ||||
|     for p, v in zip(network.module.get_weights(), vector): | ||||
|         p.data.add_(R, v) | ||||
|     return [(x - y).div_(2 * R) for x, y in zip(grads_p, grads_n)] | ||||
|  | ||||
|  | ||||
| def backward_step_unrolled( | ||||
|     network, | ||||
|     criterion, | ||||
|     base_inputs, | ||||
|     base_targets, | ||||
|     w_optimizer, | ||||
|     arch_inputs, | ||||
|     arch_targets, | ||||
| ): | ||||
|     # _compute_unrolled_model | ||||
|     _, logits = network(base_inputs) | ||||
|     loss = criterion(logits, base_targets) | ||||
|     LR, WD, momentum = ( | ||||
|         w_optimizer.param_groups[0]["lr"], | ||||
|         w_optimizer.param_groups[0]["weight_decay"], | ||||
|         w_optimizer.param_groups[0]["momentum"], | ||||
|     ) | ||||
|     with torch.no_grad(): | ||||
|         theta = _concat(network.module.get_weights()) | ||||
|         try: | ||||
|             moment = _concat( | ||||
|                 w_optimizer.state[v]["momentum_buffer"] | ||||
|                 for v in network.module.get_weights() | ||||
|             ) | ||||
|             moment = moment.mul_(momentum) | ||||
|         except: | ||||
|             moment = torch.zeros_like(theta) | ||||
|         dtheta = ( | ||||
|             _concat(torch.autograd.grad(loss, network.module.get_weights())) | ||||
|             + WD * theta | ||||
|         ) | ||||
|         params = theta.sub(LR, moment + dtheta) | ||||
|     unrolled_model = deepcopy(network) | ||||
|     model_dict = unrolled_model.state_dict() | ||||
|     new_params, offset = {}, 0 | ||||
|     for k, v in network.named_parameters(): | ||||
|         if "arch_parameters" in k: | ||||
|             continue | ||||
|         v_length = np.prod(v.size()) | ||||
|         new_params[k] = params[offset : offset + v_length].view(v.size()) | ||||
|         offset += v_length | ||||
|     model_dict.update(new_params) | ||||
|     unrolled_model.load_state_dict(model_dict) | ||||
|  | ||||
|     unrolled_model.zero_grad() | ||||
|     _, unrolled_logits = unrolled_model(arch_inputs) | ||||
|     unrolled_loss = criterion(unrolled_logits, arch_targets) | ||||
|     unrolled_loss.backward() | ||||
|  | ||||
|     dalpha = unrolled_model.module.arch_parameters.grad | ||||
|     vector = [v.grad.data for v in unrolled_model.module.get_weights()] | ||||
|     [implicit_grads] = _hessian_vector_product( | ||||
|         vector, network, criterion, base_inputs, base_targets | ||||
|     ) | ||||
|  | ||||
|     dalpha.data.sub_(LR, implicit_grads.data) | ||||
|  | ||||
|     if network.module.arch_parameters.grad is None: | ||||
|         network.module.arch_parameters.grad = deepcopy(dalpha) | ||||
|     else: | ||||
|         network.module.arch_parameters.grad.data.copy_(dalpha.data) | ||||
|     return unrolled_loss.detach(), unrolled_logits.detach() | ||||
|  | ||||
|  | ||||
| def search_func( | ||||
|     xloader, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     w_optimizer, | ||||
|     a_optimizer, | ||||
|     epoch_str, | ||||
|     print_freq, | ||||
|     logger, | ||||
| ): | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     network.train() | ||||
|     end = time.time() | ||||
|     for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate( | ||||
|         xloader | ||||
|     ): | ||||
|         scheduler.update(None, 1.0 * step / len(xloader)) | ||||
|         base_targets = base_targets.cuda(non_blocking=True) | ||||
|         arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - end) | ||||
|  | ||||
|         # update the architecture-weight | ||||
|         a_optimizer.zero_grad() | ||||
|         arch_loss, arch_logits = backward_step_unrolled( | ||||
|             network, | ||||
|             criterion, | ||||
|             base_inputs, | ||||
|             base_targets, | ||||
|             w_optimizer, | ||||
|             arch_inputs, | ||||
|             arch_targets, | ||||
|         ) | ||||
|         a_optimizer.step() | ||||
|         # record | ||||
|         arch_prec1, arch_prec5 = obtain_accuracy( | ||||
|             arch_logits.data, arch_targets.data, topk=(1, 5) | ||||
|         ) | ||||
|         arch_losses.update(arch_loss.item(), arch_inputs.size(0)) | ||||
|         arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) | ||||
|         arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) | ||||
|  | ||||
|         # update the weights | ||||
|         w_optimizer.zero_grad() | ||||
|         _, logits = network(base_inputs) | ||||
|         base_loss = criterion(logits, base_targets) | ||||
|         base_loss.backward() | ||||
|         torch.nn.utils.clip_grad_norm_(network.parameters(), 5) | ||||
|         w_optimizer.step() | ||||
|         # record | ||||
|         base_prec1, base_prec5 = obtain_accuracy( | ||||
|             logits.data, base_targets.data, topk=(1, 5) | ||||
|         ) | ||||
|         base_losses.update(base_loss.item(), base_inputs.size(0)) | ||||
|         base_top1.update(base_prec1.item(), base_inputs.size(0)) | ||||
|         base_top5.update(base_prec5.item(), base_inputs.size(0)) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|  | ||||
|         if step % print_freq == 0 or step + 1 == len(xloader): | ||||
|             Sstr = ( | ||||
|                 "*SEARCH* " | ||||
|                 + time_string() | ||||
|                 + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) | ||||
|             ) | ||||
|             Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                 batch_time=batch_time, data_time=data_time | ||||
|             ) | ||||
|             Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( | ||||
|                 loss=base_losses, top1=base_top1, top5=base_top5 | ||||
|             ) | ||||
|             Astr = "Arch [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( | ||||
|                 loss=arch_losses, top1=arch_top1, top5=arch_top5 | ||||
|             ) | ||||
|             logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr) | ||||
|     return base_losses.avg, base_top1.avg, base_top5.avg | ||||
|  | ||||
|  | ||||
| def valid_func(xloader, network, criterion): | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     network.eval() | ||||
|     end = time.time() | ||||
|     with torch.no_grad(): | ||||
|         for step, (arch_inputs, arch_targets) in enumerate(xloader): | ||||
|             arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|             # measure data loading time | ||||
|             data_time.update(time.time() - end) | ||||
|             # prediction | ||||
|             _, logits = network(arch_inputs) | ||||
|             arch_loss = criterion(logits, arch_targets) | ||||
|             # record | ||||
|             arch_prec1, arch_prec5 = obtain_accuracy( | ||||
|                 logits.data, arch_targets.data, topk=(1, 5) | ||||
|             ) | ||||
|             arch_losses.update(arch_loss.item(), arch_inputs.size(0)) | ||||
|             arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) | ||||
|             arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) | ||||
|             # measure elapsed time | ||||
|             batch_time.update(time.time() - end) | ||||
|             end = time.time() | ||||
|     return arch_losses.avg, arch_top1.avg, arch_top5.avg | ||||
|  | ||||
|  | ||||
| def main(xargs): | ||||
|     assert torch.cuda.is_available(), "CUDA is not available." | ||||
|     torch.backends.cudnn.enabled = True | ||||
|     torch.backends.cudnn.benchmark = False | ||||
|     torch.backends.cudnn.deterministic = True | ||||
|     torch.set_num_threads(xargs.workers) | ||||
|     prepare_seed(xargs.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|  | ||||
|     train_data, valid_data, xshape, class_num = get_datasets( | ||||
|         xargs.dataset, xargs.data_path, -1 | ||||
|     ) | ||||
|     config = load_config( | ||||
|         xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger | ||||
|     ) | ||||
|     search_loader, _, valid_loader = get_nas_search_loaders( | ||||
|         train_data, | ||||
|         valid_data, | ||||
|         xargs.dataset, | ||||
|         "configs/nas-benchmark/", | ||||
|         config.batch_size, | ||||
|         xargs.workers, | ||||
|     ) | ||||
|     logger.log( | ||||
|         "||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( | ||||
|             xargs.dataset, len(search_loader), len(valid_loader), config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) | ||||
|  | ||||
|     search_space = get_search_spaces("cell", xargs.search_space_name) | ||||
|     model_config = dict2config( | ||||
|         { | ||||
|             "name": "DARTS-V2", | ||||
|             "C": xargs.channel, | ||||
|             "N": xargs.num_cells, | ||||
|             "max_nodes": xargs.max_nodes, | ||||
|             "num_classes": class_num, | ||||
|             "space": search_space, | ||||
|             "affine": False, | ||||
|             "track_running_stats": bool(xargs.track_running_stats), | ||||
|         }, | ||||
|         None, | ||||
|     ) | ||||
|     search_model = get_cell_based_tiny_net(model_config) | ||||
|     logger.log("search-model :\n{:}".format(search_model)) | ||||
|  | ||||
|     w_optimizer, w_scheduler, criterion = get_optim_scheduler( | ||||
|         search_model.get_weights(), config | ||||
|     ) | ||||
|     a_optimizer = torch.optim.Adam( | ||||
|         search_model.get_alphas(), | ||||
|         lr=xargs.arch_learning_rate, | ||||
|         betas=(0.5, 0.999), | ||||
|         weight_decay=xargs.arch_weight_decay, | ||||
|     ) | ||||
|     logger.log("w-optimizer : {:}".format(w_optimizer)) | ||||
|     logger.log("a-optimizer : {:}".format(a_optimizer)) | ||||
|     logger.log("w-scheduler : {:}".format(w_scheduler)) | ||||
|     logger.log("criterion   : {:}".format(criterion)) | ||||
|     flop, param = get_model_infos(search_model, xshape) | ||||
|     # logger.log('{:}'.format(search_model)) | ||||
|     logger.log("FLOP = {:.2f} M, Params = {:.2f} MB".format(flop, param)) | ||||
|     if xargs.arch_nas_dataset is None: | ||||
|         api = None | ||||
|     else: | ||||
|         api = API(xargs.arch_nas_dataset) | ||||
|     logger.log("{:} create API = {:} done".format(time_string(), api)) | ||||
|  | ||||
|     last_info, model_base_path, model_best_path = ( | ||||
|         logger.path("info"), | ||||
|         logger.path("model"), | ||||
|         logger.path("best"), | ||||
|     ) | ||||
|     network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() | ||||
|  | ||||
|     if last_info.exists():  # automatically resume from previous checkpoint | ||||
|         logger.log( | ||||
|             "=> loading checkpoint of the last-info '{:}' start".format(last_info) | ||||
|         ) | ||||
|         last_info = torch.load(last_info) | ||||
|         start_epoch = last_info["epoch"] | ||||
|         checkpoint = torch.load(last_info["last_checkpoint"]) | ||||
|         genotypes = checkpoint["genotypes"] | ||||
|         valid_accuracies = checkpoint["valid_accuracies"] | ||||
|         search_model.load_state_dict(checkpoint["search_model"]) | ||||
|         w_scheduler.load_state_dict(checkpoint["w_scheduler"]) | ||||
|         w_optimizer.load_state_dict(checkpoint["w_optimizer"]) | ||||
|         a_optimizer.load_state_dict(checkpoint["a_optimizer"]) | ||||
|         logger.log( | ||||
|             "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format( | ||||
|                 last_info, start_epoch | ||||
|             ) | ||||
|         ) | ||||
|     else: | ||||
|         logger.log("=> do not find the last-info file : {:}".format(last_info)) | ||||
|         start_epoch, valid_accuracies, genotypes = ( | ||||
|             0, | ||||
|             {"best": -1}, | ||||
|             {-1: search_model.genotype()}, | ||||
|         ) | ||||
|  | ||||
|     # start training | ||||
|     start_time, search_time, epoch_time, total_epoch = ( | ||||
|         time.time(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         config.epochs + config.warmup, | ||||
|     ) | ||||
|     for epoch in range(start_epoch, total_epoch): | ||||
|         w_scheduler.update(epoch, 0.0) | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(epoch_time.val * (total_epoch - epoch), True) | ||||
|         ) | ||||
|         epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) | ||||
|         min_LR = min(w_scheduler.get_lr()) | ||||
|         logger.log( | ||||
|             "\n[Search the {:}-th epoch] {:}, LR={:}".format( | ||||
|                 epoch_str, need_time, min_LR | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         search_w_loss, search_w_top1, search_w_top5 = search_func( | ||||
|             search_loader, | ||||
|             network, | ||||
|             criterion, | ||||
|             w_scheduler, | ||||
|             w_optimizer, | ||||
|             a_optimizer, | ||||
|             epoch_str, | ||||
|             xargs.print_freq, | ||||
|             logger, | ||||
|         ) | ||||
|         search_time.update(time.time() - start_time) | ||||
|         logger.log( | ||||
|             "[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format( | ||||
|                 epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum | ||||
|             ) | ||||
|         ) | ||||
|         valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( | ||||
|             valid_loader, network, criterion | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format( | ||||
|                 epoch_str, valid_a_loss, valid_a_top1, valid_a_top5 | ||||
|             ) | ||||
|         ) | ||||
|         # check the best accuracy | ||||
|         valid_accuracies[epoch] = valid_a_top1 | ||||
|         if valid_a_top1 > valid_accuracies["best"]: | ||||
|             valid_accuracies["best"] = valid_a_top1 | ||||
|             genotypes["best"] = search_model.genotype() | ||||
|             find_best = True | ||||
|         else: | ||||
|             find_best = False | ||||
|  | ||||
|         genotypes[epoch] = search_model.genotype() | ||||
|         logger.log( | ||||
|             "<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch]) | ||||
|         ) | ||||
|         # save checkpoint | ||||
|         save_path = save_checkpoint( | ||||
|             { | ||||
|                 "epoch": epoch + 1, | ||||
|                 "args": deepcopy(xargs), | ||||
|                 "search_model": search_model.state_dict(), | ||||
|                 "w_optimizer": w_optimizer.state_dict(), | ||||
|                 "a_optimizer": a_optimizer.state_dict(), | ||||
|                 "w_scheduler": w_scheduler.state_dict(), | ||||
|                 "genotypes": genotypes, | ||||
|                 "valid_accuracies": valid_accuracies, | ||||
|             }, | ||||
|             model_base_path, | ||||
|             logger, | ||||
|         ) | ||||
|         last_info = save_checkpoint( | ||||
|             { | ||||
|                 "epoch": epoch + 1, | ||||
|                 "args": deepcopy(args), | ||||
|                 "last_checkpoint": save_path, | ||||
|             }, | ||||
|             logger.path("info"), | ||||
|             logger, | ||||
|         ) | ||||
|         if find_best: | ||||
|             logger.log( | ||||
|                 "<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.".format( | ||||
|                     epoch_str, valid_a_top1 | ||||
|                 ) | ||||
|             ) | ||||
|             copy_checkpoint(model_base_path, model_best_path, logger) | ||||
|         with torch.no_grad(): | ||||
|             logger.log( | ||||
|                 "arch-parameters :\n{:}".format( | ||||
|                     nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() | ||||
|                 ) | ||||
|             ) | ||||
|         if api is not None: | ||||
|             logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200"))) | ||||
|         # measure elapsed time | ||||
|         epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|     logger.log("\n" + "-" * 100) | ||||
|     # check the performance from the architecture dataset | ||||
|     logger.log( | ||||
|         "DARTS-V2 : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format( | ||||
|             total_epoch, search_time.sum, genotypes[total_epoch - 1] | ||||
|         ) | ||||
|     ) | ||||
|     if api is not None: | ||||
|         logger.log("{:}".format(api.query_by_arch(genotypes[total_epoch - 1]), "200")) | ||||
|     logger.close() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("DARTS Second Order") | ||||
|     parser.add_argument("--data_path", type=str, help="The path to dataset") | ||||
|     parser.add_argument( | ||||
|         "--dataset", | ||||
|         type=str, | ||||
|         choices=["cifar10", "cifar100", "ImageNet16-120"], | ||||
|         help="Choose between Cifar10/100 and ImageNet-16.", | ||||
|     ) | ||||
|     # channels and number-of-cells | ||||
|     parser.add_argument("--config_path", type=str, help="The config path.") | ||||
|     parser.add_argument("--search_space_name", type=str, help="The search space name.") | ||||
|     parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") | ||||
|     parser.add_argument("--channel", type=int, help="The number of channels.") | ||||
|     parser.add_argument( | ||||
|         "--num_cells", type=int, help="The number of cells in one stage." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--track_running_stats", | ||||
|         type=int, | ||||
|         choices=[0, 1], | ||||
|         help="Whether use track_running_stats or not in the BN layer.", | ||||
|     ) | ||||
|     # architecture leraning rate | ||||
|     parser.add_argument( | ||||
|         "--arch_learning_rate", | ||||
|         type=float, | ||||
|         default=3e-4, | ||||
|         help="learning rate for arch encoding", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--arch_weight_decay", | ||||
|         type=float, | ||||
|         default=1e-3, | ||||
|         help="weight decay for arch encoding", | ||||
|     ) | ||||
|     # log | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=2, | ||||
|         help="number of data loading workers (default: 2)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", type=str, help="Folder to save checkpoints and log." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--arch_nas_dataset", | ||||
|         type=str, | ||||
|         help="The path to load the architecture dataset (tiny-nas-benchmark).", | ||||
|     ) | ||||
|     parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") | ||||
|     parser.add_argument("--rand_seed", type=int, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     main(args) | ||||
							
								
								
									
										578
									
								
								exps/NAS-Bench-201-algos/ENAS.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										578
									
								
								exps/NAS-Bench-201-algos/ENAS.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,578 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ########################################################################## | ||||
| # Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 # | ||||
| ########################################################################## | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| from xautodl.config_utils import load_config, dict2config, configure2str | ||||
| from xautodl.datasets import get_datasets, get_nas_search_loaders | ||||
| from xautodl.procedures import ( | ||||
|     prepare_seed, | ||||
|     prepare_logger, | ||||
|     save_checkpoint, | ||||
|     copy_checkpoint, | ||||
|     get_optim_scheduler, | ||||
| ) | ||||
| from xautodl.utils import get_model_infos, obtain_accuracy | ||||
| from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.models import get_cell_based_tiny_net, get_search_spaces | ||||
| from nas_201_api import NASBench201API as API | ||||
|  | ||||
|  | ||||
| def train_shared_cnn( | ||||
|     xloader, | ||||
|     shared_cnn, | ||||
|     controller, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     optimizer, | ||||
|     epoch_str, | ||||
|     print_freq, | ||||
|     logger, | ||||
| ): | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     losses, top1s, top5s, xend = ( | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         time.time(), | ||||
|     ) | ||||
|  | ||||
|     shared_cnn.train() | ||||
|     controller.eval() | ||||
|  | ||||
|     for step, (inputs, targets) in enumerate(xloader): | ||||
|         scheduler.update(None, 1.0 * step / len(xloader)) | ||||
|         targets = targets.cuda(non_blocking=True) | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - xend) | ||||
|  | ||||
|         with torch.no_grad(): | ||||
|             _, _, sampled_arch = controller() | ||||
|  | ||||
|         optimizer.zero_grad() | ||||
|         shared_cnn.module.update_arch(sampled_arch) | ||||
|         _, logits = shared_cnn(inputs) | ||||
|         loss = criterion(logits, targets) | ||||
|         loss.backward() | ||||
|         torch.nn.utils.clip_grad_norm_(shared_cnn.parameters(), 5) | ||||
|         optimizer.step() | ||||
|         # record | ||||
|         prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|         losses.update(loss.item(), inputs.size(0)) | ||||
|         top1s.update(prec1.item(), inputs.size(0)) | ||||
|         top5s.update(prec5.item(), inputs.size(0)) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - xend) | ||||
|         xend = time.time() | ||||
|  | ||||
|         if step % print_freq == 0 or step + 1 == len(xloader): | ||||
|             Sstr = ( | ||||
|                 "*Train-Shared-CNN* " | ||||
|                 + time_string() | ||||
|                 + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) | ||||
|             ) | ||||
|             Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                 batch_time=batch_time, data_time=data_time | ||||
|             ) | ||||
|             Wstr = "[Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( | ||||
|                 loss=losses, top1=top1s, top5=top5s | ||||
|             ) | ||||
|             logger.log(Sstr + " " + Tstr + " " + Wstr) | ||||
|     return losses.avg, top1s.avg, top5s.avg | ||||
|  | ||||
|  | ||||
| def train_controller( | ||||
|     xloader, | ||||
|     shared_cnn, | ||||
|     controller, | ||||
|     criterion, | ||||
|     optimizer, | ||||
|     config, | ||||
|     epoch_str, | ||||
|     print_freq, | ||||
|     logger, | ||||
| ): | ||||
|     # config. (containing some necessary arg) | ||||
|     #   baseline: The baseline score (i.e. average val_acc) from the previous epoch | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     ( | ||||
|         GradnormMeter, | ||||
|         LossMeter, | ||||
|         ValAccMeter, | ||||
|         EntropyMeter, | ||||
|         BaselineMeter, | ||||
|         RewardMeter, | ||||
|         xend, | ||||
|     ) = ( | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         time.time(), | ||||
|     ) | ||||
|  | ||||
|     shared_cnn.eval() | ||||
|     controller.train() | ||||
|     controller.zero_grad() | ||||
|     # for step, (inputs, targets) in enumerate(xloader): | ||||
|     loader_iter = iter(xloader) | ||||
|     for step in range(config.ctl_train_steps * config.ctl_num_aggre): | ||||
|         try: | ||||
|             inputs, targets = next(loader_iter) | ||||
|         except: | ||||
|             loader_iter = iter(xloader) | ||||
|             inputs, targets = next(loader_iter) | ||||
|         targets = targets.cuda(non_blocking=True) | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - xend) | ||||
|  | ||||
|         log_prob, entropy, sampled_arch = controller() | ||||
|         with torch.no_grad(): | ||||
|             shared_cnn.module.update_arch(sampled_arch) | ||||
|             _, logits = shared_cnn(inputs) | ||||
|             val_top1, val_top5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) | ||||
|             val_top1 = val_top1.view(-1) / 100 | ||||
|         reward = val_top1 + config.ctl_entropy_w * entropy | ||||
|         if config.baseline is None: | ||||
|             baseline = val_top1 | ||||
|         else: | ||||
|             baseline = config.baseline - (1 - config.ctl_bl_dec) * ( | ||||
|                 config.baseline - reward | ||||
|             ) | ||||
|  | ||||
|         loss = -1 * log_prob * (reward - baseline) | ||||
|  | ||||
|         # account | ||||
|         RewardMeter.update(reward.item()) | ||||
|         BaselineMeter.update(baseline.item()) | ||||
|         ValAccMeter.update(val_top1.item() * 100) | ||||
|         LossMeter.update(loss.item()) | ||||
|         EntropyMeter.update(entropy.item()) | ||||
|  | ||||
|         # Average gradient over controller_num_aggregate samples | ||||
|         loss = loss / config.ctl_num_aggre | ||||
|         loss.backward(retain_graph=True) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - xend) | ||||
|         xend = time.time() | ||||
|         if (step + 1) % config.ctl_num_aggre == 0: | ||||
|             grad_norm = torch.nn.utils.clip_grad_norm_(controller.parameters(), 5.0) | ||||
|             GradnormMeter.update(grad_norm) | ||||
|             optimizer.step() | ||||
|             controller.zero_grad() | ||||
|  | ||||
|         if step % print_freq == 0: | ||||
|             Sstr = ( | ||||
|                 "*Train-Controller* " | ||||
|                 + time_string() | ||||
|                 + " [{:}][{:03d}/{:03d}]".format( | ||||
|                     epoch_str, step, config.ctl_train_steps * config.ctl_num_aggre | ||||
|                 ) | ||||
|             ) | ||||
|             Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                 batch_time=batch_time, data_time=data_time | ||||
|             ) | ||||
|             Wstr = "[Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Reward {reward.val:.2f} ({reward.avg:.2f})] Baseline {basel.val:.2f} ({basel.avg:.2f})".format( | ||||
|                 loss=LossMeter, | ||||
|                 top1=ValAccMeter, | ||||
|                 reward=RewardMeter, | ||||
|                 basel=BaselineMeter, | ||||
|             ) | ||||
|             Estr = "Entropy={:.4f} ({:.4f})".format(EntropyMeter.val, EntropyMeter.avg) | ||||
|             logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Estr) | ||||
|  | ||||
|     return ( | ||||
|         LossMeter.avg, | ||||
|         ValAccMeter.avg, | ||||
|         BaselineMeter.avg, | ||||
|         RewardMeter.avg, | ||||
|         baseline.item(), | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def get_best_arch(controller, shared_cnn, xloader, n_samples=10): | ||||
|     with torch.no_grad(): | ||||
|         controller.eval() | ||||
|         shared_cnn.eval() | ||||
|         archs, valid_accs = [], [] | ||||
|         loader_iter = iter(xloader) | ||||
|         for i in range(n_samples): | ||||
|             try: | ||||
|                 inputs, targets = next(loader_iter) | ||||
|             except: | ||||
|                 loader_iter = iter(xloader) | ||||
|                 inputs, targets = next(loader_iter) | ||||
|  | ||||
|             _, _, sampled_arch = controller() | ||||
|             arch = shared_cnn.module.update_arch(sampled_arch) | ||||
|             _, logits = shared_cnn(inputs) | ||||
|             val_top1, val_top5 = obtain_accuracy( | ||||
|                 logits.cpu().data, targets.data, topk=(1, 5) | ||||
|             ) | ||||
|  | ||||
|             archs.append(arch) | ||||
|             valid_accs.append(val_top1.item()) | ||||
|  | ||||
|         best_idx = np.argmax(valid_accs) | ||||
|         best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx] | ||||
|         return best_arch, best_valid_acc | ||||
|  | ||||
|  | ||||
| def valid_func(xloader, network, criterion): | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     network.eval() | ||||
|     end = time.time() | ||||
|     with torch.no_grad(): | ||||
|         for step, (arch_inputs, arch_targets) in enumerate(xloader): | ||||
|             arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|             # measure data loading time | ||||
|             data_time.update(time.time() - end) | ||||
|             # prediction | ||||
|             _, logits = network(arch_inputs) | ||||
|             arch_loss = criterion(logits, arch_targets) | ||||
|             # record | ||||
|             arch_prec1, arch_prec5 = obtain_accuracy( | ||||
|                 logits.data, arch_targets.data, topk=(1, 5) | ||||
|             ) | ||||
|             arch_losses.update(arch_loss.item(), arch_inputs.size(0)) | ||||
|             arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) | ||||
|             arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) | ||||
|             # measure elapsed time | ||||
|             batch_time.update(time.time() - end) | ||||
|             end = time.time() | ||||
|     return arch_losses.avg, arch_top1.avg, arch_top5.avg | ||||
|  | ||||
|  | ||||
| def main(xargs): | ||||
|     assert torch.cuda.is_available(), "CUDA is not available." | ||||
|     torch.backends.cudnn.enabled = True | ||||
|     torch.backends.cudnn.benchmark = False | ||||
|     torch.backends.cudnn.deterministic = True | ||||
|     torch.set_num_threads(xargs.workers) | ||||
|     prepare_seed(xargs.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|  | ||||
|     train_data, test_data, xshape, class_num = get_datasets( | ||||
|         xargs.dataset, xargs.data_path, -1 | ||||
|     ) | ||||
|     logger.log("use config from : {:}".format(xargs.config_path)) | ||||
|     config = load_config( | ||||
|         xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger | ||||
|     ) | ||||
|     _, train_loader, valid_loader = get_nas_search_loaders( | ||||
|         train_data, | ||||
|         test_data, | ||||
|         xargs.dataset, | ||||
|         "configs/nas-benchmark/", | ||||
|         config.batch_size, | ||||
|         xargs.workers, | ||||
|     ) | ||||
|     # since ENAS will train the controller on valid-loader, we need to use train transformation for valid-loader | ||||
|     valid_loader.dataset.transform = deepcopy(train_loader.dataset.transform) | ||||
|     if hasattr(valid_loader.dataset, "transforms"): | ||||
|         valid_loader.dataset.transforms = deepcopy(train_loader.dataset.transforms) | ||||
|     # data loader | ||||
|     logger.log( | ||||
|         "||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( | ||||
|             xargs.dataset, len(train_loader), len(valid_loader), config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) | ||||
|  | ||||
|     search_space = get_search_spaces("cell", xargs.search_space_name) | ||||
|     model_config = dict2config( | ||||
|         { | ||||
|             "name": "ENAS", | ||||
|             "C": xargs.channel, | ||||
|             "N": xargs.num_cells, | ||||
|             "max_nodes": xargs.max_nodes, | ||||
|             "num_classes": class_num, | ||||
|             "space": search_space, | ||||
|             "affine": False, | ||||
|             "track_running_stats": bool(xargs.track_running_stats), | ||||
|         }, | ||||
|         None, | ||||
|     ) | ||||
|     shared_cnn = get_cell_based_tiny_net(model_config) | ||||
|     controller = shared_cnn.create_controller() | ||||
|  | ||||
|     w_optimizer, w_scheduler, criterion = get_optim_scheduler( | ||||
|         shared_cnn.parameters(), config | ||||
|     ) | ||||
|     a_optimizer = torch.optim.Adam( | ||||
|         controller.parameters(), | ||||
|         lr=config.controller_lr, | ||||
|         betas=config.controller_betas, | ||||
|         eps=config.controller_eps, | ||||
|     ) | ||||
|     logger.log("w-optimizer : {:}".format(w_optimizer)) | ||||
|     logger.log("a-optimizer : {:}".format(a_optimizer)) | ||||
|     logger.log("w-scheduler : {:}".format(w_scheduler)) | ||||
|     logger.log("criterion   : {:}".format(criterion)) | ||||
|     # flop, param  = get_model_infos(shared_cnn, xshape) | ||||
|     # logger.log('{:}'.format(shared_cnn)) | ||||
|     # logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) | ||||
|     logger.log("search-space : {:}".format(search_space)) | ||||
|     if xargs.arch_nas_dataset is None: | ||||
|         api = None | ||||
|     else: | ||||
|         api = API(xargs.arch_nas_dataset) | ||||
|     logger.log("{:} create API = {:} done".format(time_string(), api)) | ||||
|     shared_cnn, controller, criterion = ( | ||||
|         torch.nn.DataParallel(shared_cnn).cuda(), | ||||
|         controller.cuda(), | ||||
|         criterion.cuda(), | ||||
|     ) | ||||
|  | ||||
|     last_info, model_base_path, model_best_path = ( | ||||
|         logger.path("info"), | ||||
|         logger.path("model"), | ||||
|         logger.path("best"), | ||||
|     ) | ||||
|  | ||||
|     if last_info.exists():  # automatically resume from previous checkpoint | ||||
|         logger.log( | ||||
|             "=> loading checkpoint of the last-info '{:}' start".format(last_info) | ||||
|         ) | ||||
|         last_info = torch.load(last_info) | ||||
|         start_epoch = last_info["epoch"] | ||||
|         checkpoint = torch.load(last_info["last_checkpoint"]) | ||||
|         genotypes = checkpoint["genotypes"] | ||||
|         baseline = checkpoint["baseline"] | ||||
|         valid_accuracies = checkpoint["valid_accuracies"] | ||||
|         shared_cnn.load_state_dict(checkpoint["shared_cnn"]) | ||||
|         controller.load_state_dict(checkpoint["controller"]) | ||||
|         w_scheduler.load_state_dict(checkpoint["w_scheduler"]) | ||||
|         w_optimizer.load_state_dict(checkpoint["w_optimizer"]) | ||||
|         a_optimizer.load_state_dict(checkpoint["a_optimizer"]) | ||||
|         logger.log( | ||||
|             "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format( | ||||
|                 last_info, start_epoch | ||||
|             ) | ||||
|         ) | ||||
|     else: | ||||
|         logger.log("=> do not find the last-info file : {:}".format(last_info)) | ||||
|         start_epoch, valid_accuracies, genotypes, baseline = 0, {"best": -1}, {}, None | ||||
|  | ||||
|     # start training | ||||
|     start_time, search_time, epoch_time, total_epoch = ( | ||||
|         time.time(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         config.epochs + config.warmup, | ||||
|     ) | ||||
|     for epoch in range(start_epoch, total_epoch): | ||||
|         w_scheduler.update(epoch, 0.0) | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(epoch_time.val * (total_epoch - epoch), True) | ||||
|         ) | ||||
|         epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) | ||||
|         logger.log( | ||||
|             "\n[Search the {:}-th epoch] {:}, LR={:}, baseline={:}".format( | ||||
|                 epoch_str, need_time, min(w_scheduler.get_lr()), baseline | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         cnn_loss, cnn_top1, cnn_top5 = train_shared_cnn( | ||||
|             train_loader, | ||||
|             shared_cnn, | ||||
|             controller, | ||||
|             criterion, | ||||
|             w_scheduler, | ||||
|             w_optimizer, | ||||
|             epoch_str, | ||||
|             xargs.print_freq, | ||||
|             logger, | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[{:}] shared-cnn : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format( | ||||
|                 epoch_str, cnn_loss, cnn_top1, cnn_top5 | ||||
|             ) | ||||
|         ) | ||||
|         ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline = train_controller( | ||||
|             valid_loader, | ||||
|             shared_cnn, | ||||
|             controller, | ||||
|             criterion, | ||||
|             a_optimizer, | ||||
|             dict2config( | ||||
|                 { | ||||
|                     "baseline": baseline, | ||||
|                     "ctl_train_steps": xargs.controller_train_steps, | ||||
|                     "ctl_num_aggre": xargs.controller_num_aggregate, | ||||
|                     "ctl_entropy_w": xargs.controller_entropy_weight, | ||||
|                     "ctl_bl_dec": xargs.controller_bl_dec, | ||||
|                 }, | ||||
|                 None, | ||||
|             ), | ||||
|             epoch_str, | ||||
|             xargs.print_freq, | ||||
|             logger, | ||||
|         ) | ||||
|         search_time.update(time.time() - start_time) | ||||
|         logger.log( | ||||
|             "[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}, time-cost={:.1f} s".format( | ||||
|                 epoch_str, | ||||
|                 ctl_loss, | ||||
|                 ctl_acc, | ||||
|                 ctl_baseline, | ||||
|                 ctl_reward, | ||||
|                 baseline, | ||||
|                 search_time.sum, | ||||
|             ) | ||||
|         ) | ||||
|         best_arch, _ = get_best_arch(controller, shared_cnn, valid_loader) | ||||
|         shared_cnn.module.update_arch(best_arch) | ||||
|         _, best_valid_acc, _ = valid_func(valid_loader, shared_cnn, criterion) | ||||
|  | ||||
|         genotypes[epoch] = best_arch | ||||
|         # check the best accuracy | ||||
|         valid_accuracies[epoch] = best_valid_acc | ||||
|         if best_valid_acc > valid_accuracies["best"]: | ||||
|             valid_accuracies["best"] = best_valid_acc | ||||
|             genotypes["best"] = best_arch | ||||
|             find_best = True | ||||
|         else: | ||||
|             find_best = False | ||||
|  | ||||
|         logger.log( | ||||
|             "<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch]) | ||||
|         ) | ||||
|         # save checkpoint | ||||
|         save_path = save_checkpoint( | ||||
|             { | ||||
|                 "epoch": epoch + 1, | ||||
|                 "args": deepcopy(xargs), | ||||
|                 "baseline": baseline, | ||||
|                 "shared_cnn": shared_cnn.state_dict(), | ||||
|                 "controller": controller.state_dict(), | ||||
|                 "w_optimizer": w_optimizer.state_dict(), | ||||
|                 "a_optimizer": a_optimizer.state_dict(), | ||||
|                 "w_scheduler": w_scheduler.state_dict(), | ||||
|                 "genotypes": genotypes, | ||||
|                 "valid_accuracies": valid_accuracies, | ||||
|             }, | ||||
|             model_base_path, | ||||
|             logger, | ||||
|         ) | ||||
|         last_info = save_checkpoint( | ||||
|             { | ||||
|                 "epoch": epoch + 1, | ||||
|                 "args": deepcopy(args), | ||||
|                 "last_checkpoint": save_path, | ||||
|             }, | ||||
|             logger.path("info"), | ||||
|             logger, | ||||
|         ) | ||||
|         if find_best: | ||||
|             logger.log( | ||||
|                 "<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.".format( | ||||
|                     epoch_str, best_valid_acc | ||||
|                 ) | ||||
|             ) | ||||
|             copy_checkpoint(model_base_path, model_best_path, logger) | ||||
|         if api is not None: | ||||
|             logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200"))) | ||||
|         # measure elapsed time | ||||
|         epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|     logger.log("\n" + "-" * 100) | ||||
|     logger.log( | ||||
|         "During searching, the best architecture is {:}".format(genotypes["best"]) | ||||
|     ) | ||||
|     logger.log("Its accuracy is {:.2f}%".format(valid_accuracies["best"])) | ||||
|     logger.log( | ||||
|         "Randomly select {:} architectures and select the best.".format( | ||||
|             xargs.controller_num_samples | ||||
|         ) | ||||
|     ) | ||||
|     start_time = time.time() | ||||
|     final_arch, _ = get_best_arch( | ||||
|         controller, shared_cnn, valid_loader, xargs.controller_num_samples | ||||
|     ) | ||||
|     search_time.update(time.time() - start_time) | ||||
|     shared_cnn.module.update_arch(final_arch) | ||||
|     final_loss, final_top1, final_top5 = valid_func(valid_loader, shared_cnn, criterion) | ||||
|     logger.log("The Selected Final Architecture : {:}".format(final_arch)) | ||||
|     logger.log( | ||||
|         "Loss={:.3f}, Accuracy@1={:.2f}%, Accuracy@5={:.2f}%".format( | ||||
|             final_loss, final_top1, final_top5 | ||||
|         ) | ||||
|     ) | ||||
|     logger.log( | ||||
|         "ENAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format( | ||||
|             total_epoch, search_time.sum, final_arch | ||||
|         ) | ||||
|     ) | ||||
|     if api is not None: | ||||
|         logger.log("{:}".format(api.query_by_arch(final_arch))) | ||||
|     logger.close() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("ENAS") | ||||
|     parser.add_argument("--data_path", type=str, help="The path to dataset") | ||||
|     parser.add_argument( | ||||
|         "--dataset", | ||||
|         type=str, | ||||
|         choices=["cifar10", "cifar100", "ImageNet16-120"], | ||||
|         help="Choose between Cifar10/100 and ImageNet-16.", | ||||
|     ) | ||||
|     # channels and number-of-cells | ||||
|     parser.add_argument( | ||||
|         "--track_running_stats", | ||||
|         type=int, | ||||
|         choices=[0, 1], | ||||
|         help="Whether use track_running_stats or not in the BN layer.", | ||||
|     ) | ||||
|     parser.add_argument("--search_space_name", type=str, help="The search space name.") | ||||
|     parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") | ||||
|     parser.add_argument("--channel", type=int, help="The number of channels.") | ||||
|     parser.add_argument( | ||||
|         "--num_cells", type=int, help="The number of cells in one stage." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--config_path", type=str, help="The config file to train ENAS." | ||||
|     ) | ||||
|     parser.add_argument("--controller_train_steps", type=int, help=".") | ||||
|     parser.add_argument("--controller_num_aggregate", type=int, help=".") | ||||
|     parser.add_argument( | ||||
|         "--controller_entropy_weight", | ||||
|         type=float, | ||||
|         help="The weight for the entropy of the controller.", | ||||
|     ) | ||||
|     parser.add_argument("--controller_bl_dec", type=float, help=".") | ||||
|     parser.add_argument("--controller_num_samples", type=int, help=".") | ||||
|     # log | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=2, | ||||
|         help="number of data loading workers (default: 2)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", type=str, help="Folder to save checkpoints and log." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--arch_nas_dataset", | ||||
|         type=str, | ||||
|         help="The path to load the architecture dataset (nas-benchmark).", | ||||
|     ) | ||||
|     parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") | ||||
|     parser.add_argument("--rand_seed", type=int, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     main(args) | ||||
							
								
								
									
										404
									
								
								exps/NAS-Bench-201-algos/GDAS.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										404
									
								
								exps/NAS-Bench-201-algos/GDAS.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,404 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ########################################################################### | ||||
| # Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 # | ||||
| ########################################################################### | ||||
| import sys, time, random, argparse | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| from xautodl.config_utils import load_config, dict2config | ||||
| from xautodl.datasets import get_datasets, get_nas_search_loaders | ||||
| from xautodl.procedures import ( | ||||
|     prepare_seed, | ||||
|     prepare_logger, | ||||
|     save_checkpoint, | ||||
|     copy_checkpoint, | ||||
|     get_optim_scheduler, | ||||
| ) | ||||
| from xautodl.utils import get_model_infos, obtain_accuracy | ||||
| from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.models import get_cell_based_tiny_net, get_search_spaces | ||||
| from nas_201_api import NASBench201API as API | ||||
|  | ||||
|  | ||||
| def search_func( | ||||
|     xloader, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     w_optimizer, | ||||
|     a_optimizer, | ||||
|     epoch_str, | ||||
|     print_freq, | ||||
|     logger, | ||||
| ): | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     network.train() | ||||
|     end = time.time() | ||||
|     for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate( | ||||
|         xloader | ||||
|     ): | ||||
|         scheduler.update(None, 1.0 * step / len(xloader)) | ||||
|         base_targets = base_targets.cuda(non_blocking=True) | ||||
|         arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - end) | ||||
|  | ||||
|         # update the weights | ||||
|         w_optimizer.zero_grad() | ||||
|         _, logits = network(base_inputs) | ||||
|         base_loss = criterion(logits, base_targets) | ||||
|         base_loss.backward() | ||||
|         torch.nn.utils.clip_grad_norm_(network.parameters(), 5) | ||||
|         w_optimizer.step() | ||||
|         # record | ||||
|         base_prec1, base_prec5 = obtain_accuracy( | ||||
|             logits.data, base_targets.data, topk=(1, 5) | ||||
|         ) | ||||
|         base_losses.update(base_loss.item(), base_inputs.size(0)) | ||||
|         base_top1.update(base_prec1.item(), base_inputs.size(0)) | ||||
|         base_top5.update(base_prec5.item(), base_inputs.size(0)) | ||||
|  | ||||
|         # update the architecture-weight | ||||
|         a_optimizer.zero_grad() | ||||
|         _, logits = network(arch_inputs) | ||||
|         arch_loss = criterion(logits, arch_targets) | ||||
|         arch_loss.backward() | ||||
|         a_optimizer.step() | ||||
|         # record | ||||
|         arch_prec1, arch_prec5 = obtain_accuracy( | ||||
|             logits.data, arch_targets.data, topk=(1, 5) | ||||
|         ) | ||||
|         arch_losses.update(arch_loss.item(), arch_inputs.size(0)) | ||||
|         arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) | ||||
|         arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|  | ||||
|         if step % print_freq == 0 or step + 1 == len(xloader): | ||||
|             Sstr = ( | ||||
|                 "*SEARCH* " | ||||
|                 + time_string() | ||||
|                 + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) | ||||
|             ) | ||||
|             Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                 batch_time=batch_time, data_time=data_time | ||||
|             ) | ||||
|             Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( | ||||
|                 loss=base_losses, top1=base_top1, top5=base_top5 | ||||
|             ) | ||||
|             Astr = "Arch [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( | ||||
|                 loss=arch_losses, top1=arch_top1, top5=arch_top5 | ||||
|             ) | ||||
|             logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr) | ||||
|     return ( | ||||
|         base_losses.avg, | ||||
|         base_top1.avg, | ||||
|         base_top5.avg, | ||||
|         arch_losses.avg, | ||||
|         arch_top1.avg, | ||||
|         arch_top5.avg, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def main(xargs): | ||||
|     assert torch.cuda.is_available(), "CUDA is not available." | ||||
|     torch.backends.cudnn.enabled = True | ||||
|     torch.backends.cudnn.benchmark = False | ||||
|     torch.backends.cudnn.deterministic = True | ||||
|     torch.set_num_threads(xargs.workers) | ||||
|     prepare_seed(xargs.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|  | ||||
|     train_data, valid_data, xshape, class_num = get_datasets( | ||||
|         xargs.dataset, xargs.data_path, -1 | ||||
|     ) | ||||
|     # config_path = 'configs/nas-benchmark/algos/GDAS.config' | ||||
|     config = load_config( | ||||
|         xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger | ||||
|     ) | ||||
|     search_loader, _, valid_loader = get_nas_search_loaders( | ||||
|         train_data, | ||||
|         valid_data, | ||||
|         xargs.dataset, | ||||
|         "configs/nas-benchmark/", | ||||
|         config.batch_size, | ||||
|         xargs.workers, | ||||
|     ) | ||||
|     logger.log( | ||||
|         "||||||| {:10s} ||||||| Search-Loader-Num={:}, batch size={:}".format( | ||||
|             xargs.dataset, len(search_loader), config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) | ||||
|  | ||||
|     search_space = get_search_spaces("cell", xargs.search_space_name) | ||||
|     if xargs.model_config is None: | ||||
|         model_config = dict2config( | ||||
|             { | ||||
|                 "name": "GDAS", | ||||
|                 "C": xargs.channel, | ||||
|                 "N": xargs.num_cells, | ||||
|                 "max_nodes": xargs.max_nodes, | ||||
|                 "num_classes": class_num, | ||||
|                 "space": search_space, | ||||
|                 "affine": False, | ||||
|                 "track_running_stats": bool(xargs.track_running_stats), | ||||
|             }, | ||||
|             None, | ||||
|         ) | ||||
|     else: | ||||
|         model_config = load_config( | ||||
|             xargs.model_config, | ||||
|             { | ||||
|                 "num_classes": class_num, | ||||
|                 "space": search_space, | ||||
|                 "affine": False, | ||||
|                 "track_running_stats": bool(xargs.track_running_stats), | ||||
|             }, | ||||
|             None, | ||||
|         ) | ||||
|     search_model = get_cell_based_tiny_net(model_config) | ||||
|     logger.log("search-model :\n{:}".format(search_model)) | ||||
|     logger.log("model-config : {:}".format(model_config)) | ||||
|  | ||||
|     w_optimizer, w_scheduler, criterion = get_optim_scheduler( | ||||
|         search_model.get_weights(), config | ||||
|     ) | ||||
|     a_optimizer = torch.optim.Adam( | ||||
|         search_model.get_alphas(), | ||||
|         lr=xargs.arch_learning_rate, | ||||
|         betas=(0.5, 0.999), | ||||
|         weight_decay=xargs.arch_weight_decay, | ||||
|     ) | ||||
|     logger.log("w-optimizer : {:}".format(w_optimizer)) | ||||
|     logger.log("a-optimizer : {:}".format(a_optimizer)) | ||||
|     logger.log("w-scheduler : {:}".format(w_scheduler)) | ||||
|     logger.log("criterion   : {:}".format(criterion)) | ||||
|     flop, param = get_model_infos(search_model, xshape) | ||||
|     logger.log("FLOP = {:.2f} M, Params = {:.2f} MB".format(flop, param)) | ||||
|     logger.log("search-space [{:} ops] : {:}".format(len(search_space), search_space)) | ||||
|     if xargs.arch_nas_dataset is None: | ||||
|         api = None | ||||
|     else: | ||||
|         api = API(xargs.arch_nas_dataset) | ||||
|     logger.log("{:} create API = {:} done".format(time_string(), api)) | ||||
|  | ||||
|     last_info, model_base_path, model_best_path = ( | ||||
|         logger.path("info"), | ||||
|         logger.path("model"), | ||||
|         logger.path("best"), | ||||
|     ) | ||||
|     network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() | ||||
|  | ||||
|     if last_info.exists():  # automatically resume from previous checkpoint | ||||
|         logger.log( | ||||
|             "=> loading checkpoint of the last-info '{:}' start".format(last_info) | ||||
|         ) | ||||
|         last_info = torch.load(last_info) | ||||
|         start_epoch = last_info["epoch"] | ||||
|         checkpoint = torch.load(last_info["last_checkpoint"]) | ||||
|         genotypes = checkpoint["genotypes"] | ||||
|         valid_accuracies = checkpoint["valid_accuracies"] | ||||
|         search_model.load_state_dict(checkpoint["search_model"]) | ||||
|         w_scheduler.load_state_dict(checkpoint["w_scheduler"]) | ||||
|         w_optimizer.load_state_dict(checkpoint["w_optimizer"]) | ||||
|         a_optimizer.load_state_dict(checkpoint["a_optimizer"]) | ||||
|         logger.log( | ||||
|             "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format( | ||||
|                 last_info, start_epoch | ||||
|             ) | ||||
|         ) | ||||
|     else: | ||||
|         logger.log("=> do not find the last-info file : {:}".format(last_info)) | ||||
|         start_epoch, valid_accuracies, genotypes = ( | ||||
|             0, | ||||
|             {"best": -1}, | ||||
|             {-1: search_model.genotype()}, | ||||
|         ) | ||||
|  | ||||
|     # start training | ||||
|     start_time, search_time, epoch_time, total_epoch = ( | ||||
|         time.time(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         config.epochs + config.warmup, | ||||
|     ) | ||||
|     for epoch in range(start_epoch, total_epoch): | ||||
|         w_scheduler.update(epoch, 0.0) | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(epoch_time.val * (total_epoch - epoch), True) | ||||
|         ) | ||||
|         epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) | ||||
|         search_model.set_tau( | ||||
|             xargs.tau_max - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "\n[Search the {:}-th epoch] {:}, tau={:}, LR={:}".format( | ||||
|                 epoch_str, need_time, search_model.get_tau(), min(w_scheduler.get_lr()) | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         ( | ||||
|             search_w_loss, | ||||
|             search_w_top1, | ||||
|             search_w_top5, | ||||
|             valid_a_loss, | ||||
|             valid_a_top1, | ||||
|             valid_a_top5, | ||||
|         ) = search_func( | ||||
|             search_loader, | ||||
|             network, | ||||
|             criterion, | ||||
|             w_scheduler, | ||||
|             w_optimizer, | ||||
|             a_optimizer, | ||||
|             epoch_str, | ||||
|             xargs.print_freq, | ||||
|             logger, | ||||
|         ) | ||||
|         search_time.update(time.time() - start_time) | ||||
|         logger.log( | ||||
|             "[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format( | ||||
|                 epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum | ||||
|             ) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format( | ||||
|                 epoch_str, valid_a_loss, valid_a_top1, valid_a_top5 | ||||
|             ) | ||||
|         ) | ||||
|         # check the best accuracy | ||||
|         valid_accuracies[epoch] = valid_a_top1 | ||||
|         if valid_a_top1 > valid_accuracies["best"]: | ||||
|             valid_accuracies["best"] = valid_a_top1 | ||||
|             genotypes["best"] = search_model.genotype() | ||||
|             find_best = True | ||||
|         else: | ||||
|             find_best = False | ||||
|  | ||||
|         genotypes[epoch] = search_model.genotype() | ||||
|         logger.log( | ||||
|             "<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch]) | ||||
|         ) | ||||
|         # save checkpoint | ||||
|         save_path = save_checkpoint( | ||||
|             { | ||||
|                 "epoch": epoch + 1, | ||||
|                 "args": deepcopy(xargs), | ||||
|                 "search_model": search_model.state_dict(), | ||||
|                 "w_optimizer": w_optimizer.state_dict(), | ||||
|                 "a_optimizer": a_optimizer.state_dict(), | ||||
|                 "w_scheduler": w_scheduler.state_dict(), | ||||
|                 "genotypes": genotypes, | ||||
|                 "valid_accuracies": valid_accuracies, | ||||
|             }, | ||||
|             model_base_path, | ||||
|             logger, | ||||
|         ) | ||||
|         last_info = save_checkpoint( | ||||
|             { | ||||
|                 "epoch": epoch + 1, | ||||
|                 "args": deepcopy(args), | ||||
|                 "last_checkpoint": save_path, | ||||
|             }, | ||||
|             logger.path("info"), | ||||
|             logger, | ||||
|         ) | ||||
|         if find_best: | ||||
|             logger.log( | ||||
|                 "<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.".format( | ||||
|                     epoch_str, valid_a_top1 | ||||
|                 ) | ||||
|             ) | ||||
|             copy_checkpoint(model_base_path, model_best_path, logger) | ||||
|         with torch.no_grad(): | ||||
|             logger.log("{:}".format(search_model.show_alphas())) | ||||
|         if api is not None: | ||||
|             logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200"))) | ||||
|         # measure elapsed time | ||||
|         epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|     logger.log("\n" + "-" * 100) | ||||
|     # check the performance from the architecture dataset | ||||
|     logger.log( | ||||
|         "GDAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format( | ||||
|             total_epoch, search_time.sum, genotypes[total_epoch - 1] | ||||
|         ) | ||||
|     ) | ||||
|     if api is not None: | ||||
|         logger.log("{:}".format(api.query_by_arch(genotypes[total_epoch - 1], "200"))) | ||||
|     logger.close() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("GDAS") | ||||
|     parser.add_argument("--data_path", type=str, help="The path to dataset") | ||||
|     parser.add_argument( | ||||
|         "--dataset", | ||||
|         type=str, | ||||
|         choices=["cifar10", "cifar100", "ImageNet16-120"], | ||||
|         help="Choose between Cifar10/100 and ImageNet-16.", | ||||
|     ) | ||||
|     # channels and number-of-cells | ||||
|     parser.add_argument("--search_space_name", type=str, help="The search space name.") | ||||
|     parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") | ||||
|     parser.add_argument("--channel", type=int, help="The number of channels.") | ||||
|     parser.add_argument( | ||||
|         "--num_cells", type=int, help="The number of cells in one stage." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--track_running_stats", | ||||
|         type=int, | ||||
|         choices=[0, 1], | ||||
|         help="Whether use track_running_stats or not in the BN layer.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--config_path", type=str, help="The path of the configuration." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--model_config", | ||||
|         type=str, | ||||
|         help="The path of the model configuration. When this arg is set, it will cover max_nodes / channels / num_cells.", | ||||
|     ) | ||||
|     # architecture leraning rate | ||||
|     parser.add_argument( | ||||
|         "--arch_learning_rate", | ||||
|         type=float, | ||||
|         default=3e-4, | ||||
|         help="learning rate for arch encoding", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--arch_weight_decay", | ||||
|         type=float, | ||||
|         default=1e-3, | ||||
|         help="weight decay for arch encoding", | ||||
|     ) | ||||
|     parser.add_argument("--tau_min", type=float, help="The minimum tau for Gumbel") | ||||
|     parser.add_argument("--tau_max", type=float, help="The maximum tau for Gumbel") | ||||
|     # log | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=2, | ||||
|         help="number of data loading workers (default: 2)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", type=str, help="Folder to save checkpoints and log." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--arch_nas_dataset", | ||||
|         type=str, | ||||
|         help="The path to load the architecture dataset (tiny-nas-benchmark).", | ||||
|     ) | ||||
|     parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") | ||||
|     parser.add_argument("--rand_seed", type=int, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     main(args) | ||||
							
								
								
									
										382
									
								
								exps/NAS-Bench-201-algos/RANDOM-NAS.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										382
									
								
								exps/NAS-Bench-201-algos/RANDOM-NAS.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,382 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ############################################################################## | ||||
| # Random Search and Reproducibility for Neural Architecture Search, UAI 2019 # | ||||
| ############################################################################## | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| from xautodl.config_utils import load_config, dict2config, configure2str | ||||
| from xautodl.datasets import get_datasets, get_nas_search_loaders | ||||
| from xautodl.procedures import ( | ||||
|     prepare_seed, | ||||
|     prepare_logger, | ||||
|     save_checkpoint, | ||||
|     copy_checkpoint, | ||||
|     get_optim_scheduler, | ||||
| ) | ||||
| from xautodl.utils import get_model_infos, obtain_accuracy | ||||
| from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.models import get_cell_based_tiny_net, get_search_spaces | ||||
| from nas_201_api import NASBench201API as API | ||||
|  | ||||
|  | ||||
| def search_func( | ||||
|     xloader, network, criterion, scheduler, w_optimizer, epoch_str, print_freq, logger | ||||
| ): | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     network.train() | ||||
|     end = time.time() | ||||
|     for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate( | ||||
|         xloader | ||||
|     ): | ||||
|         scheduler.update(None, 1.0 * step / len(xloader)) | ||||
|         base_targets = base_targets.cuda(non_blocking=True) | ||||
|         arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - end) | ||||
|  | ||||
|         # update the weights | ||||
|         network.module.random_genotype(True) | ||||
|         w_optimizer.zero_grad() | ||||
|         _, logits = network(base_inputs) | ||||
|         base_loss = criterion(logits, base_targets) | ||||
|         base_loss.backward() | ||||
|         nn.utils.clip_grad_norm_(network.parameters(), 5) | ||||
|         w_optimizer.step() | ||||
|         # record | ||||
|         base_prec1, base_prec5 = obtain_accuracy( | ||||
|             logits.data, base_targets.data, topk=(1, 5) | ||||
|         ) | ||||
|         base_losses.update(base_loss.item(), base_inputs.size(0)) | ||||
|         base_top1.update(base_prec1.item(), base_inputs.size(0)) | ||||
|         base_top5.update(base_prec5.item(), base_inputs.size(0)) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|  | ||||
|         if step % print_freq == 0 or step + 1 == len(xloader): | ||||
|             Sstr = ( | ||||
|                 "*SEARCH* " | ||||
|                 + time_string() | ||||
|                 + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) | ||||
|             ) | ||||
|             Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                 batch_time=batch_time, data_time=data_time | ||||
|             ) | ||||
|             Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( | ||||
|                 loss=base_losses, top1=base_top1, top5=base_top5 | ||||
|             ) | ||||
|             logger.log(Sstr + " " + Tstr + " " + Wstr) | ||||
|     return base_losses.avg, base_top1.avg, base_top5.avg | ||||
|  | ||||
|  | ||||
| def valid_func(xloader, network, criterion): | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     network.eval() | ||||
|     end = time.time() | ||||
|     with torch.no_grad(): | ||||
|         for step, (arch_inputs, arch_targets) in enumerate(xloader): | ||||
|             arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|             # measure data loading time | ||||
|             data_time.update(time.time() - end) | ||||
|             # prediction | ||||
|  | ||||
|             network.module.random_genotype(True) | ||||
|             _, logits = network(arch_inputs) | ||||
|             arch_loss = criterion(logits, arch_targets) | ||||
|             # record | ||||
|             arch_prec1, arch_prec5 = obtain_accuracy( | ||||
|                 logits.data, arch_targets.data, topk=(1, 5) | ||||
|             ) | ||||
|             arch_losses.update(arch_loss.item(), arch_inputs.size(0)) | ||||
|             arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) | ||||
|             arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) | ||||
|             # measure elapsed time | ||||
|             batch_time.update(time.time() - end) | ||||
|             end = time.time() | ||||
|     return arch_losses.avg, arch_top1.avg, arch_top5.avg | ||||
|  | ||||
|  | ||||
| def search_find_best(xloader, network, n_samples): | ||||
|     with torch.no_grad(): | ||||
|         network.eval() | ||||
|         archs, valid_accs = [], [] | ||||
|         # print ('obtain the top-{:} architectures'.format(n_samples)) | ||||
|         loader_iter = iter(xloader) | ||||
|         for i in range(n_samples): | ||||
|             arch = network.module.random_genotype(True) | ||||
|             try: | ||||
|                 inputs, targets = next(loader_iter) | ||||
|             except: | ||||
|                 loader_iter = iter(xloader) | ||||
|                 inputs, targets = next(loader_iter) | ||||
|  | ||||
|             _, logits = network(inputs) | ||||
|             val_top1, val_top5 = obtain_accuracy( | ||||
|                 logits.cpu().data, targets.data, topk=(1, 5) | ||||
|             ) | ||||
|  | ||||
|             archs.append(arch) | ||||
|             valid_accs.append(val_top1.item()) | ||||
|  | ||||
|         best_idx = np.argmax(valid_accs) | ||||
|         best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx] | ||||
|         return best_arch, best_valid_acc | ||||
|  | ||||
|  | ||||
| def main(xargs): | ||||
|     assert torch.cuda.is_available(), "CUDA is not available." | ||||
|     torch.backends.cudnn.enabled = True | ||||
|     torch.backends.cudnn.benchmark = False | ||||
|     torch.backends.cudnn.deterministic = True | ||||
|     torch.set_num_threads(xargs.workers) | ||||
|     prepare_seed(xargs.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|  | ||||
|     train_data, valid_data, xshape, class_num = get_datasets( | ||||
|         xargs.dataset, xargs.data_path, -1 | ||||
|     ) | ||||
|     config = load_config( | ||||
|         xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger | ||||
|     ) | ||||
|     search_loader, _, valid_loader = get_nas_search_loaders( | ||||
|         train_data, | ||||
|         valid_data, | ||||
|         xargs.dataset, | ||||
|         "configs/nas-benchmark/", | ||||
|         (config.batch_size, config.test_batch_size), | ||||
|         xargs.workers, | ||||
|     ) | ||||
|     logger.log( | ||||
|         "||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( | ||||
|             xargs.dataset, len(search_loader), len(valid_loader), config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) | ||||
|  | ||||
|     search_space = get_search_spaces("cell", xargs.search_space_name) | ||||
|     model_config = dict2config( | ||||
|         { | ||||
|             "name": "RANDOM", | ||||
|             "C": xargs.channel, | ||||
|             "N": xargs.num_cells, | ||||
|             "max_nodes": xargs.max_nodes, | ||||
|             "num_classes": class_num, | ||||
|             "space": search_space, | ||||
|             "affine": False, | ||||
|             "track_running_stats": bool(xargs.track_running_stats), | ||||
|         }, | ||||
|         None, | ||||
|     ) | ||||
|     search_model = get_cell_based_tiny_net(model_config) | ||||
|  | ||||
|     w_optimizer, w_scheduler, criterion = get_optim_scheduler( | ||||
|         search_model.parameters(), config | ||||
|     ) | ||||
|     logger.log("w-optimizer : {:}".format(w_optimizer)) | ||||
|     logger.log("w-scheduler : {:}".format(w_scheduler)) | ||||
|     logger.log("criterion   : {:}".format(criterion)) | ||||
|     if xargs.arch_nas_dataset is None: | ||||
|         api = None | ||||
|     else: | ||||
|         api = API(xargs.arch_nas_dataset) | ||||
|     logger.log("{:} create API = {:} done".format(time_string(), api)) | ||||
|  | ||||
|     last_info, model_base_path, model_best_path = ( | ||||
|         logger.path("info"), | ||||
|         logger.path("model"), | ||||
|         logger.path("best"), | ||||
|     ) | ||||
|     network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() | ||||
|  | ||||
|     if last_info.exists():  # automatically resume from previous checkpoint | ||||
|         logger.log( | ||||
|             "=> loading checkpoint of the last-info '{:}' start".format(last_info) | ||||
|         ) | ||||
|         last_info = torch.load(last_info) | ||||
|         start_epoch = last_info["epoch"] | ||||
|         checkpoint = torch.load(last_info["last_checkpoint"]) | ||||
|         genotypes = checkpoint["genotypes"] | ||||
|         valid_accuracies = checkpoint["valid_accuracies"] | ||||
|         search_model.load_state_dict(checkpoint["search_model"]) | ||||
|         w_scheduler.load_state_dict(checkpoint["w_scheduler"]) | ||||
|         w_optimizer.load_state_dict(checkpoint["w_optimizer"]) | ||||
|         logger.log( | ||||
|             "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format( | ||||
|                 last_info, start_epoch | ||||
|             ) | ||||
|         ) | ||||
|     else: | ||||
|         logger.log("=> do not find the last-info file : {:}".format(last_info)) | ||||
|         start_epoch, valid_accuracies, genotypes = 0, {"best": -1}, {} | ||||
|  | ||||
|     # start training | ||||
|     start_time, search_time, epoch_time, total_epoch = ( | ||||
|         time.time(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         config.epochs + config.warmup, | ||||
|     ) | ||||
|     for epoch in range(start_epoch, total_epoch): | ||||
|         w_scheduler.update(epoch, 0.0) | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(epoch_time.val * (total_epoch - epoch), True) | ||||
|         ) | ||||
|         epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) | ||||
|         logger.log( | ||||
|             "\n[Search the {:}-th epoch] {:}, LR={:}".format( | ||||
|                 epoch_str, need_time, min(w_scheduler.get_lr()) | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         # selected_arch = search_find_best(valid_loader, network, criterion, xargs.select_num) | ||||
|         search_w_loss, search_w_top1, search_w_top5 = search_func( | ||||
|             search_loader, | ||||
|             network, | ||||
|             criterion, | ||||
|             w_scheduler, | ||||
|             w_optimizer, | ||||
|             epoch_str, | ||||
|             xargs.print_freq, | ||||
|             logger, | ||||
|         ) | ||||
|         search_time.update(time.time() - start_time) | ||||
|         logger.log( | ||||
|             "[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format( | ||||
|                 epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum | ||||
|             ) | ||||
|         ) | ||||
|         valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( | ||||
|             valid_loader, network, criterion | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format( | ||||
|                 epoch_str, valid_a_loss, valid_a_top1, valid_a_top5 | ||||
|             ) | ||||
|         ) | ||||
|         cur_arch, cur_valid_acc = search_find_best( | ||||
|             valid_loader, network, xargs.select_num | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[{:}] find-the-best : {:}, accuracy@1={:.2f}%".format( | ||||
|                 epoch_str, cur_arch, cur_valid_acc | ||||
|             ) | ||||
|         ) | ||||
|         genotypes[epoch] = cur_arch | ||||
|         # check the best accuracy | ||||
|         valid_accuracies[epoch] = valid_a_top1 | ||||
|         if valid_a_top1 > valid_accuracies["best"]: | ||||
|             valid_accuracies["best"] = valid_a_top1 | ||||
|             find_best = True | ||||
|         else: | ||||
|             find_best = False | ||||
|  | ||||
|         # save checkpoint | ||||
|         save_path = save_checkpoint( | ||||
|             { | ||||
|                 "epoch": epoch + 1, | ||||
|                 "args": deepcopy(xargs), | ||||
|                 "search_model": search_model.state_dict(), | ||||
|                 "w_optimizer": w_optimizer.state_dict(), | ||||
|                 "w_scheduler": w_scheduler.state_dict(), | ||||
|                 "genotypes": genotypes, | ||||
|                 "valid_accuracies": valid_accuracies, | ||||
|             }, | ||||
|             model_base_path, | ||||
|             logger, | ||||
|         ) | ||||
|         last_info = save_checkpoint( | ||||
|             { | ||||
|                 "epoch": epoch + 1, | ||||
|                 "args": deepcopy(args), | ||||
|                 "last_checkpoint": save_path, | ||||
|             }, | ||||
|             logger.path("info"), | ||||
|             logger, | ||||
|         ) | ||||
|         if find_best: | ||||
|             logger.log( | ||||
|                 "<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.".format( | ||||
|                     epoch_str, valid_a_top1 | ||||
|                 ) | ||||
|             ) | ||||
|             copy_checkpoint(model_base_path, model_best_path, logger) | ||||
|         if api is not None: | ||||
|             logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200"))) | ||||
|         # measure elapsed time | ||||
|         epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|     logger.log("\n" + "-" * 200) | ||||
|     logger.log("Pre-searching costs {:.1f} s".format(search_time.sum)) | ||||
|     start_time = time.time() | ||||
|     best_arch, best_acc = search_find_best(valid_loader, network, xargs.select_num) | ||||
|     search_time.update(time.time() - start_time) | ||||
|     logger.log( | ||||
|         "RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.".format( | ||||
|             best_arch, best_acc, search_time.sum | ||||
|         ) | ||||
|     ) | ||||
|     if api is not None: | ||||
|         logger.log("{:}".format(api.query_by_arch(best_arch, "200"))) | ||||
|     logger.close() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("Random search for NAS.") | ||||
|     parser.add_argument("--data_path", type=str, help="The path to dataset") | ||||
|     parser.add_argument( | ||||
|         "--dataset", | ||||
|         type=str, | ||||
|         choices=["cifar10", "cifar100", "ImageNet16-120"], | ||||
|         help="Choose between Cifar10/100 and ImageNet-16.", | ||||
|     ) | ||||
|     # channels and number-of-cells | ||||
|     parser.add_argument("--search_space_name", type=str, help="The search space name.") | ||||
|     parser.add_argument( | ||||
|         "--config_path", type=str, help="The path to the configuration." | ||||
|     ) | ||||
|     parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") | ||||
|     parser.add_argument("--channel", type=int, help="The number of channels.") | ||||
|     parser.add_argument( | ||||
|         "--num_cells", type=int, help="The number of cells in one stage." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--select_num", | ||||
|         type=int, | ||||
|         help="The number of selected architectures to evaluate.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--track_running_stats", | ||||
|         type=int, | ||||
|         choices=[0, 1], | ||||
|         help="Whether use track_running_stats or not in the BN layer.", | ||||
|     ) | ||||
|     # log | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=2, | ||||
|         help="number of data loading workers (default: 2)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", type=str, help="Folder to save checkpoints and log." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--arch_nas_dataset", | ||||
|         type=str, | ||||
|         help="The path to load the architecture dataset (tiny-nas-benchmark).", | ||||
|     ) | ||||
|     parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") | ||||
|     parser.add_argument("--rand_seed", type=int, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     main(args) | ||||
							
								
								
									
										189
									
								
								exps/NAS-Bench-201-algos/RANDOM.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										189
									
								
								exps/NAS-Bench-201-algos/RANDOM.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,189 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ############################################################################## | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np, collections | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from pathlib import Path | ||||
|  | ||||
| from xautodl.config_utils import load_config, dict2config, configure2str | ||||
| from xautodl.datasets import get_datasets, SearchDataset | ||||
| from xautodl.procedures import ( | ||||
|     prepare_seed, | ||||
|     prepare_logger, | ||||
|     save_checkpoint, | ||||
|     copy_checkpoint, | ||||
|     get_optim_scheduler, | ||||
| ) | ||||
| from xautodl.utils import get_model_infos, obtain_accuracy | ||||
| from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.models import get_search_spaces | ||||
|  | ||||
| from nas_201_api import NASBench201API as API | ||||
| from R_EA import train_and_eval, random_architecture_func | ||||
|  | ||||
|  | ||||
| def main(xargs, nas_bench): | ||||
|     assert torch.cuda.is_available(), "CUDA is not available." | ||||
|     torch.backends.cudnn.enabled = True | ||||
|     torch.backends.cudnn.benchmark = False | ||||
|     torch.backends.cudnn.deterministic = True | ||||
|     torch.set_num_threads(xargs.workers) | ||||
|     prepare_seed(xargs.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|  | ||||
|     if xargs.dataset == "cifar10": | ||||
|         dataname = "cifar10-valid" | ||||
|     else: | ||||
|         dataname = xargs.dataset | ||||
|     if xargs.data_path is not None: | ||||
|         train_data, valid_data, xshape, class_num = get_datasets( | ||||
|             xargs.dataset, xargs.data_path, -1 | ||||
|         ) | ||||
|         split_Fpath = "configs/nas-benchmark/cifar-split.txt" | ||||
|         cifar_split = load_config(split_Fpath, None, None) | ||||
|         train_split, valid_split = cifar_split.train, cifar_split.valid | ||||
|         logger.log("Load split file from {:}".format(split_Fpath)) | ||||
|         config_path = "configs/nas-benchmark/algos/R-EA.config" | ||||
|         config = load_config( | ||||
|             config_path, {"class_num": class_num, "xshape": xshape}, logger | ||||
|         ) | ||||
|         # To split data | ||||
|         train_data_v2 = deepcopy(train_data) | ||||
|         train_data_v2.transform = valid_data.transform | ||||
|         valid_data = train_data_v2 | ||||
|         search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | ||||
|         # data loader | ||||
|         train_loader = torch.utils.data.DataLoader( | ||||
|             train_data, | ||||
|             batch_size=config.batch_size, | ||||
|             sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), | ||||
|             num_workers=xargs.workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|         valid_loader = torch.utils.data.DataLoader( | ||||
|             valid_data, | ||||
|             batch_size=config.batch_size, | ||||
|             sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), | ||||
|             num_workers=xargs.workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|         logger.log( | ||||
|             "||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( | ||||
|                 xargs.dataset, len(train_loader), len(valid_loader), config.batch_size | ||||
|             ) | ||||
|         ) | ||||
|         logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) | ||||
|         extra_info = { | ||||
|             "config": config, | ||||
|             "train_loader": train_loader, | ||||
|             "valid_loader": valid_loader, | ||||
|         } | ||||
|     else: | ||||
|         config_path = "configs/nas-benchmark/algos/R-EA.config" | ||||
|         config = load_config(config_path, None, logger) | ||||
|         logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) | ||||
|         extra_info = {"config": config, "train_loader": None, "valid_loader": None} | ||||
|     search_space = get_search_spaces("cell", xargs.search_space_name) | ||||
|     random_arch = random_architecture_func(xargs.max_nodes, search_space) | ||||
|     # x =random_arch() ; y = mutate_arch(x) | ||||
|     x_start_time = time.time() | ||||
|     logger.log("{:} use nas_bench : {:}".format(time_string(), nas_bench)) | ||||
|     best_arch, best_acc, total_time_cost, history = None, -1, 0, [] | ||||
|     # for idx in range(xargs.random_num): | ||||
|     while total_time_cost < xargs.time_budget: | ||||
|         arch = random_arch() | ||||
|         accuracy, cost_time = train_and_eval(arch, nas_bench, extra_info, dataname) | ||||
|         if total_time_cost + cost_time > xargs.time_budget: | ||||
|             break | ||||
|         else: | ||||
|             total_time_cost += cost_time | ||||
|         history.append(arch) | ||||
|         if best_arch is None or best_acc < accuracy: | ||||
|             best_acc, best_arch = accuracy, arch | ||||
|         logger.log( | ||||
|             "[{:03d}] : {:} : accuracy = {:.2f}%".format(len(history), arch, accuracy) | ||||
|         ) | ||||
|     logger.log( | ||||
|         "{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s (real-cost = {:.3f} s).".format( | ||||
|             time_string(), | ||||
|             best_arch, | ||||
|             best_acc, | ||||
|             len(history), | ||||
|             total_time_cost, | ||||
|             time.time() - x_start_time, | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     info = nas_bench.query_by_arch(best_arch, "200") | ||||
|     if info is None: | ||||
|         logger.log("Did not find this architecture : {:}.".format(best_arch)) | ||||
|     else: | ||||
|         logger.log("{:}".format(info)) | ||||
|     logger.log("-" * 100) | ||||
|     logger.close() | ||||
|     return logger.log_dir, nas_bench.query_index_by_arch(best_arch) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("Random NAS") | ||||
|     parser.add_argument("--data_path", type=str, help="Path to dataset") | ||||
|     parser.add_argument( | ||||
|         "--dataset", | ||||
|         type=str, | ||||
|         choices=["cifar10", "cifar100", "ImageNet16-120"], | ||||
|         help="Choose between Cifar10/100 and ImageNet-16.", | ||||
|     ) | ||||
|     # channels and number-of-cells | ||||
|     parser.add_argument("--search_space_name", type=str, help="The search space name.") | ||||
|     parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") | ||||
|     parser.add_argument("--channel", type=int, help="The number of channels.") | ||||
|     parser.add_argument( | ||||
|         "--num_cells", type=int, help="The number of cells in one stage." | ||||
|     ) | ||||
|     # parser.add_argument('--random_num',         type=int,   help='The number of random selected architectures.') | ||||
|     parser.add_argument( | ||||
|         "--time_budget", | ||||
|         type=int, | ||||
|         help="The total time cost budge for searching (in seconds).", | ||||
|     ) | ||||
|     # log | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=2, | ||||
|         help="number of data loading workers (default: 2)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", type=str, help="Folder to save checkpoints and log." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--arch_nas_dataset", | ||||
|         type=str, | ||||
|         help="The path to load the architecture dataset (tiny-nas-benchmark).", | ||||
|     ) | ||||
|     parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") | ||||
|     parser.add_argument("--rand_seed", type=int, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|     # if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) | ||||
|     if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset): | ||||
|         nas_bench = None | ||||
|     else: | ||||
|         print( | ||||
|             "{:} build NAS-Benchmark-API from {:}".format( | ||||
|                 time_string(), args.arch_nas_dataset | ||||
|             ) | ||||
|         ) | ||||
|         nas_bench = API(args.arch_nas_dataset) | ||||
|     if args.rand_seed < 0: | ||||
|         save_dir, all_indexes, num = None, [], 500 | ||||
|         for i in range(num): | ||||
|             print("{:} : {:03d}/{:03d}".format(time_string(), i, num)) | ||||
|             args.rand_seed = random.randint(1, 100000) | ||||
|             save_dir, index = main(args, nas_bench) | ||||
|             all_indexes.append(index) | ||||
|         torch.save(all_indexes, save_dir / "results.pth") | ||||
|     else: | ||||
|         main(args, nas_bench) | ||||
							
								
								
									
										7
									
								
								exps/NAS-Bench-201-algos/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								exps/NAS-Bench-201-algos/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | ||||
| # NAS Algorithms evaluated in NAS-Bench-201 | ||||
|  | ||||
| The Python files in this folder are used to re-produce the results in our NAS-Bench-201 paper. | ||||
|  | ||||
| We have upgraded the codes to be more general and extendable at [NATS-algos](https://github.com/D-X-Y/AutoDL-Projects/tree/main/exps/NATS-algos). | ||||
|  | ||||
| **Notice** On 24 May 2021, the codes in `AutoDL` repo have been re-organized. If you find `module not found` error, please let me know. I will fix them ASAP. | ||||
							
								
								
									
										399
									
								
								exps/NAS-Bench-201-algos/R_EA.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										399
									
								
								exps/NAS-Bench-201-algos/R_EA.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,399 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ################################################################## | ||||
| # Regularized Evolution for Image Classifier Architecture Search # | ||||
| ################################################################## | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np, collections | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from pathlib import Path | ||||
|  | ||||
| from xautodl.config_utils import load_config, dict2config, configure2str | ||||
| from xautodl.datasets import get_datasets, SearchDataset | ||||
| from xautodl.procedures import ( | ||||
|     prepare_seed, | ||||
|     prepare_logger, | ||||
|     save_checkpoint, | ||||
|     copy_checkpoint, | ||||
|     get_optim_scheduler, | ||||
| ) | ||||
| from xautodl.utils import get_model_infos, obtain_accuracy | ||||
| from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.models import CellStructure, get_search_spaces | ||||
| from nas_201_api import NASBench201API as API | ||||
|  | ||||
|  | ||||
| class Model(object): | ||||
|     def __init__(self): | ||||
|         self.arch = None | ||||
|         self.accuracy = None | ||||
|  | ||||
|     def __str__(self): | ||||
|         """Prints a readable version of this bitstring.""" | ||||
|         return "{:}".format(self.arch) | ||||
|  | ||||
|  | ||||
| # This function is to mimic the training and evaluatinig procedure for a single architecture `arch`. | ||||
| # The time_cost is calculated as the total training time for a few (e.g., 12 epochs) plus the evaluation time for one epoch. | ||||
| # For use_012_epoch_training = True, the architecture is trained for 12 epochs, with LR being decaded from 0.1 to 0. | ||||
| #       In this case, the LR schedular is converged. | ||||
| # For use_012_epoch_training = False, the architecture is planed to be trained for 200 epochs, but we early stop its procedure. | ||||
| # | ||||
| def train_and_eval( | ||||
|     arch, nas_bench, extra_info, dataname="cifar10-valid", use_012_epoch_training=True | ||||
| ): | ||||
|  | ||||
|     if use_012_epoch_training and nas_bench is not None: | ||||
|         arch_index = nas_bench.query_index_by_arch(arch) | ||||
|         assert arch_index >= 0, "can not find this arch : {:}".format(arch) | ||||
|         info = nas_bench.get_more_info( | ||||
|             arch_index, dataname, iepoch=None, hp="12", is_random=True | ||||
|         ) | ||||
|         valid_acc, time_cost = ( | ||||
|             info["valid-accuracy"], | ||||
|             info["train-all-time"] + info["valid-per-time"], | ||||
|         ) | ||||
|         # _, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs | ||||
|     elif not use_012_epoch_training and nas_bench is not None: | ||||
|         # Please contact me if you want to use the following logic, because it has some potential issues. | ||||
|         # Please use `use_012_epoch_training=False` for cifar10 only. | ||||
|         # It did return values for cifar100 and ImageNet16-120, but it has some potential issues. (Please email me for more details) | ||||
|         arch_index, nepoch = nas_bench.query_index_by_arch(arch), 25 | ||||
|         assert arch_index >= 0, "can not find this arch : {:}".format(arch) | ||||
|         xoinfo = nas_bench.get_more_info( | ||||
|             arch_index, "cifar10-valid", iepoch=None, hp="12" | ||||
|         ) | ||||
|         xocost = nas_bench.get_cost_info(arch_index, "cifar10-valid", hp="200") | ||||
|         info = nas_bench.get_more_info( | ||||
|             arch_index, dataname, nepoch, hp="200", is_random=True | ||||
|         )  # use the validation accuracy after 25 training epochs, which is used in our ICLR submission (not the camera ready). | ||||
|         cost = nas_bench.get_cost_info(arch_index, dataname, hp="200") | ||||
|         # The following codes are used to estimate the time cost. | ||||
|         # When we build NAS-Bench-201, architectures are trained on different machines and we can not use that time record. | ||||
|         # When we create checkpoints for converged_LR, we run all experiments on 1080Ti, and thus the time for each architecture can be fairly compared. | ||||
|         nums = { | ||||
|             "ImageNet16-120-train": 151700, | ||||
|             "ImageNet16-120-valid": 3000, | ||||
|             "cifar10-valid-train": 25000, | ||||
|             "cifar10-valid-valid": 25000, | ||||
|             "cifar100-train": 50000, | ||||
|             "cifar100-valid": 5000, | ||||
|         } | ||||
|         estimated_train_cost = ( | ||||
|             xoinfo["train-per-time"] | ||||
|             / nums["cifar10-valid-train"] | ||||
|             * nums["{:}-train".format(dataname)] | ||||
|             / xocost["latency"] | ||||
|             * cost["latency"] | ||||
|             * nepoch | ||||
|         ) | ||||
|         estimated_valid_cost = ( | ||||
|             xoinfo["valid-per-time"] | ||||
|             / nums["cifar10-valid-valid"] | ||||
|             * nums["{:}-valid".format(dataname)] | ||||
|             / xocost["latency"] | ||||
|             * cost["latency"] | ||||
|         ) | ||||
|         try: | ||||
|             valid_acc, time_cost = ( | ||||
|                 info["valid-accuracy"], | ||||
|                 estimated_train_cost + estimated_valid_cost, | ||||
|             ) | ||||
|         except: | ||||
|             valid_acc, time_cost = ( | ||||
|                 info["valtest-accuracy"], | ||||
|                 estimated_train_cost + estimated_valid_cost, | ||||
|             ) | ||||
|     else: | ||||
|         # train a model from scratch. | ||||
|         raise ValueError("NOT IMPLEMENT YET") | ||||
|     return valid_acc, time_cost | ||||
|  | ||||
|  | ||||
| def random_architecture_func(max_nodes, op_names): | ||||
|     # return a random architecture | ||||
|     def random_architecture(): | ||||
|         genotypes = [] | ||||
|         for i in range(1, max_nodes): | ||||
|             xlist = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 op_name = random.choice(op_names) | ||||
|                 xlist.append((op_name, j)) | ||||
|             genotypes.append(tuple(xlist)) | ||||
|         return CellStructure(genotypes) | ||||
|  | ||||
|     return random_architecture | ||||
|  | ||||
|  | ||||
| def mutate_arch_func(op_names): | ||||
|     """Computes the architecture for a child of the given parent architecture. | ||||
|     The parent architecture is cloned and mutated to produce the child architecture. The child architecture is mutated by randomly switch one operation to another. | ||||
|     """ | ||||
|  | ||||
|     def mutate_arch_func(parent_arch): | ||||
|         child_arch = deepcopy(parent_arch) | ||||
|         node_id = random.randint(0, len(child_arch.nodes) - 1) | ||||
|         node_info = list(child_arch.nodes[node_id]) | ||||
|         snode_id = random.randint(0, len(node_info) - 1) | ||||
|         xop = random.choice(op_names) | ||||
|         while xop == node_info[snode_id][0]: | ||||
|             xop = random.choice(op_names) | ||||
|         node_info[snode_id] = (xop, node_info[snode_id][1]) | ||||
|         child_arch.nodes[node_id] = tuple(node_info) | ||||
|         return child_arch | ||||
|  | ||||
|     return mutate_arch_func | ||||
|  | ||||
|  | ||||
| def regularized_evolution( | ||||
|     cycles, | ||||
|     population_size, | ||||
|     sample_size, | ||||
|     time_budget, | ||||
|     random_arch, | ||||
|     mutate_arch, | ||||
|     nas_bench, | ||||
|     extra_info, | ||||
|     dataname, | ||||
| ): | ||||
|     """Algorithm for regularized evolution (i.e. aging evolution). | ||||
|  | ||||
|     Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image | ||||
|     Classifier Architecture Search". | ||||
|  | ||||
|     Args: | ||||
|       cycles: the number of cycles the algorithm should run for. | ||||
|       population_size: the number of individuals to keep in the population. | ||||
|       sample_size: the number of individuals that should participate in each tournament. | ||||
|       time_budget: the upper bound of searching cost | ||||
|  | ||||
|     Returns: | ||||
|       history: a list of `Model` instances, representing all the models computed | ||||
|           during the evolution experiment. | ||||
|     """ | ||||
|     population = collections.deque() | ||||
|     history, total_time_cost = ( | ||||
|         [], | ||||
|         0, | ||||
|     )  # Not used by the algorithm, only used to report results. | ||||
|  | ||||
|     # Initialize the population with random models. | ||||
|     while len(population) < population_size: | ||||
|         model = Model() | ||||
|         model.arch = random_arch() | ||||
|         model.accuracy, time_cost = train_and_eval( | ||||
|             model.arch, nas_bench, extra_info, dataname | ||||
|         ) | ||||
|         population.append(model) | ||||
|         history.append(model) | ||||
|         total_time_cost += time_cost | ||||
|  | ||||
|     # Carry out evolution in cycles. Each cycle produces a model and removes | ||||
|     # another. | ||||
|     # while len(history) < cycles: | ||||
|     while total_time_cost < time_budget: | ||||
|         # Sample randomly chosen models from the current population. | ||||
|         start_time, sample = time.time(), [] | ||||
|         while len(sample) < sample_size: | ||||
|             # Inefficient, but written this way for clarity. In the case of neural | ||||
|             # nets, the efficiency of this line is irrelevant because training neural | ||||
|             # nets is the rate-determining step. | ||||
|             candidate = random.choice(list(population)) | ||||
|             sample.append(candidate) | ||||
|  | ||||
|         # The parent is the best model in the sample. | ||||
|         parent = max(sample, key=lambda i: i.accuracy) | ||||
|  | ||||
|         # Create the child model and store it. | ||||
|         child = Model() | ||||
|         child.arch = mutate_arch(parent.arch) | ||||
|         total_time_cost += time.time() - start_time | ||||
|         child.accuracy, time_cost = train_and_eval( | ||||
|             child.arch, nas_bench, extra_info, dataname | ||||
|         ) | ||||
|         if total_time_cost + time_cost > time_budget:  # return | ||||
|             return history, total_time_cost | ||||
|         else: | ||||
|             total_time_cost += time_cost | ||||
|         population.append(child) | ||||
|         history.append(child) | ||||
|  | ||||
|         # Remove the oldest model. | ||||
|         population.popleft() | ||||
|     return history, total_time_cost | ||||
|  | ||||
|  | ||||
| def main(xargs, nas_bench): | ||||
|     assert torch.cuda.is_available(), "CUDA is not available." | ||||
|     torch.backends.cudnn.enabled = True | ||||
|     torch.backends.cudnn.benchmark = False | ||||
|     torch.backends.cudnn.deterministic = True | ||||
|     torch.set_num_threads(xargs.workers) | ||||
|     prepare_seed(xargs.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|  | ||||
|     if xargs.dataset == "cifar10": | ||||
|         dataname = "cifar10-valid" | ||||
|     else: | ||||
|         dataname = xargs.dataset | ||||
|     if xargs.data_path is not None: | ||||
|         train_data, valid_data, xshape, class_num = get_datasets( | ||||
|             xargs.dataset, xargs.data_path, -1 | ||||
|         ) | ||||
|         split_Fpath = "configs/nas-benchmark/cifar-split.txt" | ||||
|         cifar_split = load_config(split_Fpath, None, None) | ||||
|         train_split, valid_split = cifar_split.train, cifar_split.valid | ||||
|         logger.log("Load split file from {:}".format(split_Fpath)) | ||||
|         config_path = "configs/nas-benchmark/algos/R-EA.config" | ||||
|         config = load_config( | ||||
|             config_path, {"class_num": class_num, "xshape": xshape}, logger | ||||
|         ) | ||||
|         # To split data | ||||
|         train_data_v2 = deepcopy(train_data) | ||||
|         train_data_v2.transform = valid_data.transform | ||||
|         valid_data = train_data_v2 | ||||
|         search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | ||||
|         # data loader | ||||
|         train_loader = torch.utils.data.DataLoader( | ||||
|             train_data, | ||||
|             batch_size=config.batch_size, | ||||
|             sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), | ||||
|             num_workers=xargs.workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|         valid_loader = torch.utils.data.DataLoader( | ||||
|             valid_data, | ||||
|             batch_size=config.batch_size, | ||||
|             sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), | ||||
|             num_workers=xargs.workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|         logger.log( | ||||
|             "||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( | ||||
|                 xargs.dataset, len(train_loader), len(valid_loader), config.batch_size | ||||
|             ) | ||||
|         ) | ||||
|         logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) | ||||
|         extra_info = { | ||||
|             "config": config, | ||||
|             "train_loader": train_loader, | ||||
|             "valid_loader": valid_loader, | ||||
|         } | ||||
|     else: | ||||
|         config_path = "configs/nas-benchmark/algos/R-EA.config" | ||||
|         config = load_config(config_path, None, logger) | ||||
|         logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) | ||||
|         extra_info = {"config": config, "train_loader": None, "valid_loader": None} | ||||
|  | ||||
|     search_space = get_search_spaces("cell", xargs.search_space_name) | ||||
|     random_arch = random_architecture_func(xargs.max_nodes, search_space) | ||||
|     mutate_arch = mutate_arch_func(search_space) | ||||
|     # x =random_arch() ; y = mutate_arch(x) | ||||
|     x_start_time = time.time() | ||||
|     logger.log("{:} use nas_bench : {:}".format(time_string(), nas_bench)) | ||||
|     logger.log( | ||||
|         "-" * 30 | ||||
|         + " start searching with the time budget of {:} s".format(xargs.time_budget) | ||||
|     ) | ||||
|     history, total_cost = regularized_evolution( | ||||
|         xargs.ea_cycles, | ||||
|         xargs.ea_population, | ||||
|         xargs.ea_sample_size, | ||||
|         xargs.time_budget, | ||||
|         random_arch, | ||||
|         mutate_arch, | ||||
|         nas_bench if args.ea_fast_by_api else None, | ||||
|         extra_info, | ||||
|         dataname, | ||||
|     ) | ||||
|     logger.log( | ||||
|         "{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).".format( | ||||
|             time_string(), len(history), total_cost, time.time() - x_start_time | ||||
|         ) | ||||
|     ) | ||||
|     best_arch = max(history, key=lambda i: i.accuracy) | ||||
|     best_arch = best_arch.arch | ||||
|     logger.log("{:} best arch is {:}".format(time_string(), best_arch)) | ||||
|  | ||||
|     info = nas_bench.query_by_arch(best_arch, "200") | ||||
|     if info is None: | ||||
|         logger.log("Did not find this architecture : {:}.".format(best_arch)) | ||||
|     else: | ||||
|         logger.log("{:}".format(info)) | ||||
|     logger.log("-" * 100) | ||||
|     logger.close() | ||||
|     return logger.log_dir, nas_bench.query_index_by_arch(best_arch) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("Regularized Evolution Algorithm") | ||||
|     parser.add_argument("--data_path", type=str, help="Path to dataset") | ||||
|     parser.add_argument( | ||||
|         "--dataset", | ||||
|         type=str, | ||||
|         choices=["cifar10", "cifar100", "ImageNet16-120"], | ||||
|         help="Choose between Cifar10/100 and ImageNet-16.", | ||||
|     ) | ||||
|     # channels and number-of-cells | ||||
|     parser.add_argument("--search_space_name", type=str, help="The search space name.") | ||||
|     parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") | ||||
|     parser.add_argument("--channel", type=int, help="The number of channels.") | ||||
|     parser.add_argument( | ||||
|         "--num_cells", type=int, help="The number of cells in one stage." | ||||
|     ) | ||||
|     parser.add_argument("--ea_cycles", type=int, help="The number of cycles in EA.") | ||||
|     parser.add_argument("--ea_population", type=int, help="The population size in EA.") | ||||
|     parser.add_argument("--ea_sample_size", type=int, help="The sample size in EA.") | ||||
|     parser.add_argument( | ||||
|         "--ea_fast_by_api", | ||||
|         type=int, | ||||
|         help="Use our API to speed up the experiments or not.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--time_budget", | ||||
|         type=int, | ||||
|         help="The total time cost budge for searching (in seconds).", | ||||
|     ) | ||||
|     # log | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=2, | ||||
|         help="number of data loading workers (default: 2)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", type=str, help="Folder to save checkpoints and log." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--arch_nas_dataset", | ||||
|         type=str, | ||||
|         help="The path to load the architecture dataset (tiny-nas-benchmark).", | ||||
|     ) | ||||
|     parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|     # if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) | ||||
|     args.ea_fast_by_api = args.ea_fast_by_api > 0 | ||||
|  | ||||
|     if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset): | ||||
|         nas_bench = None | ||||
|     else: | ||||
|         print( | ||||
|             "{:} build NAS-Benchmark-API from {:}".format( | ||||
|                 time_string(), args.arch_nas_dataset | ||||
|             ) | ||||
|         ) | ||||
|         nas_bench = API(args.arch_nas_dataset) | ||||
|     if args.rand_seed < 0: | ||||
|         save_dir, all_indexes, num = None, [], 500 | ||||
|         for i in range(num): | ||||
|             print("{:} : {:03d}/{:03d}".format(time_string(), i, num)) | ||||
|             args.rand_seed = random.randint(1, 100000) | ||||
|             save_dir, index = main(args, nas_bench) | ||||
|             all_indexes.append(index) | ||||
|         torch.save(all_indexes, save_dir / "results.pth") | ||||
|     else: | ||||
|         main(args, nas_bench) | ||||
							
								
								
									
										476
									
								
								exps/NAS-Bench-201-algos/SETN.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										476
									
								
								exps/NAS-Bench-201-algos/SETN.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,476 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | ||||
| ###################################################################################### | ||||
| # One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 # | ||||
| ###################################################################################### | ||||
| import sys, time, random, argparse | ||||
| import numpy as np | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from pathlib import Path | ||||
|  | ||||
| from xautodl.config_utils import load_config, dict2config, configure2str | ||||
| from xautodl.datasets import get_datasets, get_nas_search_loaders | ||||
| from xautodl.procedures import ( | ||||
|     prepare_seed, | ||||
|     prepare_logger, | ||||
|     save_checkpoint, | ||||
|     copy_checkpoint, | ||||
|     get_optim_scheduler, | ||||
| ) | ||||
| from xautodl.utils import get_model_infos, obtain_accuracy | ||||
| from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.models import get_cell_based_tiny_net, get_search_spaces | ||||
| from nas_201_api import NASBench201API as API | ||||
|  | ||||
|  | ||||
| def search_func( | ||||
|     xloader, | ||||
|     network, | ||||
|     criterion, | ||||
|     scheduler, | ||||
|     w_optimizer, | ||||
|     a_optimizer, | ||||
|     epoch_str, | ||||
|     print_freq, | ||||
|     logger, | ||||
| ): | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     end = time.time() | ||||
|     network.train() | ||||
|     for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate( | ||||
|         xloader | ||||
|     ): | ||||
|         scheduler.update(None, 1.0 * step / len(xloader)) | ||||
|         base_targets = base_targets.cuda(non_blocking=True) | ||||
|         arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|         # measure data loading time | ||||
|         data_time.update(time.time() - end) | ||||
|  | ||||
|         # update the weights | ||||
|         sampled_arch = network.module.dync_genotype(True) | ||||
|         network.module.set_cal_mode("dynamic", sampled_arch) | ||||
|         # network.module.set_cal_mode( 'urs' ) | ||||
|         network.zero_grad() | ||||
|         _, logits = network(base_inputs) | ||||
|         base_loss = criterion(logits, base_targets) | ||||
|         base_loss.backward() | ||||
|         w_optimizer.step() | ||||
|         # record | ||||
|         base_prec1, base_prec5 = obtain_accuracy( | ||||
|             logits.data, base_targets.data, topk=(1, 5) | ||||
|         ) | ||||
|         base_losses.update(base_loss.item(), base_inputs.size(0)) | ||||
|         base_top1.update(base_prec1.item(), base_inputs.size(0)) | ||||
|         base_top5.update(base_prec5.item(), base_inputs.size(0)) | ||||
|  | ||||
|         # update the architecture-weight | ||||
|         network.module.set_cal_mode("joint") | ||||
|         network.zero_grad() | ||||
|         _, logits = network(arch_inputs) | ||||
|         arch_loss = criterion(logits, arch_targets) | ||||
|         arch_loss.backward() | ||||
|         a_optimizer.step() | ||||
|         # record | ||||
|         arch_prec1, arch_prec5 = obtain_accuracy( | ||||
|             logits.data, arch_targets.data, topk=(1, 5) | ||||
|         ) | ||||
|         arch_losses.update(arch_loss.item(), arch_inputs.size(0)) | ||||
|         arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) | ||||
|         arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         batch_time.update(time.time() - end) | ||||
|         end = time.time() | ||||
|  | ||||
|         if step % print_freq == 0 or step + 1 == len(xloader): | ||||
|             Sstr = ( | ||||
|                 "*SEARCH* " | ||||
|                 + time_string() | ||||
|                 + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)) | ||||
|             ) | ||||
|             Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( | ||||
|                 batch_time=batch_time, data_time=data_time | ||||
|             ) | ||||
|             Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( | ||||
|                 loss=base_losses, top1=base_top1, top5=base_top5 | ||||
|             ) | ||||
|             Astr = "Arch [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( | ||||
|                 loss=arch_losses, top1=arch_top1, top5=arch_top5 | ||||
|             ) | ||||
|             logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr) | ||||
|             # print (nn.functional.softmax(network.module.arch_parameters, dim=-1)) | ||||
|             # print (network.module.arch_parameters) | ||||
|     return ( | ||||
|         base_losses.avg, | ||||
|         base_top1.avg, | ||||
|         base_top5.avg, | ||||
|         arch_losses.avg, | ||||
|         arch_top1.avg, | ||||
|         arch_top5.avg, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def get_best_arch(xloader, network, n_samples): | ||||
|     with torch.no_grad(): | ||||
|         network.eval() | ||||
|         archs, valid_accs = network.module.return_topK(n_samples), [] | ||||
|         # print ('obtain the top-{:} architectures'.format(n_samples)) | ||||
|         loader_iter = iter(xloader) | ||||
|         for i, sampled_arch in enumerate(archs): | ||||
|             network.module.set_cal_mode("dynamic", sampled_arch) | ||||
|             try: | ||||
|                 inputs, targets = next(loader_iter) | ||||
|             except: | ||||
|                 loader_iter = iter(xloader) | ||||
|                 inputs, targets = next(loader_iter) | ||||
|  | ||||
|             _, logits = network(inputs) | ||||
|             val_top1, val_top5 = obtain_accuracy( | ||||
|                 logits.cpu().data, targets.data, topk=(1, 5) | ||||
|             ) | ||||
|  | ||||
|             valid_accs.append(val_top1.item()) | ||||
|  | ||||
|         best_idx = np.argmax(valid_accs) | ||||
|         best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx] | ||||
|         return best_arch, best_valid_acc | ||||
|  | ||||
|  | ||||
| def valid_func(xloader, network, criterion): | ||||
|     data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|     arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|     end = time.time() | ||||
|     with torch.no_grad(): | ||||
|         network.eval() | ||||
|         for step, (arch_inputs, arch_targets) in enumerate(xloader): | ||||
|             arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|             # measure data loading time | ||||
|             data_time.update(time.time() - end) | ||||
|             # prediction | ||||
|             _, logits = network(arch_inputs) | ||||
|             arch_loss = criterion(logits, arch_targets) | ||||
|             # record | ||||
|             arch_prec1, arch_prec5 = obtain_accuracy( | ||||
|                 logits.data, arch_targets.data, topk=(1, 5) | ||||
|             ) | ||||
|             arch_losses.update(arch_loss.item(), arch_inputs.size(0)) | ||||
|             arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) | ||||
|             arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) | ||||
|             # measure elapsed time | ||||
|             batch_time.update(time.time() - end) | ||||
|             end = time.time() | ||||
|     return arch_losses.avg, arch_top1.avg, arch_top5.avg | ||||
|  | ||||
|  | ||||
| def main(xargs): | ||||
|     assert torch.cuda.is_available(), "CUDA is not available." | ||||
|     torch.backends.cudnn.enabled = True | ||||
|     torch.backends.cudnn.benchmark = False | ||||
|     torch.backends.cudnn.deterministic = True | ||||
|     torch.set_num_threads(xargs.workers) | ||||
|     prepare_seed(xargs.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|  | ||||
|     train_data, valid_data, xshape, class_num = get_datasets( | ||||
|         xargs.dataset, xargs.data_path, -1 | ||||
|     ) | ||||
|     config = load_config( | ||||
|         xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger | ||||
|     ) | ||||
|     search_loader, _, valid_loader = get_nas_search_loaders( | ||||
|         train_data, | ||||
|         valid_data, | ||||
|         xargs.dataset, | ||||
|         "configs/nas-benchmark/", | ||||
|         (config.batch_size, config.test_batch_size), | ||||
|         xargs.workers, | ||||
|     ) | ||||
|     logger.log( | ||||
|         "||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( | ||||
|             xargs.dataset, len(search_loader), len(valid_loader), config.batch_size | ||||
|         ) | ||||
|     ) | ||||
|     logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) | ||||
|  | ||||
|     search_space = get_search_spaces("cell", xargs.search_space_name) | ||||
|     if xargs.model_config is None: | ||||
|         model_config = dict2config( | ||||
|             dict( | ||||
|                 name="SETN", | ||||
|                 C=xargs.channel, | ||||
|                 N=xargs.num_cells, | ||||
|                 max_nodes=xargs.max_nodes, | ||||
|                 num_classes=class_num, | ||||
|                 space=search_space, | ||||
|                 affine=False, | ||||
|                 track_running_stats=bool(xargs.track_running_stats), | ||||
|             ), | ||||
|             None, | ||||
|         ) | ||||
|     else: | ||||
|         model_config = load_config( | ||||
|             xargs.model_config, | ||||
|             dict( | ||||
|                 num_classes=class_num, | ||||
|                 space=search_space, | ||||
|                 affine=False, | ||||
|                 track_running_stats=bool(xargs.track_running_stats), | ||||
|             ), | ||||
|             None, | ||||
|         ) | ||||
|     logger.log("search space : {:}".format(search_space)) | ||||
|     search_model = get_cell_based_tiny_net(model_config) | ||||
|  | ||||
|     w_optimizer, w_scheduler, criterion = get_optim_scheduler( | ||||
|         search_model.get_weights(), config | ||||
|     ) | ||||
|     a_optimizer = torch.optim.Adam( | ||||
|         search_model.get_alphas(), | ||||
|         lr=xargs.arch_learning_rate, | ||||
|         betas=(0.5, 0.999), | ||||
|         weight_decay=xargs.arch_weight_decay, | ||||
|     ) | ||||
|     logger.log("w-optimizer : {:}".format(w_optimizer)) | ||||
|     logger.log("a-optimizer : {:}".format(a_optimizer)) | ||||
|     logger.log("w-scheduler : {:}".format(w_scheduler)) | ||||
|     logger.log("criterion   : {:}".format(criterion)) | ||||
|     flop, param = get_model_infos(search_model, xshape) | ||||
|     logger.log("FLOP = {:.2f} M, Params = {:.2f} MB".format(flop, param)) | ||||
|     logger.log("search-space : {:}".format(search_space)) | ||||
|     if xargs.arch_nas_dataset is None: | ||||
|         api = None | ||||
|     else: | ||||
|         api = API(xargs.arch_nas_dataset) | ||||
|     logger.log("{:} create API = {:} done".format(time_string(), api)) | ||||
|  | ||||
|     last_info, model_base_path, model_best_path = ( | ||||
|         logger.path("info"), | ||||
|         logger.path("model"), | ||||
|         logger.path("best"), | ||||
|     ) | ||||
|     network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() | ||||
|  | ||||
|     if last_info.exists():  # automatically resume from previous checkpoint | ||||
|         logger.log( | ||||
|             "=> loading checkpoint of the last-info '{:}' start".format(last_info) | ||||
|         ) | ||||
|         last_info = torch.load(last_info) | ||||
|         start_epoch = last_info["epoch"] | ||||
|         checkpoint = torch.load(last_info["last_checkpoint"]) | ||||
|         genotypes = checkpoint["genotypes"] | ||||
|         valid_accuracies = checkpoint["valid_accuracies"] | ||||
|         search_model.load_state_dict(checkpoint["search_model"]) | ||||
|         w_scheduler.load_state_dict(checkpoint["w_scheduler"]) | ||||
|         w_optimizer.load_state_dict(checkpoint["w_optimizer"]) | ||||
|         a_optimizer.load_state_dict(checkpoint["a_optimizer"]) | ||||
|         logger.log( | ||||
|             "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format( | ||||
|                 last_info, start_epoch | ||||
|             ) | ||||
|         ) | ||||
|     else: | ||||
|         logger.log("=> do not find the last-info file : {:}".format(last_info)) | ||||
|         init_genotype, _ = get_best_arch(valid_loader, network, xargs.select_num) | ||||
|         start_epoch, valid_accuracies, genotypes = 0, {"best": -1}, {-1: init_genotype} | ||||
|  | ||||
|     # start training | ||||
|     start_time, search_time, epoch_time, total_epoch = ( | ||||
|         time.time(), | ||||
|         AverageMeter(), | ||||
|         AverageMeter(), | ||||
|         config.epochs + config.warmup, | ||||
|     ) | ||||
|     for epoch in range(start_epoch, total_epoch): | ||||
|         w_scheduler.update(epoch, 0.0) | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(epoch_time.val * (total_epoch - epoch), True) | ||||
|         ) | ||||
|         epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch) | ||||
|         logger.log( | ||||
|             "\n[Search the {:}-th epoch] {:}, LR={:}".format( | ||||
|                 epoch_str, need_time, min(w_scheduler.get_lr()) | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         ( | ||||
|             search_w_loss, | ||||
|             search_w_top1, | ||||
|             search_w_top5, | ||||
|             search_a_loss, | ||||
|             search_a_top1, | ||||
|             search_a_top5, | ||||
|         ) = search_func( | ||||
|             search_loader, | ||||
|             network, | ||||
|             criterion, | ||||
|             w_scheduler, | ||||
|             w_optimizer, | ||||
|             a_optimizer, | ||||
|             epoch_str, | ||||
|             xargs.print_freq, | ||||
|             logger, | ||||
|         ) | ||||
|         search_time.update(time.time() - start_time) | ||||
|         logger.log( | ||||
|             "[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format( | ||||
|                 epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum | ||||
|             ) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format( | ||||
|                 epoch_str, search_a_loss, search_a_top1, search_a_top5 | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) | ||||
|         network.module.set_cal_mode("dynamic", genotype) | ||||
|         valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( | ||||
|             valid_loader, network, criterion | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}".format( | ||||
|                 epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype | ||||
|             ) | ||||
|         ) | ||||
|         # search_model.set_cal_mode('urs') | ||||
|         # valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion) | ||||
|         # logger.log('[{:}] URS---evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) | ||||
|         # search_model.set_cal_mode('joint') | ||||
|         # valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion) | ||||
|         # logger.log('[{:}] JOINT-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) | ||||
|         # search_model.set_cal_mode('select') | ||||
|         # valid_a_loss , valid_a_top1 , valid_a_top5  = valid_func(valid_loader, network, criterion) | ||||
|         # logger.log('[{:}] Selec-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) | ||||
|         # check the best accuracy | ||||
|         valid_accuracies[epoch] = valid_a_top1 | ||||
|  | ||||
|         genotypes[epoch] = genotype | ||||
|         logger.log( | ||||
|             "<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch]) | ||||
|         ) | ||||
|         # save checkpoint | ||||
|         save_path = save_checkpoint( | ||||
|             { | ||||
|                 "epoch": epoch + 1, | ||||
|                 "args": deepcopy(xargs), | ||||
|                 "search_model": search_model.state_dict(), | ||||
|                 "w_optimizer": w_optimizer.state_dict(), | ||||
|                 "a_optimizer": a_optimizer.state_dict(), | ||||
|                 "w_scheduler": w_scheduler.state_dict(), | ||||
|                 "genotypes": genotypes, | ||||
|                 "valid_accuracies": valid_accuracies, | ||||
|             }, | ||||
|             model_base_path, | ||||
|             logger, | ||||
|         ) | ||||
|         last_info = save_checkpoint( | ||||
|             { | ||||
|                 "epoch": epoch + 1, | ||||
|                 "args": deepcopy(args), | ||||
|                 "last_checkpoint": save_path, | ||||
|             }, | ||||
|             logger.path("info"), | ||||
|             logger, | ||||
|         ) | ||||
|         with torch.no_grad(): | ||||
|             logger.log("{:}".format(search_model.show_alphas())) | ||||
|         if api is not None: | ||||
|             logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200"))) | ||||
|         # measure elapsed time | ||||
|         epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|     # the final post procedure : count the time | ||||
|     start_time = time.time() | ||||
|     genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num) | ||||
|     search_time.update(time.time() - start_time) | ||||
|     network.module.set_cal_mode("dynamic", genotype) | ||||
|     valid_a_loss, valid_a_top1, valid_a_top5 = valid_func( | ||||
|         valid_loader, network, criterion | ||||
|     ) | ||||
|     logger.log( | ||||
|         "Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.".format( | ||||
|             genotype, valid_a_top1 | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     logger.log("\n" + "-" * 100) | ||||
|     # check the performance from the architecture dataset | ||||
|     logger.log( | ||||
|         "SETN : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format( | ||||
|             total_epoch, search_time.sum, genotype | ||||
|         ) | ||||
|     ) | ||||
|     if api is not None: | ||||
|         logger.log("{:}".format(api.query_by_arch(genotype, "200"))) | ||||
|     logger.close() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("SETN") | ||||
|     parser.add_argument("--data_path", type=str, help="Path to dataset") | ||||
|     parser.add_argument( | ||||
|         "--dataset", | ||||
|         type=str, | ||||
|         choices=["cifar10", "cifar100", "ImageNet16-120"], | ||||
|         help="Choose between Cifar10/100 and ImageNet-16.", | ||||
|     ) | ||||
|     # channels and number-of-cells | ||||
|     parser.add_argument("--search_space_name", type=str, help="The search space name.") | ||||
|     parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") | ||||
|     parser.add_argument("--channel", type=int, help="The number of channels.") | ||||
|     parser.add_argument( | ||||
|         "--num_cells", type=int, help="The number of cells in one stage." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--select_num", | ||||
|         type=int, | ||||
|         help="The number of selected architectures to evaluate.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--track_running_stats", | ||||
|         type=int, | ||||
|         choices=[0, 1], | ||||
|         help="Whether use track_running_stats or not in the BN layer.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--config_path", type=str, help="The path of the configuration." | ||||
|     ) | ||||
|     # architecture leraning rate | ||||
|     parser.add_argument( | ||||
|         "--arch_learning_rate", | ||||
|         type=float, | ||||
|         default=3e-4, | ||||
|         help="learning rate for arch encoding", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--arch_weight_decay", | ||||
|         type=float, | ||||
|         default=1e-3, | ||||
|         help="weight decay for arch encoding", | ||||
|     ) | ||||
|     # log | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=2, | ||||
|         help="number of data loading workers (default: 2)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", type=str, help="Folder to save checkpoints and log." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--arch_nas_dataset", | ||||
|         type=str, | ||||
|         help="The path to load the architecture dataset (tiny-nas-benchmark).", | ||||
|     ) | ||||
|     parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") | ||||
|     parser.add_argument("--rand_seed", type=int, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     main(args) | ||||
							
								
								
									
										294
									
								
								exps/NAS-Bench-201-algos/reinforce.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										294
									
								
								exps/NAS-Bench-201-algos/reinforce.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,294 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ##################################################################################################### | ||||
| # modified from https://github.com/pytorch/examples/blob/master/reinforcement_learning/reinforce.py # | ||||
| ##################################################################################################### | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np, collections | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from torch.distributions import Categorical | ||||
|  | ||||
| from xautodl.config_utils import load_config, dict2config, configure2str | ||||
| from xautodl.datasets import get_datasets, SearchDataset | ||||
| from xautodl.procedures import ( | ||||
|     prepare_seed, | ||||
|     prepare_logger, | ||||
|     save_checkpoint, | ||||
|     copy_checkpoint, | ||||
|     get_optim_scheduler, | ||||
| ) | ||||
| from xautodl.utils import get_model_infos, obtain_accuracy | ||||
| from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.models import CellStructure, get_search_spaces | ||||
| from nas_201_api import NASBench201API as API | ||||
| from R_EA import train_and_eval | ||||
|  | ||||
|  | ||||
| class Policy(nn.Module): | ||||
|     def __init__(self, max_nodes, search_space): | ||||
|         super(Policy, self).__init__() | ||||
|         self.max_nodes = max_nodes | ||||
|         self.search_space = deepcopy(search_space) | ||||
|         self.edge2index = {} | ||||
|         for i in range(1, max_nodes): | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 self.edge2index[node_str] = len(self.edge2index) | ||||
|         self.arch_parameters = nn.Parameter( | ||||
|             1e-3 * torch.randn(len(self.edge2index), len(search_space)) | ||||
|         ) | ||||
|  | ||||
|     def generate_arch(self, actions): | ||||
|         genotypes = [] | ||||
|         for i in range(1, self.max_nodes): | ||||
|             xlist = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 op_name = self.search_space[actions[self.edge2index[node_str]]] | ||||
|                 xlist.append((op_name, j)) | ||||
|             genotypes.append(tuple(xlist)) | ||||
|         return CellStructure(genotypes) | ||||
|  | ||||
|     def genotype(self): | ||||
|         genotypes = [] | ||||
|         for i in range(1, self.max_nodes): | ||||
|             xlist = [] | ||||
|             for j in range(i): | ||||
|                 node_str = "{:}<-{:}".format(i, j) | ||||
|                 with torch.no_grad(): | ||||
|                     weights = self.arch_parameters[self.edge2index[node_str]] | ||||
|                     op_name = self.search_space[weights.argmax().item()] | ||||
|                 xlist.append((op_name, j)) | ||||
|             genotypes.append(tuple(xlist)) | ||||
|         return CellStructure(genotypes) | ||||
|  | ||||
|     def forward(self): | ||||
|         alphas = nn.functional.softmax(self.arch_parameters, dim=-1) | ||||
|         return alphas | ||||
|  | ||||
|  | ||||
| class ExponentialMovingAverage(object): | ||||
|     """Class that maintains an exponential moving average.""" | ||||
|  | ||||
|     def __init__(self, momentum): | ||||
|         self._numerator = 0 | ||||
|         self._denominator = 0 | ||||
|         self._momentum = momentum | ||||
|  | ||||
|     def update(self, value): | ||||
|         self._numerator = ( | ||||
|             self._momentum * self._numerator + (1 - self._momentum) * value | ||||
|         ) | ||||
|         self._denominator = self._momentum * self._denominator + (1 - self._momentum) | ||||
|  | ||||
|     def value(self): | ||||
|         """Return the current value of the moving average""" | ||||
|         return self._numerator / self._denominator | ||||
|  | ||||
|  | ||||
| def select_action(policy): | ||||
|     probs = policy() | ||||
|     m = Categorical(probs) | ||||
|     action = m.sample() | ||||
|     # policy.saved_log_probs.append(m.log_prob(action)) | ||||
|     return m.log_prob(action), action.cpu().tolist() | ||||
|  | ||||
|  | ||||
| def main(xargs, nas_bench): | ||||
|     assert torch.cuda.is_available(), "CUDA is not available." | ||||
|     torch.backends.cudnn.enabled = True | ||||
|     torch.backends.cudnn.benchmark = False | ||||
|     torch.backends.cudnn.deterministic = True | ||||
|     torch.set_num_threads(xargs.workers) | ||||
|     prepare_seed(xargs.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|  | ||||
|     if xargs.dataset == "cifar10": | ||||
|         dataname = "cifar10-valid" | ||||
|     else: | ||||
|         dataname = xargs.dataset | ||||
|     if xargs.data_path is not None: | ||||
|         train_data, valid_data, xshape, class_num = get_datasets( | ||||
|             xargs.dataset, xargs.data_path, -1 | ||||
|         ) | ||||
|         split_Fpath = "configs/nas-benchmark/cifar-split.txt" | ||||
|         cifar_split = load_config(split_Fpath, None, None) | ||||
|         train_split, valid_split = cifar_split.train, cifar_split.valid | ||||
|         logger.log("Load split file from {:}".format(split_Fpath)) | ||||
|         config_path = "configs/nas-benchmark/algos/R-EA.config" | ||||
|         config = load_config( | ||||
|             config_path, {"class_num": class_num, "xshape": xshape}, logger | ||||
|         ) | ||||
|         # To split data | ||||
|         train_data_v2 = deepcopy(train_data) | ||||
|         train_data_v2.transform = valid_data.transform | ||||
|         valid_data = train_data_v2 | ||||
|         search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | ||||
|         # data loader | ||||
|         train_loader = torch.utils.data.DataLoader( | ||||
|             train_data, | ||||
|             batch_size=config.batch_size, | ||||
|             sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), | ||||
|             num_workers=xargs.workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|         valid_loader = torch.utils.data.DataLoader( | ||||
|             valid_data, | ||||
|             batch_size=config.batch_size, | ||||
|             sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), | ||||
|             num_workers=xargs.workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|         logger.log( | ||||
|             "||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( | ||||
|                 xargs.dataset, len(train_loader), len(valid_loader), config.batch_size | ||||
|             ) | ||||
|         ) | ||||
|         logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) | ||||
|         extra_info = { | ||||
|             "config": config, | ||||
|             "train_loader": train_loader, | ||||
|             "valid_loader": valid_loader, | ||||
|         } | ||||
|     else: | ||||
|         config_path = "configs/nas-benchmark/algos/R-EA.config" | ||||
|         config = load_config(config_path, None, logger) | ||||
|         extra_info = {"config": config, "train_loader": None, "valid_loader": None} | ||||
|         logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config)) | ||||
|  | ||||
|     search_space = get_search_spaces("cell", xargs.search_space_name) | ||||
|     policy = Policy(xargs.max_nodes, search_space) | ||||
|     optimizer = torch.optim.Adam(policy.parameters(), lr=xargs.learning_rate) | ||||
|     # optimizer = torch.optim.SGD(policy.parameters(), lr=xargs.learning_rate) | ||||
|     eps = np.finfo(np.float32).eps.item() | ||||
|     baseline = ExponentialMovingAverage(xargs.EMA_momentum) | ||||
|     logger.log("policy    : {:}".format(policy)) | ||||
|     logger.log("optimizer : {:}".format(optimizer)) | ||||
|     logger.log("eps       : {:}".format(eps)) | ||||
|  | ||||
|     # nas dataset load | ||||
|     logger.log("{:} use nas_bench : {:}".format(time_string(), nas_bench)) | ||||
|  | ||||
|     # REINFORCE | ||||
|     # attempts = 0 | ||||
|     x_start_time = time.time() | ||||
|     logger.log( | ||||
|         "Will start searching with time budget of {:} s.".format(xargs.time_budget) | ||||
|     ) | ||||
|     total_steps, total_costs, trace = 0, 0, [] | ||||
|     # for istep in range(xargs.RL_steps): | ||||
|     while total_costs < xargs.time_budget: | ||||
|         start_time = time.time() | ||||
|         log_prob, action = select_action(policy) | ||||
|         arch = policy.generate_arch(action) | ||||
|         reward, cost_time = train_and_eval(arch, nas_bench, extra_info, dataname) | ||||
|         trace.append((reward, arch)) | ||||
|         # accumulate time | ||||
|         if total_costs + cost_time < xargs.time_budget: | ||||
|             total_costs += cost_time | ||||
|         else: | ||||
|             break | ||||
|  | ||||
|         baseline.update(reward) | ||||
|         # calculate loss | ||||
|         policy_loss = (-log_prob * (reward - baseline.value())).sum() | ||||
|         optimizer.zero_grad() | ||||
|         policy_loss.backward() | ||||
|         optimizer.step() | ||||
|         # accumulate time | ||||
|         total_costs += time.time() - start_time | ||||
|         total_steps += 1 | ||||
|         logger.log( | ||||
|             "step [{:3d}] : average-reward={:.3f} : policy_loss={:.4f} : {:}".format( | ||||
|                 total_steps, baseline.value(), policy_loss.item(), policy.genotype() | ||||
|             ) | ||||
|         ) | ||||
|         # logger.log('----> {:}'.format(policy.arch_parameters)) | ||||
|         # logger.log('') | ||||
|  | ||||
|     # best_arch = policy.genotype() # first version | ||||
|     best_arch = max(trace, key=lambda x: x[0])[1] | ||||
|     logger.log( | ||||
|         "REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).".format( | ||||
|             total_steps, total_costs, time.time() - x_start_time | ||||
|         ) | ||||
|     ) | ||||
|     info = nas_bench.query_by_arch(best_arch, "200") | ||||
|     if info is None: | ||||
|         logger.log("Did not find this architecture : {:}.".format(best_arch)) | ||||
|     else: | ||||
|         logger.log("{:}".format(info)) | ||||
|     logger.log("-" * 100) | ||||
|     logger.close() | ||||
|     return logger.log_dir, nas_bench.query_index_by_arch(best_arch) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("The REINFORCE Algorithm") | ||||
|     parser.add_argument("--data_path", type=str, help="Path to dataset") | ||||
|     parser.add_argument( | ||||
|         "--dataset", | ||||
|         type=str, | ||||
|         choices=["cifar10", "cifar100", "ImageNet16-120"], | ||||
|         help="Choose between Cifar10/100 and ImageNet-16.", | ||||
|     ) | ||||
|     # channels and number-of-cells | ||||
|     parser.add_argument("--search_space_name", type=str, help="The search space name.") | ||||
|     parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.") | ||||
|     parser.add_argument("--channel", type=int, help="The number of channels.") | ||||
|     parser.add_argument( | ||||
|         "--num_cells", type=int, help="The number of cells in one stage." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--learning_rate", type=float, help="The learning rate for REINFORCE." | ||||
|     ) | ||||
|     # parser.add_argument('--RL_steps',           type=int,   help='The steps for REINFORCE.') | ||||
|     parser.add_argument( | ||||
|         "--EMA_momentum", type=float, help="The momentum value for EMA." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--time_budget", | ||||
|         type=int, | ||||
|         help="The total time cost budge for searching (in seconds).", | ||||
|     ) | ||||
|     # log | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=2, | ||||
|         help="number of data loading workers (default: 2)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", type=str, help="Folder to save checkpoints and log." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--arch_nas_dataset", | ||||
|         type=str, | ||||
|         help="The path to load the architecture dataset (tiny-nas-benchmark).", | ||||
|     ) | ||||
|     parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)") | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|     # if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) | ||||
|     if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset): | ||||
|         nas_bench = None | ||||
|     else: | ||||
|         print( | ||||
|             "{:} build NAS-Benchmark-API from {:}".format( | ||||
|                 time_string(), args.arch_nas_dataset | ||||
|             ) | ||||
|         ) | ||||
|         nas_bench = API(args.arch_nas_dataset) | ||||
|     if args.rand_seed < 0: | ||||
|         save_dir, all_indexes, num = None, [], 500 | ||||
|         for i in range(num): | ||||
|             print("{:} : {:03d}/{:03d}".format(time_string(), i, num)) | ||||
|             args.rand_seed = random.randint(1, 100000) | ||||
|             save_dir, index = main(args, nas_bench) | ||||
|             all_indexes.append(index) | ||||
|         torch.save(all_indexes, save_dir / "results.pth") | ||||
|     else: | ||||
|         main(args, nas_bench) | ||||
		Reference in New Issue
	
	Block a user