120 lines
		
	
	
		
			6.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			120 lines
		
	
	
		
			6.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| ##################################################
 | |
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
 | |
| ##############################################################################
 | |
| import os, sys, time, glob, random, argparse
 | |
| import numpy as np, collections
 | |
| from copy import deepcopy
 | |
| import torch
 | |
| import torch.nn as nn
 | |
| from pathlib import Path
 | |
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
 | |
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
 | |
| from config_utils import load_config, dict2config, configure2str
 | |
| from datasets     import get_datasets, SearchDataset
 | |
| from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
 | |
| from utils        import get_model_infos, obtain_accuracy
 | |
| from log_utils    import AverageMeter, time_string, convert_secs2time
 | |
| from models       import get_search_spaces
 | |
| from nas_201_api  import NASBench201API as API
 | |
| from R_EA         import train_and_eval, random_architecture_func
 | |
| 
 | |
| 
 | |
| def main(xargs, nas_bench):
 | |
|   assert torch.cuda.is_available(), 'CUDA is not available.'
 | |
|   torch.backends.cudnn.enabled   = True
 | |
|   torch.backends.cudnn.benchmark = False
 | |
|   torch.backends.cudnn.deterministic = True
 | |
|   torch.set_num_threads( xargs.workers )
 | |
|   prepare_seed(xargs.rand_seed)
 | |
|   logger = prepare_logger(args)
 | |
| 
 | |
|   if xargs.dataset == 'cifar10':
 | |
|     dataname = 'cifar10-valid'
 | |
|   else:
 | |
|     dataname = xargs.dataset
 | |
|   if xargs.data_path is not None:
 | |
|     train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
 | |
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
 | |
|     cifar_split = load_config(split_Fpath, None, None)
 | |
|     train_split, valid_split = cifar_split.train, cifar_split.valid
 | |
|     logger.log('Load split file from {:}'.format(split_Fpath))
 | |
|     config_path = 'configs/nas-benchmark/algos/R-EA.config'
 | |
|     config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
 | |
|     # To split data
 | |
|     train_data_v2 = deepcopy(train_data)
 | |
|     train_data_v2.transform = valid_data.transform
 | |
|     valid_data    = train_data_v2
 | |
|     search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
 | |
|     # data loader
 | |
|     train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True)
 | |
|     valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
 | |
|     logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size))
 | |
|     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
 | |
|     extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader}
 | |
|   else:
 | |
|     config_path = 'configs/nas-benchmark/algos/R-EA.config'
 | |
|     config = load_config(config_path, None, logger)
 | |
|     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
 | |
|     extra_info = {'config': config, 'train_loader': None, 'valid_loader': None}
 | |
|   search_space = get_search_spaces('cell', xargs.search_space_name)
 | |
|   random_arch = random_architecture_func(xargs.max_nodes, search_space)
 | |
|   #x =random_arch() ; y = mutate_arch(x)
 | |
|   x_start_time = time.time()
 | |
|   logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench))
 | |
|   best_arch, best_acc, total_time_cost, history = None, -1, 0, []
 | |
|   #for idx in range(xargs.random_num):
 | |
|   while total_time_cost < xargs.time_budget:
 | |
|     arch = random_arch()
 | |
|     accuracy, cost_time = train_and_eval(arch, nas_bench, extra_info, dataname)
 | |
|     if total_time_cost + cost_time > xargs.time_budget: break
 | |
|     else: total_time_cost += cost_time
 | |
|     history.append(arch)
 | |
|     if best_arch is None or best_acc < accuracy:
 | |
|       best_acc, best_arch = accuracy, arch
 | |
|     logger.log('[{:03d}] : {:} : accuracy = {:.2f}%'.format(len(history), arch, accuracy))
 | |
|   logger.log('{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s (real-cost = {:.3f} s).'.format(time_string(), best_arch, best_acc, len(history), total_time_cost, time.time()-x_start_time))
 | |
|   
 | |
|   info = nas_bench.query_by_arch( best_arch )
 | |
|   if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
 | |
|   else           : logger.log('{:}'.format(info))
 | |
|   logger.log('-'*100)
 | |
|   logger.close()
 | |
|   return logger.log_dir, nas_bench.query_index_by_arch( best_arch )
 | |
| 
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|   parser = argparse.ArgumentParser("Regularized Evolution Algorithm")
 | |
|   parser.add_argument('--data_path',          type=str,   help='Path to dataset')
 | |
|   parser.add_argument('--dataset',            type=str,   choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
 | |
|   # channels and number-of-cells
 | |
|   parser.add_argument('--search_space_name',  type=str,   help='The search space name.')
 | |
|   parser.add_argument('--max_nodes',          type=int,   help='The maximum number of nodes.')
 | |
|   parser.add_argument('--channel',            type=int,   help='The number of channels.')
 | |
|   parser.add_argument('--num_cells',          type=int,   help='The number of cells in one stage.')
 | |
|   #parser.add_argument('--random_num',         type=int,   help='The number of random selected architectures.')
 | |
|   parser.add_argument('--time_budget',        type=int,   help='The total time cost budge for searching (in seconds).')
 | |
|   # log
 | |
|   parser.add_argument('--workers',            type=int,   default=2,    help='number of data loading workers (default: 2)')
 | |
|   parser.add_argument('--save_dir',           type=str,   help='Folder to save checkpoints and log.')
 | |
|   parser.add_argument('--arch_nas_dataset',   type=str,   help='The path to load the architecture dataset (tiny-nas-benchmark).')
 | |
|   parser.add_argument('--print_freq',         type=int,   help='print frequency (default: 200)')
 | |
|   parser.add_argument('--rand_seed',          type=int,   help='manual seed')
 | |
|   args = parser.parse_args()
 | |
|   #if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
 | |
|   if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset):
 | |
|     nas_bench = None
 | |
|   else:
 | |
|     print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset))
 | |
|     nas_bench = API(args.arch_nas_dataset)
 | |
|   if args.rand_seed < 0:
 | |
|     save_dir, all_indexes, num = None, [], 500
 | |
|     for i in range(num):
 | |
|       print ('{:} : {:03d}/{:03d}'.format(time_string(), i, num))
 | |
|       args.rand_seed = random.randint(1, 100000)
 | |
|       save_dir, index = main(args, nas_bench)
 | |
|       all_indexes.append( index )
 | |
|     torch.save(all_indexes, save_dir / 'results.pth')
 | |
|   else:
 | |
|     main(args, nas_bench)
 |