simplify baselines
This commit is contained in:
		
							
								
								
									
										193
									
								
								exps/NAS-Bench-102/test-correlation.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										193
									
								
								exps/NAS-Bench-102/test-correlation.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,193 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ######################################################## | ||||
| # python exps/NAS-Bench-102/test-correlation.py --api_path $HOME/.torch/NAS-Bench-102-v1_0-e61699.pth | ||||
| ######################################################## | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from pathlib import Path | ||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from config_utils import load_config, dict2config, configure2str | ||||
| from datasets     import get_datasets, SearchDataset | ||||
| from procedures   import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler | ||||
| from utils        import get_model_infos, obtain_accuracy | ||||
| from log_utils    import AverageMeter, time_string, convert_secs2time | ||||
| from models       import get_cell_based_tiny_net, get_search_spaces, CellStructure | ||||
| from nas_102_api  import NASBench102API as API | ||||
|  | ||||
|    | ||||
| def valid_func(xloader, network, criterion): | ||||
|   data_time, batch_time = AverageMeter(), AverageMeter() | ||||
|   arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() | ||||
|   network.eval() | ||||
|   end = time.time() | ||||
|   with torch.no_grad(): | ||||
|     for step, (arch_inputs, arch_targets) in enumerate(xloader): | ||||
|       arch_targets = arch_targets.cuda(non_blocking=True) | ||||
|       # measure data loading time | ||||
|       data_time.update(time.time() - end) | ||||
|       # prediction | ||||
|       _, logits = network(arch_inputs) | ||||
|       arch_loss = criterion(logits, arch_targets) | ||||
|       # record | ||||
|       arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) | ||||
|       arch_losses.update(arch_loss.item(),  arch_inputs.size(0)) | ||||
|       arch_top1.update  (arch_prec1.item(), arch_inputs.size(0)) | ||||
|       arch_top5.update  (arch_prec5.item(), arch_inputs.size(0)) | ||||
|       # measure elapsed time | ||||
|       batch_time.update(time.time() - end) | ||||
|       end = time.time() | ||||
|   return arch_losses.avg, arch_top1.avg, arch_top5.avg | ||||
|  | ||||
|  | ||||
| def main(xargs): | ||||
|   assert torch.cuda.is_available(), 'CUDA is not available.' | ||||
|   torch.backends.cudnn.enabled   = True | ||||
|   torch.backends.cudnn.benchmark = False | ||||
|   torch.backends.cudnn.deterministic = True | ||||
|   torch.set_num_threads( xargs.workers ) | ||||
|   prepare_seed(xargs.rand_seed) | ||||
|   logger = prepare_logger(args) | ||||
|  | ||||
|   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||
|   if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100': | ||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||
|     cifar_split = load_config(split_Fpath, None, None) | ||||
|     train_split, valid_split = cifar_split.train, cifar_split.valid | ||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) | ||||
|   elif xargs.dataset.startswith('ImageNet16'): | ||||
|     split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset) | ||||
|     imagenet16_split = load_config(split_Fpath, None, None) | ||||
|     train_split, valid_split = imagenet16_split.train, imagenet16_split.valid | ||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) | ||||
|   else: | ||||
|     raise ValueError('invalid dataset : {:}'.format(xargs.dataset)) | ||||
|   config_path = 'configs/nas-benchmark/algos/DARTS.config' | ||||
|   config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||
|   # To split data | ||||
|   train_data_v2 = deepcopy(train_data) | ||||
|   train_data_v2.transform = valid_data.transform | ||||
|   valid_data    = train_data_v2 | ||||
|   search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | ||||
|   # data loader | ||||
|   search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True) | ||||
|   valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | ||||
|   logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) | ||||
|   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|  | ||||
|   search_space = get_search_spaces('cell', xargs.search_space_name) | ||||
|   model_config = dict2config({'name': 'DARTS-V2', 'C': xargs.channel, 'N': xargs.num_cells, | ||||
|                               'max_nodes': xargs.max_nodes, 'num_classes': class_num, | ||||
|                               'space'    : search_space}, None) | ||||
|   search_model = get_cell_based_tiny_net(model_config) | ||||
|   logger.log('search-model :\n{:}'.format(search_model)) | ||||
|    | ||||
|   w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config) | ||||
|   a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) | ||||
|   logger.log('w-optimizer : {:}'.format(w_optimizer)) | ||||
|   logger.log('a-optimizer : {:}'.format(a_optimizer)) | ||||
|   logger.log('w-scheduler : {:}'.format(w_scheduler)) | ||||
|   logger.log('criterion   : {:}'.format(criterion)) | ||||
|   flop, param  = get_model_infos(search_model, xshape) | ||||
|   #logger.log('{:}'.format(search_model)) | ||||
|   logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) | ||||
|   if xargs.arch_nas_dataset is None: | ||||
|     api = None | ||||
|   else: | ||||
|     api = API(xargs.arch_nas_dataset) | ||||
|   logger.log('{:} create API = {:} done'.format(time_string(), api)) | ||||
|  | ||||
|   last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best') | ||||
|   network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda() | ||||
|  | ||||
|   logger.close() | ||||
|    | ||||
|  | ||||
| def check_unique_arch(meta_file): | ||||
|   api = API(str(meta_file)) | ||||
|   arch_strs = deepcopy(api.meta_archs) | ||||
|   xarchs = [CellStructure.str2structure(x) for x in arch_strs] | ||||
|   def get_unique_matrix(archs, consider_zero): | ||||
|     UniquStrs = [arch.to_unique_str(consider_zero) for arch in archs] | ||||
|     print ('{:} create unique-string ({:}/{:}) done'.format(time_string(), len(set(UniquStrs)), len(UniquStrs))) | ||||
|     Unique2Index = dict() | ||||
|     for index, xstr in enumerate(UniquStrs): | ||||
|       if xstr not in Unique2Index: Unique2Index[xstr] = list() | ||||
|       Unique2Index[xstr].append( index ) | ||||
|     sm_matrix = torch.eye(len(archs)).bool() | ||||
|     for _, xlist in Unique2Index.items(): | ||||
|       for i in xlist: | ||||
|         for j in xlist: | ||||
|           sm_matrix[i,j] = True | ||||
|     unique_ids, unique_num = [-1 for _ in archs], 0 | ||||
|     for i in range(len(unique_ids)): | ||||
|       if unique_ids[i] > -1: continue | ||||
|       neighbours = sm_matrix[i].nonzero().view(-1).tolist() | ||||
|       for nghb in neighbours: | ||||
|         assert unique_ids[nghb] == -1, 'impossible' | ||||
|         unique_ids[nghb] = unique_num | ||||
|       unique_num += 1 | ||||
|     return sm_matrix, unique_ids, unique_num | ||||
|  | ||||
|   print ('There are {:} valid-archs'.format( sum(arch.check_valid() for arch in xarchs) )) | ||||
|   sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, None) | ||||
|   print ('{:} There are {:} unique architectures (considering nothing).'.format(time_string(), unique_num)) | ||||
|   sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, False) | ||||
|   print ('{:} There are {:} unique architectures (not considering zero).'.format(time_string(), unique_num)) | ||||
|   sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs,  True) | ||||
|   print ('{:} There are {:} unique architectures (considering zero).'.format(time_string(), unique_num)) | ||||
|  | ||||
|  | ||||
| def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True): | ||||
|   if isinstance(meta_file, API): | ||||
|     api = meta_file | ||||
|   else: | ||||
|     api = API(str(meta_file)) | ||||
|   cifar10_valid     = [] | ||||
|   cifar10_test      = [] | ||||
|   cifar100_test     = [] | ||||
|   imagenet_test     = [] | ||||
|   for idx, arch in enumerate(api): | ||||
|     results = api.get_more_info(idx, 'cifar10-valid' , test_epoch-1, use_less_or_not, is_rand) | ||||
|     cifar10_valid.append( results['valid-accuracy'] ) | ||||
|     results = api.get_more_info(idx, 'cifar10'       , None, False, is_rand) | ||||
|     cifar10_test.append( results['test-accuracy'] ) | ||||
|     results = api.get_more_info(idx, 'cifar100'      , None, False, is_rand) | ||||
|     cifar100_test.append( results['test-accuracy'] ) | ||||
|     results = api.get_more_info(idx, 'ImageNet16-120', None, False, is_rand) | ||||
|     imagenet_test.append( results['test-accuracy'] ) | ||||
|   def get_cor(A, B): | ||||
|     return float(np.corrcoef(A, B)[0,1]) | ||||
|   cors = [] | ||||
|   for basestr, xlist in zip(['CIFAR-010', 'CIFAR-100', 'ImageNet16'], [cifar10_test,cifar100_test, imagenet_test]): | ||||
|     correlation = get_cor(cifar10_valid, xlist) | ||||
|     print ('With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(less_epoch, '012' if use_less_or_not else '200', basestr, correlation)) | ||||
|     cors.append( correlation ) | ||||
|     #print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist))) | ||||
|     #print('-'*200) | ||||
|   #print('*'*230) | ||||
|   return cors | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   parser = argparse.ArgumentParser("Analysis of NAS-Bench-102") | ||||
|   parser.add_argument('--save_dir',  type=str, default='./output/search-cell-nas-bench-102/visuals', help='The base-name of folder to save checkpoints and log.') | ||||
|   parser.add_argument('--api_path',  type=str, default=None,                                         help='The path to the NAS-Bench-102 benchmark file.') | ||||
|   args = parser.parse_args() | ||||
|  | ||||
|   vis_save_dir = Path(args.save_dir) | ||||
|   vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|   meta_file = Path(args.api_path) | ||||
|   assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file) | ||||
|  | ||||
|   #check_unique_arch(meta_file) | ||||
|   api = API(str(meta_file)) | ||||
|   #for iepoch in [11, 25, 50, 100, 150, 175, 200]: | ||||
|   #  check_cor_for_bandit(api,  6, iepoch) | ||||
|   #  check_cor_for_bandit(api, 12, iepoch) | ||||
|   correlations = check_cor_for_bandit(api, 6, True, True) | ||||
|   import pdb; pdb.set_trace() | ||||
| @@ -370,17 +370,17 @@ def write_video(save_dir): | ||||
| if __name__ == '__main__': | ||||
|  | ||||
|   parser = argparse.ArgumentParser(description='NAS-Bench-102', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
|   parser.add_argument('--save_dir',  type=str, default='./output/search-cell-nas-bench-102/visual', help='The base-name of folder to save checkpoints and log.') | ||||
|   parser.add_argument('--api_path',  type=str, default=None,                                        help='The path to the NAS-Bench-102 benchmark file.') | ||||
|   parser.add_argument('--save_dir',  type=str, default='./output/search-cell-nas-bench-102/visuals', help='The base-name of folder to save checkpoints and log.') | ||||
|   parser.add_argument('--api_path',  type=str, default=None,                                         help='The path to the NAS-Bench-102 benchmark file.') | ||||
|   args = parser.parse_args() | ||||
|    | ||||
|   vis_save_dir = Path(args.save_dir) / 'visuals' | ||||
|   vis_save_dir = Path(args.save_dir) | ||||
|   vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|   meta_file = Path(args.api_path) | ||||
|   assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file) | ||||
|   visualize_rank_over_time(str(meta_file), vis_save_dir / 'over-time') | ||||
|   write_video(vis_save_dir / 'over-time') | ||||
|   visualize_info(str(meta_file), 'cifar10' , vis_save_dir) | ||||
|   visualize_info(str(meta_file), 'cifar100', vis_save_dir) | ||||
|   visualize_info(str(meta_file), 'ImageNet16-120', vis_save_dir) | ||||
|   #visualize_rank_over_time(str(meta_file), vis_save_dir / 'over-time') | ||||
|   #write_video(vis_save_dir / 'over-time') | ||||
|   #visualize_info(str(meta_file), 'cifar10' , vis_save_dir) | ||||
|   #visualize_info(str(meta_file), 'cifar100', vis_save_dir) | ||||
|   #visualize_info(str(meta_file), 'ImageNet16-120', vis_save_dir) | ||||
|   visualize_relative_ranking(vis_save_dir) | ||||
|   | ||||
| @@ -110,25 +110,30 @@ def main(xargs, nas_bench): | ||||
|   logger = prepare_logger(args) | ||||
|  | ||||
|   assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' | ||||
|   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||
|   split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||
|   cifar_split = load_config(split_Fpath, None, None) | ||||
|   train_split, valid_split = cifar_split.train, cifar_split.valid | ||||
|   logger.log('Load split file from {:}'.format(split_Fpath)) | ||||
|   config_path = 'configs/nas-benchmark/algos/R-EA.config' | ||||
|   config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||
|   # To split data | ||||
|   train_data_v2 = deepcopy(train_data) | ||||
|   train_data_v2.transform = valid_data.transform | ||||
|   valid_data    = train_data_v2 | ||||
|   search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | ||||
|   # data loader | ||||
|   train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True) | ||||
|   valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | ||||
|   logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) | ||||
|   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|   extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} | ||||
|    | ||||
|   if xargs.data_path is not None: | ||||
|     train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||
|     cifar_split = load_config(split_Fpath, None, None) | ||||
|     train_split, valid_split = cifar_split.train, cifar_split.valid | ||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) | ||||
|     config_path = 'configs/nas-benchmark/algos/R-EA.config' | ||||
|     config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||
|     # To split data | ||||
|     train_data_v2 = deepcopy(train_data) | ||||
|     train_data_v2.transform = valid_data.transform | ||||
|     valid_data    = train_data_v2 | ||||
|     search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | ||||
|     # data loader | ||||
|     train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True) | ||||
|     valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | ||||
|     logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) | ||||
|     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|     extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} | ||||
|   else: | ||||
|     config_path = 'configs/nas-benchmark/algos/R-EA.config' | ||||
|     config = load_config(config_path, None, logger) | ||||
|     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|     extra_info = {'config': config, 'train_loader': None, 'valid_loader': None} | ||||
|  | ||||
|   # nas dataset load | ||||
|   assert xargs.arch_nas_dataset is not None and os.path.isfile(xargs.arch_nas_dataset) | ||||
|   | ||||
| @@ -29,25 +29,30 @@ def main(xargs, nas_bench): | ||||
|   logger = prepare_logger(args) | ||||
|  | ||||
|   assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' | ||||
|   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||
|   split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||
|   cifar_split = load_config(split_Fpath, None, None) | ||||
|   train_split, valid_split = cifar_split.train, cifar_split.valid | ||||
|   logger.log('Load split file from {:}'.format(split_Fpath)) | ||||
|   config_path = 'configs/nas-benchmark/algos/R-EA.config' | ||||
|   config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||
|   # To split data | ||||
|   train_data_v2 = deepcopy(train_data) | ||||
|   train_data_v2.transform = valid_data.transform | ||||
|   valid_data    = train_data_v2 | ||||
|   search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | ||||
|   # data loader | ||||
|   train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True) | ||||
|   valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | ||||
|   logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) | ||||
|   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|   extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} | ||||
|  | ||||
|   if xargs.data_path is not None: | ||||
|     train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||
|     cifar_split = load_config(split_Fpath, None, None) | ||||
|     train_split, valid_split = cifar_split.train, cifar_split.valid | ||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) | ||||
|     config_path = 'configs/nas-benchmark/algos/R-EA.config' | ||||
|     config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||
|     # To split data | ||||
|     train_data_v2 = deepcopy(train_data) | ||||
|     train_data_v2.transform = valid_data.transform | ||||
|     valid_data    = train_data_v2 | ||||
|     search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | ||||
|     # data loader | ||||
|     train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True) | ||||
|     valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | ||||
|     logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) | ||||
|     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|     extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} | ||||
|   else: | ||||
|     config_path = 'configs/nas-benchmark/algos/R-EA.config' | ||||
|     config = load_config(config_path, None, logger) | ||||
|     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|     extra_info = {'config': config, 'train_loader': None, 'valid_loader': None} | ||||
|   search_space = get_search_spaces('cell', xargs.search_space_name) | ||||
|   random_arch = random_architecture_func(xargs.max_nodes, search_space) | ||||
|   #x =random_arch() ; y = mutate_arch(x) | ||||
| @@ -71,7 +76,7 @@ def main(xargs, nas_bench): | ||||
|   logger.log('-'*100) | ||||
|   logger.close() | ||||
|   return logger.log_dir, nas_bench.query_index_by_arch( best_arch ) | ||||
|    | ||||
|  | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   | ||||
| @@ -172,24 +172,30 @@ def main(xargs, nas_bench): | ||||
|   logger = prepare_logger(args) | ||||
|  | ||||
|   assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' | ||||
|   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||
|   split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||
|   cifar_split = load_config(split_Fpath, None, None) | ||||
|   train_split, valid_split = cifar_split.train, cifar_split.valid | ||||
|   logger.log('Load split file from {:}'.format(split_Fpath)) | ||||
|   config_path = 'configs/nas-benchmark/algos/R-EA.config' | ||||
|   config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||
|   # To split data | ||||
|   train_data_v2 = deepcopy(train_data) | ||||
|   train_data_v2.transform = valid_data.transform | ||||
|   valid_data    = train_data_v2 | ||||
|   search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | ||||
|   # data loader | ||||
|   train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True) | ||||
|   valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | ||||
|   logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) | ||||
|   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|   extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} | ||||
|   if xargs.data_path is not None: | ||||
|     train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||
|     cifar_split = load_config(split_Fpath, None, None) | ||||
|     train_split, valid_split = cifar_split.train, cifar_split.valid | ||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) | ||||
|     config_path = 'configs/nas-benchmark/algos/R-EA.config' | ||||
|     config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||
|     # To split data | ||||
|     train_data_v2 = deepcopy(train_data) | ||||
|     train_data_v2.transform = valid_data.transform | ||||
|     valid_data    = train_data_v2 | ||||
|     search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | ||||
|     # data loader | ||||
|     train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True) | ||||
|     valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | ||||
|     logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) | ||||
|     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|     extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} | ||||
|   else: | ||||
|     config_path = 'configs/nas-benchmark/algos/R-EA.config' | ||||
|     config = load_config(config_path, None, logger) | ||||
|     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|     extra_info = {'config': config, 'train_loader': None, 'valid_loader': None} | ||||
|  | ||||
|   search_space = get_search_spaces('cell', xargs.search_space_name) | ||||
|   random_arch = random_architecture_func(xargs.max_nodes, search_space) | ||||
|   | ||||
| @@ -99,24 +99,31 @@ def main(xargs, nas_bench): | ||||
|   logger = prepare_logger(args) | ||||
|  | ||||
|   assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' | ||||
|   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||
|   split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||
|   cifar_split = load_config(split_Fpath, None, None) | ||||
|   train_split, valid_split = cifar_split.train, cifar_split.valid | ||||
|   logger.log('Load split file from {:}'.format(split_Fpath)) | ||||
|   config_path = 'configs/nas-benchmark/algos/R-EA.config' | ||||
|   config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||
|   # To split data | ||||
|   train_data_v2 = deepcopy(train_data) | ||||
|   train_data_v2.transform = valid_data.transform | ||||
|   valid_data    = train_data_v2 | ||||
|   search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | ||||
|   # data loader | ||||
|   train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True) | ||||
|   valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | ||||
|   logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) | ||||
|   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|   extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} | ||||
|   if xargs.data_path is not None: | ||||
|     train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||
|     cifar_split = load_config(split_Fpath, None, None) | ||||
|     train_split, valid_split = cifar_split.train, cifar_split.valid | ||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) | ||||
|     config_path = 'configs/nas-benchmark/algos/R-EA.config' | ||||
|     config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||
|     # To split data | ||||
|     train_data_v2 = deepcopy(train_data) | ||||
|     train_data_v2.transform = valid_data.transform | ||||
|     valid_data    = train_data_v2 | ||||
|     search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | ||||
|     # data loader | ||||
|     train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True) | ||||
|     valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | ||||
|     logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) | ||||
|     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|     extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} | ||||
|   else: | ||||
|     config_path = 'configs/nas-benchmark/algos/R-EA.config' | ||||
|     config = load_config(config_path, None, logger) | ||||
|     extra_info = {'config': config, 'train_loader': None, 'valid_loader': None} | ||||
|     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|    | ||||
|    | ||||
|   search_space = get_search_spaces('cell', xargs.search_space_name) | ||||
|   policy    = Policy(xargs.max_nodes, search_space) | ||||
|   | ||||
| @@ -74,15 +74,22 @@ class Structure: | ||||
|       nodes[i+1] = sum(sums) > 0 | ||||
|     return nodes[len(self.nodes)] | ||||
|  | ||||
|   def to_unique_str(self): | ||||
|   def to_unique_str(self, consider_zero=False): | ||||
|     # this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation | ||||
|     # two operations are special, i.e., none and skip_connect | ||||
|     nodes = {0: '0'} | ||||
|     for i_node, node_info in enumerate(self.nodes): | ||||
|       cur_node = [] | ||||
|       for op, xin in node_info: | ||||
|         if op == 'skip_connect': x = nodes[xin] | ||||
|         else: x = '('+nodes[xin]+')' + '@{:}'.format(op) | ||||
|         if consider_zero is None: | ||||
|           x = '('+nodes[xin]+')' + '@{:}'.format(op) | ||||
|         elif consider_zero: | ||||
|           if op == 'none' or nodes[xin] == '#': x = '#' # zero | ||||
|           elif op == 'skip_connect': x = nodes[xin] | ||||
|           else: x = '('+nodes[xin]+')' + '@{:}'.format(op) | ||||
|         else: | ||||
|           if op == 'skip_connect': x = nodes[xin] | ||||
|           else: x = '('+nodes[xin]+')' + '@{:}'.format(op) | ||||
|         cur_node.append(x) | ||||
|       nodes[i_node+1] = '+'.join( sorted(cur_node) ) | ||||
|     return nodes[ len(self.nodes) ] | ||||
|   | ||||
| @@ -41,8 +41,9 @@ class NASBench102API(object): | ||||
|       if verbose: print('try to create the NAS-Bench-102 api from {:}'.format(file_path_or_dict)) | ||||
|       assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict) | ||||
|       file_path_or_dict = torch.load(file_path_or_dict) | ||||
|     else: | ||||
|     elif isinstance(file_path_or_dict, dict): | ||||
|       file_path_or_dict = copy.deepcopy( file_path_or_dict ) | ||||
|     else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict))) | ||||
|     assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict)) | ||||
|     keys = ('meta_archs', 'arch2infos', 'evaluated_indexes') | ||||
|     for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key) | ||||
| @@ -152,26 +153,40 @@ class NASBench102API(object): | ||||
|     archresult = arch2infos[index] | ||||
|     return archresult.get_net_param(dataset, seed) | ||||
|  | ||||
|   def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False): | ||||
|   # obtain the metric for the `index`-th architecture | ||||
|   def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False, is_random=True): | ||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||
|     archresult = arch2infos[index] | ||||
|     if dataset == 'cifar10-valid': | ||||
|       train_info = archresult.get_metrics(dataset, 'train'   , iepoch=iepoch, is_random=True) | ||||
|       valid_info = archresult.get_metrics(dataset, 'x-valid' , iepoch=iepoch, is_random=True) | ||||
|       test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=True) | ||||
|       train_info = archresult.get_metrics(dataset, 'train'   , iepoch=iepoch, is_random=is_random) | ||||
|       valid_info = archresult.get_metrics(dataset, 'x-valid' , iepoch=iepoch, is_random=is_random) | ||||
|       try: | ||||
|         test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) | ||||
|       except: | ||||
|         test__info = None | ||||
|       total      = train_info['iepoch'] + 1 | ||||
|       return {'train-loss'    : train_info['loss'], | ||||
|       xifo = {'train-loss'    : train_info['loss'], | ||||
|               'train-accuracy': train_info['accuracy'], | ||||
|               'train-all-time': train_info['all_time'], | ||||
|               'valid-loss'    : valid_info['loss'], | ||||
|               'valid-accuracy': valid_info['accuracy'], | ||||
|               'valid-all-time': valid_info['all_time'], | ||||
|               'valid-per-time': valid_info['all_time'] / total, | ||||
|               'valid-per-time': None if valid_info['all_time'] is None else valid_info['all_time'] / total} | ||||
|       if test__info is not None: | ||||
|         xifo['test-loss']     = test__info['loss'] | ||||
|         xifo['test-accuracy'] = test__info['accuracy'] | ||||
|       return xifo | ||||
|     else: | ||||
|       train_info = archresult.get_metrics(dataset, 'train'   , iepoch=iepoch, is_random=is_random) | ||||
|       if dataset == 'cifar10': | ||||
|         test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) | ||||
|       else: | ||||
|         test__info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random) | ||||
|       return {'train-loss'    : train_info['loss'], | ||||
|               'train-accuracy': train_info['accuracy'], | ||||
|               'test-loss'     : test__info['loss'], | ||||
|               'test-accuracy' : test__info['accuracy']} | ||||
|     else: | ||||
|       raise ValueError('coming soon...') | ||||
|  | ||||
|   def show(self, index=-1): | ||||
|     if index < 0: # show all architectures | ||||
| @@ -369,7 +384,7 @@ class ResultsCount(object): | ||||
|   def update_latency(self, latency): | ||||
|     self.latency = copy.deepcopy( latency ) | ||||
|  | ||||
|   def update_eval(self, accs, losses, times): # old version | ||||
|   def update_eval(self, accs, losses, times):  # new version | ||||
|     data_names = set([x.split('@')[0] for x in accs.keys()]) | ||||
|     for data_name in data_names: | ||||
|       assert data_name not in self.eval_names, '{:} has already been added into eval-names'.format(data_name) | ||||
|   | ||||
| @@ -21,17 +21,11 @@ num_cells=5 | ||||
| max_nodes=4 | ||||
| space=nas-bench-102 | ||||
|  | ||||
| if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then | ||||
|   data_path="$TORCH_HOME/cifar.python" | ||||
| else | ||||
|   data_path="$TORCH_HOME/cifar.python/ImageNet16" | ||||
| fi | ||||
|  | ||||
| save_dir=./output/search-cell-${space}/BOHB-${dataset} | ||||
|  | ||||
| OMP_NUM_THREADS=4 python ./exps/algos/BOHB.py \ | ||||
| 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | ||||
| 	--dataset ${dataset} --data_path ${data_path} \ | ||||
| 	--dataset ${dataset} \ | ||||
| 	--search_space_name ${space} \ | ||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | ||||
| 	--time_budget 12000  \ | ||||
|   | ||||
| @@ -22,17 +22,11 @@ num_cells=5 | ||||
| max_nodes=4 | ||||
| space=nas-bench-102 | ||||
|  | ||||
| if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then | ||||
|   data_path="$TORCH_HOME/cifar.python" | ||||
| else | ||||
|   data_path="$TORCH_HOME/cifar.python/ImageNet16" | ||||
| fi | ||||
|  | ||||
| save_dir=./output/search-cell-${space}/R-EA-${dataset} | ||||
|  | ||||
| OMP_NUM_THREADS=4 python ./exps/algos/R_EA.py \ | ||||
| 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | ||||
| 	--dataset ${dataset} --data_path ${data_path} \ | ||||
| 	--dataset ${dataset} \ | ||||
| 	--search_space_name ${space} \ | ||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | ||||
| 	--time_budget 12000 \ | ||||
|   | ||||
| @@ -21,17 +21,11 @@ num_cells=5 | ||||
| max_nodes=4 | ||||
| space=nas-bench-102 | ||||
|  | ||||
| if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then | ||||
|   data_path="$TORCH_HOME/cifar.python" | ||||
| else | ||||
|   data_path="$TORCH_HOME/cifar.python/ImageNet16" | ||||
| fi | ||||
|  | ||||
| save_dir=./output/search-cell-${space}/REINFORCE-${dataset} | ||||
|  | ||||
| OMP_NUM_THREADS=4 python ./exps/algos/reinforce.py \ | ||||
| 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | ||||
| 	--dataset ${dataset} --data_path ${data_path} \ | ||||
| 	--dataset ${dataset} \ | ||||
| 	--search_space_name ${space} \ | ||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | ||||
| 	--time_budget 12000 \ | ||||
|   | ||||
| @@ -21,17 +21,11 @@ num_cells=5 | ||||
| max_nodes=4 | ||||
| space=nas-bench-102 | ||||
|  | ||||
| if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then | ||||
|   data_path="$TORCH_HOME/cifar.python" | ||||
| else | ||||
|   data_path="$TORCH_HOME/cifar.python/ImageNet16" | ||||
| fi | ||||
|  | ||||
| save_dir=./output/search-cell-${space}/RAND-${dataset} | ||||
|  | ||||
| OMP_NUM_THREADS=4 python ./exps/algos/RANDOM.py \ | ||||
| 	--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \ | ||||
| 	--dataset ${dataset} --data_path ${data_path} \ | ||||
| 	--dataset ${dataset} \ | ||||
| 	--search_space_name ${space} \ | ||||
| 	--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \ | ||||
| 	--time_budget 12000 \ | ||||
|   | ||||
		Reference in New Issue
	
	Block a user