| 
									
										
										
										
											2020-07-13 10:04:52 +00:00
										 |  |  | ################################################## | 
					
						
							|  |  |  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # | 
					
						
							|  |  |  | ############################################################################## | 
					
						
							|  |  |  | # Random Search for Hyper-Parameter Optimization, JMLR 2012 ################## | 
					
						
							|  |  |  | ############################################################################## | 
					
						
							| 
									
										
										
										
											2020-08-30 08:04:52 +00:00
										 |  |  | # python ./exps/NATS-algos/random_wo_share.py --dataset cifar10 --search_space tss | 
					
						
							|  |  |  | # python ./exps/NATS-algos/random_wo_share.py --dataset cifar100 --search_space tss | 
					
						
							|  |  |  | # python ./exps/NATS-algos/random_wo_share.py --dataset ImageNet16-120 --search_space tss | 
					
						
							| 
									
										
										
										
											2020-07-13 10:04:52 +00:00
										 |  |  | ############################################################################## | 
					
						
							|  |  |  | import os, sys, time, glob, random, argparse | 
					
						
							|  |  |  | import numpy as np, collections | 
					
						
							|  |  |  | from copy import deepcopy | 
					
						
							|  |  |  | import torch | 
					
						
							|  |  |  | import torch.nn as nn | 
					
						
							|  |  |  | from pathlib import Path | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | 
					
						
							|  |  |  | if str(lib_dir) not in sys.path: | 
					
						
							|  |  |  |     sys.path.insert(0, str(lib_dir)) | 
					
						
							| 
									
										
										
										
											2020-07-13 10:04:52 +00:00
										 |  |  | from config_utils import load_config, dict2config, configure2str | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  | from datasets import get_datasets, SearchDataset | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  | from procedures import ( | 
					
						
							|  |  |  |     prepare_seed, | 
					
						
							|  |  |  |     prepare_logger, | 
					
						
							|  |  |  |     save_checkpoint, | 
					
						
							|  |  |  |     copy_checkpoint, | 
					
						
							|  |  |  |     get_optim_scheduler, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  | from utils import get_model_infos, obtain_accuracy | 
					
						
							|  |  |  | from log_utils import AverageMeter, time_string, convert_secs2time | 
					
						
							|  |  |  | from models import get_search_spaces | 
					
						
							|  |  |  | from nats_bench import create | 
					
						
							| 
									
										
										
										
											2020-07-13 11:35:13 +00:00
										 |  |  | from regularized_ea import random_topology_func, random_size_func | 
					
						
							| 
									
										
										
										
											2020-07-13 10:04:52 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def main(xargs, api): | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     torch.set_num_threads(4) | 
					
						
							|  |  |  |     prepare_seed(xargs.rand_seed) | 
					
						
							|  |  |  |     logger = prepare_logger(args) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     logger.log("{:} use api : {:}".format(time_string(), api)) | 
					
						
							|  |  |  |     api.reset_time() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     search_space = get_search_spaces(xargs.search_space, "nats-bench") | 
					
						
							|  |  |  |     if xargs.search_space == "tss": | 
					
						
							|  |  |  |         random_arch = random_topology_func(search_space) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         random_arch = random_size_func(search_space) | 
					
						
							| 
									
										
										
										
											2020-07-13 10:04:52 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     best_arch, best_acc, total_time_cost, history = None, -1, [], [] | 
					
						
							|  |  |  |     current_best_index = [] | 
					
						
							|  |  |  |     while len(total_time_cost) == 0 or total_time_cost[-1] < xargs.time_budget: | 
					
						
							|  |  |  |         arch = random_arch() | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |         accuracy, _, _, total_cost = api.simulate_train_eval( | 
					
						
							|  |  |  |             arch, xargs.dataset, hp="12" | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |         total_time_cost.append(total_cost) | 
					
						
							|  |  |  |         history.append(arch) | 
					
						
							|  |  |  |         if best_arch is None or best_acc < accuracy: | 
					
						
							|  |  |  |             best_acc, best_arch = accuracy, arch | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |         logger.log( | 
					
						
							|  |  |  |             "[{:03d}] : {:} : accuracy = {:.2f}%".format(len(history), arch, accuracy) | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |         current_best_index.append(api.query_index_by_arch(best_arch)) | 
					
						
							|  |  |  |     logger.log( | 
					
						
							|  |  |  |         "{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s.".format( | 
					
						
							|  |  |  |             time_string(), best_arch, best_acc, len(history), total_time_cost[-1] | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-07-13 11:35:13 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |     info = api.query_info_str_by_arch( | 
					
						
							|  |  |  |         best_arch, "200" if xargs.search_space == "tss" else "90" | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     logger.log("{:}".format(info)) | 
					
						
							|  |  |  |     logger.log("-" * 100) | 
					
						
							|  |  |  |     logger.close() | 
					
						
							|  |  |  |     return logger.log_dir, current_best_index, total_time_cost | 
					
						
							| 
									
										
										
										
											2020-07-13 10:04:52 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  |     parser = argparse.ArgumentParser("Random NAS") | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--dataset", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         choices=["cifar10", "cifar100", "ImageNet16-120"], | 
					
						
							|  |  |  |         help="Choose between Cifar10/100 and ImageNet-16.", | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--search_space", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         choices=["tss", "sss"], | 
					
						
							|  |  |  |         help="Choose the search space.", | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-07-13 10:04:52 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     parser.add_argument( | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |         "--time_budget", | 
					
						
							|  |  |  |         type=int, | 
					
						
							|  |  |  |         default=20000, | 
					
						
							|  |  |  |         help="The total time cost budge for searching (in seconds).", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--loops_if_rand", type=int, default=500, help="The total runs for evaluation." | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     ) | 
					
						
							|  |  |  |     # log | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--save_dir", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         default="./output/search", | 
					
						
							|  |  |  |         help="Folder to save checkpoints and log.", | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | 
					
						
							|  |  |  |     args = parser.parse_args() | 
					
						
							| 
									
										
										
										
											2020-07-13 10:04:52 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     api = create(None, args.search_space, fast_mode=True, verbose=False) | 
					
						
							| 
									
										
										
										
											2020-07-13 10:04:52 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     args.save_dir = os.path.join( | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |         "{:}-{:}".format(args.save_dir, args.search_space), | 
					
						
							|  |  |  |         "{:}-T{:}".format(args.dataset, args.time_budget), | 
					
						
							|  |  |  |         "RANDOM", | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     ) | 
					
						
							|  |  |  |     print("save-dir : {:}".format(args.save_dir)) | 
					
						
							| 
									
										
										
										
											2020-07-13 10:04:52 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     if args.rand_seed < 0: | 
					
						
							|  |  |  |         save_dir, all_info = None, collections.OrderedDict() | 
					
						
							|  |  |  |         for i in range(args.loops_if_rand): | 
					
						
							|  |  |  |             print("{:} : {:03d}/{:03d}".format(time_string(), i, args.loops_if_rand)) | 
					
						
							|  |  |  |             args.rand_seed = random.randint(1, 100000) | 
					
						
							|  |  |  |             save_dir, all_archs, all_total_times = main(args, api) | 
					
						
							|  |  |  |             all_info[i] = {"all_archs": all_archs, "all_total_times": all_total_times} | 
					
						
							|  |  |  |         save_path = save_dir / "results.pth" | 
					
						
							|  |  |  |         print("save into {:}".format(save_path)) | 
					
						
							|  |  |  |         torch.save(all_info, save_path) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         main(args, api) |