Update NATS-Bench (tss version 0.8)
This commit is contained in:
		| @@ -27,6 +27,7 @@ from procedures   import bench_evaluate_for_seed | ||||
| from procedures   import get_machine_info | ||||
| from datasets     import get_datasets | ||||
| from log_utils    import Logger, AverageMeter, time_string, convert_secs2time | ||||
| from utils        import split_str2indexes | ||||
|  | ||||
|  | ||||
| def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Text], | ||||
| @@ -107,7 +108,6 @@ def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text], | ||||
|   logger.log('xargs : seeds      = {:}'.format(seeds)) | ||||
|   logger.log('xargs : cover_mode = {:}'.format(cover_mode)) | ||||
|   logger.log('-' * 100) | ||||
|  | ||||
|   logger.log( | ||||
|     'Start evaluating range =: {:06d} - {:06d}'.format(min(to_evaluate_indexes), max(to_evaluate_indexes)) | ||||
|    +'({:} in total) / {:06d} with cover-mode={:}'.format(len(to_evaluate_indexes), len(nets), cover_mode)) | ||||
| @@ -115,7 +115,6 @@ def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text], | ||||
|     logger.log( | ||||
|       '--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}'.format(i, len(datasets), dataset, xpath, split)) | ||||
|   logger.log('--->>> optimization config : {:}'.format(opt_config)) | ||||
|   #to_evaluate_indexes = list(range(srange[0], srange[1] + 1)) | ||||
|  | ||||
|   start_time, epoch_time = time.time(), AverageMeter() | ||||
|   for i, index in enumerate(to_evaluate_indexes): | ||||
| @@ -136,7 +135,9 @@ def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text], | ||||
|           logger.log('Find existing file : {:}, skip this evaluation'.format(to_save_name)) | ||||
|           has_continue = True | ||||
|           continue | ||||
|       results = evaluate_all_datasets(channelstr, datasets, xpaths, splits, opt_config, seed, workers, logger) | ||||
|       results = evaluate_all_datasets(channelstr, | ||||
|                                       datasets, xpaths, splits, opt_config, seed, | ||||
|                                        workers, logger) | ||||
|       torch.save(results, to_save_name) | ||||
|       logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] ===>>> {:}'.format(time_string(), i, | ||||
|                   len(to_evaluate_indexes), index, len(nets), seeds, to_save_name)) | ||||
| @@ -224,20 +225,7 @@ if __name__ == '__main__': | ||||
|     raise ValueError('{:} is not a file.'.format(opt_config)) | ||||
|   save_dir = Path(args.save_dir) / 'raw-data-{:}'.format(args.hyper) | ||||
|   save_dir.mkdir(parents=True, exist_ok=True) | ||||
|   if not isinstance(args.srange, str): | ||||
|     raise ValueError('Invalid scheme for {:}'.format(args.srange)) | ||||
|   srangestr = "".join(args.srange.split()) | ||||
|   to_evaluate_indexes = set() | ||||
|   for srange in srangestr.split(','): | ||||
|     srange = srange.split('-') | ||||
|     if len(srange) != 2: | ||||
|       raise ValueError('invalid srange : {:}'.format(srange)) | ||||
|     assert len(srange[0]) == len(srange[1]) == 5, 'invalid srange : {:}'.format(srange) | ||||
|     srange = (int(srange[0]), int(srange[1])) | ||||
|     if not (0 <= srange[0] <= srange[1] < args.check_N): | ||||
|       raise ValueError('{:} vs {:} vs {:}'.format(srange[0], srange[1], args.check_N)) | ||||
|     for i in range(srange[0], srange[1]+1): | ||||
|       to_evaluate_indexes.add(i) | ||||
|   to_evaluate_indexes = split_str2indexes(args.srange, args.check_N, 5) | ||||
|  | ||||
|   if not len(args.seeds): | ||||
|     raise ValueError('invalid length of seeds args: {:}'.format(args.seeds)) | ||||
|   | ||||
| @@ -5,13 +5,20 @@ | ||||
| ############################################################################## | ||||
| # This file is used to train (all) architecture candidate in the topology    # | ||||
| # search space in NATS-Bench (tss) with different hyper-parameters.          # | ||||
| # When use mode=meta, | ||||
| ### | ||||
| # When use mode=new, it will automatically detect whether the checkpoint of  # | ||||
| # a trial exists, if so, it will skip this trial. When use mode=cover, it    # | ||||
| # will ignore the (possible) existing checkpoint, run each trial, and save.  # | ||||
| ############################################################################## | ||||
| # 1, generate meta data:                                                     # | ||||
| # Please use the script of scripts/NATS-Bench/train-topology.sh to run.      # | ||||
| # bash scripts/NATS-Bench/train-topology.sh 00000-15624 12 777               # | ||||
| # bash scripts/NATS-Bench/train-topology.sh 00000-15624 200 '777 888 999'    # | ||||
| #                                                                            # | ||||
| ################                                                             # | ||||
| # [Deprecated Function: Generate the meta information]                       # | ||||
| # python ./exps/NATS-Bench/main-tss.py --mode meta                           # | ||||
| ############################################################################## | ||||
| import os, sys, time, torch, random, argparse | ||||
| from typing import List, Text, Dict, Any | ||||
| from PIL     import ImageFile | ||||
| ImageFile.LOAD_TRUNCATED_IMAGES = True | ||||
| from copy    import deepcopy | ||||
| @@ -19,16 +26,18 @@ from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||
| from config_utils import load_config | ||||
| from config_utils import dict2config, load_config | ||||
| from procedures   import bench_evaluate_for_seed | ||||
| from procedures   import get_machine_info | ||||
| from datasets     import get_datasets | ||||
| from log_utils    import Logger, AverageMeter, time_string, convert_secs2time | ||||
| from models       import CellStructure, CellArchitectures, get_search_spaces | ||||
| from utils        import split_str2indexes | ||||
|  | ||||
|  | ||||
| def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger): | ||||
|   machine_info, arch_config = get_machine_info(), deepcopy(arch_config) | ||||
| def evaluate_all_datasets(arch: Text, datasets: List[Text], xpaths: List[Text], | ||||
|                           splits: List[Text], config_path: Text, seed: int, raw_arch_config, workers, logger): | ||||
|   machine_info, raw_arch_config = get_machine_info(), deepcopy(raw_arch_config) | ||||
|   all_infos = {'info': machine_info} | ||||
|   all_dataset_keys = [] | ||||
|   # look all the datasets | ||||
| @@ -37,19 +46,12 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_c | ||||
|     train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) | ||||
|     # load the configuration | ||||
|     if dataset == 'cifar10' or dataset == 'cifar100': | ||||
|       if use_less: config_path = 'configs/nas-benchmark/LESS.config' | ||||
|       else       : config_path = 'configs/nas-benchmark/CIFAR.config' | ||||
|       split_info  = load_config('configs/nas-benchmark/cifar-split.txt', None, None) | ||||
|     elif dataset.startswith('ImageNet16'): | ||||
|       if use_less: config_path = 'configs/nas-benchmark/LESS.config' | ||||
|       else       : config_path = 'configs/nas-benchmark/ImageNet-16.config' | ||||
|       split_info  = load_config('configs/nas-benchmark/{:}-split.txt'.format(dataset), None, None) | ||||
|     else: | ||||
|       raise ValueError('invalid dataset : {:}'.format(dataset)) | ||||
|     config = load_config(config_path, \ | ||||
|                             {'class_num': class_num, | ||||
|                              'xshape'   : xshape}, \ | ||||
|                             logger) | ||||
|     config = load_config(config_path, dict(class_num=class_num, xshape=xshape), logger) | ||||
|     # check whether use splited validation set | ||||
|     if bool(split): | ||||
|       assert dataset == 'cifar10' | ||||
| @@ -89,6 +91,8 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_c | ||||
|     logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format(dataset_key, config)) | ||||
|     for key, value in ValLoaders.items(): | ||||
|       logger.log('Evaluate ---->>>> {:10s} with {:} batchs'.format(key, len(value))) | ||||
|     arch_config = dict2config(dict(name='infer.tiny', C=raw_arch_config['channel'], N=raw_arch_config['num_cells'], | ||||
|                                    genotype=arch, num_classes=config.class_num), None) | ||||
|     results = bench_evaluate_for_seed(arch_config, config, train_loader, ValLoaders, seed, logger) | ||||
|     all_infos[dataset_key] = results | ||||
|     all_dataset_keys.append( dataset_key ) | ||||
| @@ -96,71 +100,59 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_c | ||||
|   return all_infos | ||||
|  | ||||
|  | ||||
| def main(save_dir, workers, datasets, xpaths, splits, use_less, srange, arch_index, seeds, cover_mode, meta_info, arch_config): | ||||
|   assert torch.cuda.is_available(), 'CUDA is not available.' | ||||
|   torch.backends.cudnn.enabled   = True | ||||
|   #torch.backends.cudnn.benchmark = True | ||||
|   torch.backends.cudnn.deterministic = True | ||||
|   torch.set_num_threads( workers ) | ||||
| def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text], | ||||
|          splits: List[int], seeds: List[int], nets: List[str], opt_config: Dict[Text, Any], | ||||
|          to_evaluate_indexes: tuple, cover_mode: bool, arch_config: Dict[Text, Any]): | ||||
|  | ||||
|   assert len(srange) == 2 and 0 <= srange[0] <= srange[1], 'invalid srange : {:}'.format(srange) | ||||
|   log_dir = save_dir / 'logs' | ||||
|   log_dir.mkdir(parents=True, exist_ok=True) | ||||
|   logger = Logger(str(log_dir), os.getpid(), False) | ||||
|  | ||||
|   if use_less: | ||||
|     sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}-LESS'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells']) | ||||
|   else: | ||||
|     sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells']) | ||||
|   logger  = Logger(str(sub_dir), 0, False) | ||||
|  | ||||
|   all_archs = meta_info['archs'] | ||||
|   assert srange[1] < meta_info['total'], 'invalid range : {:}-{:} vs. {:}'.format(srange[0], srange[1], meta_info['total']) | ||||
|   assert arch_index == -1 or srange[0] <= arch_index <= srange[1], 'invalid range : {:} vs. {:} vs. {:}'.format(srange[0], arch_index, srange[1]) | ||||
|   if arch_index == -1: | ||||
|     to_evaluate_indexes = list(range(srange[0], srange[1]+1)) | ||||
|   else: | ||||
|     to_evaluate_indexes = [arch_index] | ||||
|   logger.log('xargs : seeds      = {:}'.format(seeds)) | ||||
|   logger.log('xargs : arch_index = {:}'.format(arch_index)) | ||||
|   logger.log('xargs : cover_mode = {:}'.format(cover_mode)) | ||||
|   logger.log('-'*100) | ||||
|  | ||||
|   logger.log('Start evaluating range =: {:06d} vs. {:06d} vs. {:06d} / {:06d} with cover-mode={:}'.format(srange[0], arch_index, srange[1], meta_info['total'], cover_mode)) | ||||
|   logger.log('-' * 100) | ||||
|   logger.log( | ||||
|     'Start evaluating range =: {:06d} - {:06d}'.format(min(to_evaluate_indexes), max(to_evaluate_indexes)) | ||||
|    +'({:} in total) / {:06d} with cover-mode={:}'.format(len(to_evaluate_indexes), len(nets), cover_mode)) | ||||
|   for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)): | ||||
|     logger.log('--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}'.format(i, len(datasets), dataset, xpath, split)) | ||||
|   logger.log('--->>> architecture config : {:}'.format(arch_config)) | ||||
|    | ||||
|     logger.log( | ||||
|       '--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}'.format(i, len(datasets), dataset, xpath, split)) | ||||
|   logger.log('--->>> optimization config : {:}'.format(opt_config)) | ||||
|  | ||||
|   start_time, epoch_time = time.time(), AverageMeter() | ||||
|   for i, index in enumerate(to_evaluate_indexes): | ||||
|     arch = all_archs[index] | ||||
|     logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th architecture [seeds={:}] {:}'.format('-'*15, i, len(to_evaluate_indexes), index, meta_info['total'], seeds, '-'*15)) | ||||
|     #logger.log('{:} {:} {:}'.format('-'*15, arch.tostr(), '-'*15)) | ||||
|     logger.log('{:} {:} {:}'.format('-'*15, arch, '-'*15)) | ||||
|     arch = nets[index] | ||||
|     logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] {:}'.format(time_string(), i, | ||||
|                        len(to_evaluate_indexes), index, len(nets), seeds, '-' * 15)) | ||||
|     logger.log('{:} {:} {:}'.format('-' * 15, arch, '-' * 15)) | ||||
|  | ||||
|     # test this arch on different datasets with different seeds | ||||
|     has_continue = False | ||||
|     for seed in seeds: | ||||
|       to_save_name = sub_dir / 'arch-{:06d}-seed-{:04d}.pth'.format(index, seed) | ||||
|       to_save_name = save_dir / 'arch-{:06d}-seed-{:04d}.pth'.format(index, seed) | ||||
|       if to_save_name.exists(): | ||||
|         if cover_mode: | ||||
|           logger.log('Find existing file : {:}, remove it before evaluation'.format(to_save_name)) | ||||
|           os.remove(str(to_save_name)) | ||||
|         else         : | ||||
|         else: | ||||
|           logger.log('Find existing file : {:}, skip this evaluation'.format(to_save_name)) | ||||
|           has_continue = True | ||||
|           continue | ||||
|       results = evaluate_all_datasets(CellStructure.str2structure(arch), \ | ||||
|                                         datasets, xpaths, splits, use_less, seed, \ | ||||
|       results = evaluate_all_datasets(CellStructure.str2structure(arch), | ||||
|                                       datasets, xpaths, splits, opt_config, seed, | ||||
|                                       arch_config, workers, logger) | ||||
|       torch.save(results, to_save_name) | ||||
|       logger.log('{:} --evaluate-- {:06d}/{:06d} ({:06d}/{:06d})-th seed={:} done, save into {:}'.format('-'*15, i, len(to_evaluate_indexes), index, meta_info['total'], seed, to_save_name)) | ||||
|       logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] ===>>> {:}'.format(time_string(), i, | ||||
|                   len(to_evaluate_indexes), index, len(nets), seeds, to_save_name)) | ||||
|     # measure elapsed time | ||||
|     if not has_continue: epoch_time.update(time.time() - start_time) | ||||
|     start_time = time.time() | ||||
|     need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes)-i-1), True) ) | ||||
|     logger.log('This arch costs : {:}'.format( convert_secs2time(epoch_time.val, True) )) | ||||
|     logger.log('{:}'.format('*'*100)) | ||||
|     logger.log('{:}   {:74s}   {:}'.format('*'*10, '{:06d}/{:06d} ({:06d}/{:06d})-th done, left {:}'.format(i, len(to_evaluate_indexes), index, meta_info['total'], need_time), '*'*10)) | ||||
|     logger.log('{:}'.format('*'*100)) | ||||
|     need_time = 'Time Left: {:}'.format(convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes)-i-1), True) ) | ||||
|     logger.log('This arch costs : {:}'.format(convert_secs2time(epoch_time.val, True) )) | ||||
|     logger.log('{:}'.format('*' * 100)) | ||||
|     logger.log('{:}   {:74s}   {:}'.format('*' * 10, '{:06d}/{:06d} ({:06d}/{:06d})-th done, left {:}'.format(i, len( | ||||
|       to_evaluate_indexes), index, len(nets), need_time), '*' * 10)) | ||||
|     logger.log('{:}'.format('*' * 100)) | ||||
|  | ||||
|   logger.close() | ||||
|  | ||||
| @@ -256,28 +248,34 @@ def generate_meta_info(save_dir, max_node, divide=40): | ||||
|   torch.save(info, save_name) | ||||
|   print ('save the meta file into {:}'.format(save_name)) | ||||
|  | ||||
|   """ | ||||
|   script_name_full = save_dir / 'BENCH-201-N{:}.opt-full.script'.format(max_node) | ||||
|   script_name_less = save_dir / 'BENCH-201-N{:}.opt-less.script'.format(max_node) | ||||
|   full_file = open(str(script_name_full), 'w') | ||||
|   less_file = open(str(script_name_less), 'w') | ||||
|   gaps = total_arch // divide | ||||
|   for start in range(0, total_arch, gaps): | ||||
|     xend = min(start+gaps, total_arch) | ||||
|     full_file.write('bash ./scripts-search/NAS-Bench-201/train-models.sh 0 {:5d} {:5d} -1 \'777 888 999\'\n'.format(start, xend-1)) | ||||
|     less_file.write('bash ./scripts-search/NAS-Bench-201/train-models.sh 1 {:5d} {:5d} -1 \'777 888 999\'\n'.format(start, xend-1)) | ||||
|   print ('save the training script into {:} and {:}'.format(script_name_full, script_name_less)) | ||||
|   full_file.close() | ||||
|   less_file.close() | ||||
|  | ||||
|   script_name = save_dir / 'meta-node-{:}.cal-script.txt'.format(max_node) | ||||
|   macro = 'OMP_NUM_THREADS=6 CUDA_VISIBLE_DEVICES=0' | ||||
|   with open(str(script_name), 'w') as cfile: | ||||
|     for start in range(0, total_arch, gaps): | ||||
|       xend = min(start+gaps, total_arch) | ||||
|       cfile.write('{:} python exps/NAS-Bench-201/statistics.py --mode cal --target_dir {:06d}-{:06d}-C16-N5\n'.format(macro, start, xend-1)) | ||||
|   print ('save the post-processing script into {:}'.format(script_name)) | ||||
|   """ | ||||
| def traverse_net(max_node): | ||||
|   aa_nas_bench_ss = get_search_spaces('cell', 'nats-bench') | ||||
|   archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False) | ||||
|   print ('There are {:} archs vs {:}.'.format(len(archs), len(aa_nas_bench_ss) ** ((max_node-1)*max_node/2))) | ||||
|  | ||||
|   random.seed( 88 ) # please do not change this line for reproducibility | ||||
|   random.shuffle( archs ) | ||||
|   assert archs[0  ].tostr() == '|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|', 'please check the 0-th architecture : {:}'.format(archs[0]) | ||||
|   assert archs[9  ].tostr() == '|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|', 'please check the 9-th architecture : {:}'.format(archs[9]) | ||||
|   assert archs[123].tostr() == '|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|', 'please check the 123-th architecture : {:}'.format(archs[123]) | ||||
|   return [x.tostr() for x in archs] | ||||
|  | ||||
|  | ||||
| def filter_indexes(xlist, mode, save_dir, seeds): | ||||
|   all_indexes = [] | ||||
|   for index in xlist: | ||||
|     if mode == 'cover': | ||||
|       all_indexes.append(index) | ||||
|     else: | ||||
|       for seed in seeds: | ||||
|         temp_path = save_dir / 'arch-{:06d}-seed-{:04d}.pth'.format(index, seed) | ||||
|         if not temp_path.exists(): | ||||
|           all_indexes.append(index) | ||||
|           break | ||||
|   print('{:} [FILTER-INDEXES] : there are {:}/{:} architectures in total'.format(time_string(), len(all_indexes), len(xlist))) | ||||
|   return all_indexes | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   # mode_choices = ['meta', 'new', 'cover'] + ['specific-{:}'.format(_) for _ in CellArchitectures.keys()] | ||||
| @@ -291,11 +289,12 @@ if __name__ == '__main__': | ||||
|   parser.add_argument('--datasets',    type=str,   nargs='+',      help='The applied datasets.') | ||||
|   parser.add_argument('--xpaths',      type=str,   nargs='+',      help='The root path for this dataset.') | ||||
|   parser.add_argument('--splits',      type=int,   nargs='+',      help='The root path for this dataset.') | ||||
|   parser.add_argument('--hyper',       type=str, default='12', choices=['01', '12', '90'], help='The tag for hyper-parameters.') | ||||
|   parser.add_argument('--hyper',       type=str, default='12', choices=['01', '12', '200'], help='The tag for hyper-parameters.') | ||||
|  | ||||
|   parser.add_argument('--seeds'  ,     type=int,   nargs='+',      help='The range of models to be evaluated') | ||||
|   parser.add_argument('--channel',     type=int,   default=16,     help='The number of channels.') | ||||
|   parser.add_argument('--num_cells',   type=int,   default=5,      help='The number of cells in one stage.') | ||||
|   parser.add_argument('--check_N',     type=int, default=15625,  help='For safety.') | ||||
|   args = parser.parse_args() | ||||
|  | ||||
|   assert args.mode in ['meta', 'new', 'cover'] or args.mode.startswith('specific-'), 'invalid mode : {:}'.format(args.mode) | ||||
| @@ -308,16 +307,28 @@ if __name__ == '__main__': | ||||
|     train_single_model(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, args.use_less>0, \ | ||||
|                          tuple(args.seeds), model_str, {'channel': args.channel, 'num_cells': args.num_cells}) | ||||
|   else: | ||||
|     meta_path = Path(args.save_dir) / 'meta-node-{:}.pth'.format(args.max_node) | ||||
|     assert meta_path.exists(), '{:} does not exist.'.format(meta_path) | ||||
|     meta_info = torch.load( meta_path ) | ||||
|     # check whether args is ok | ||||
|     assert len(args.srange) == 2 and args.srange[0] <= args.srange[1], 'invalid length of srange args: {:}'.format(args.srange) | ||||
|     assert len(args.seeds) > 0, 'invalid length of seeds args: {:}'.format(args.seeds) | ||||
|     assert len(args.datasets) == len(args.xpaths) == len(args.splits), 'invalid infos : {:} vs {:} vs {:}'.format(len(args.datasets), len(args.xpaths), len(args.splits)) | ||||
|     assert args.workers > 0, 'invalid number of workers : {:}'.format(args.workers) | ||||
|     nets = traverse_net(args.max_node) | ||||
|     if len(nets) != args.check_N: | ||||
|       raise ValueError('Pre-num-check failed : {:} vs {:}'.format(len(nets), args.check_N)) | ||||
|     opt_config = './configs/nas-benchmark/hyper-opts/{:}E.config'.format(args.hyper) | ||||
|     if not os.path.isfile(opt_config): | ||||
|       raise ValueError('{:} is not a file.'.format(opt_config)) | ||||
|     save_dir = Path(args.save_dir) / 'raw-data-{:}'.format(args.hyper) | ||||
|     save_dir.mkdir(parents=True, exist_ok=True) | ||||
|     to_evaluate_indexes = split_str2indexes(args.srange, args.check_N, 5) | ||||
|     if not len(args.seeds): | ||||
|       raise ValueError('invalid length of seeds args: {:}'.format(args.seeds)) | ||||
|     if not (len(args.datasets) == len(args.xpaths) == len(args.splits)): | ||||
|       raise ValueError('invalid infos : {:} vs {:} vs {:}'.format(len(args.datasets), len(args.xpaths), len(args.splits))) | ||||
|     if args.workers <= 0: | ||||
|       raise ValueError('invalid number of workers : {:}'.format(args.workers)) | ||||
|  | ||||
|     main(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, args.hyper, \ | ||||
|            tuple(args.srange), args.arch_index, tuple(args.seeds), \ | ||||
|            args.mode == 'cover', meta_info, \ | ||||
|            {'channel': args.channel, 'num_cells': args.num_cells}) | ||||
|     target_indexes = filter_indexes(to_evaluate_indexes, args.mode, save_dir, args.seeds) | ||||
|  | ||||
|     assert torch.cuda.is_available(), 'CUDA is not available.' | ||||
|     torch.backends.cudnn.enabled = True | ||||
|     torch.backends.cudnn.deterministic = True | ||||
|     torch.set_num_threads(args.workers) | ||||
|  | ||||
|     main(save_dir, args.workers, args.datasets, args.xpaths, args.splits, tuple(args.seeds), nets, opt_config, target_indexes, args.mode == 'cover', \ | ||||
|          {'name': 'infer.tiny', 'channel': args.channel, 'num_cells': args.num_cells}) | ||||
|   | ||||
| @@ -4,3 +4,4 @@ from .flop_benchmark   import get_model_infos, count_parameters_in_MB | ||||
| from .affine_utils     import normalize_points, denormalize_points | ||||
| from .affine_utils     import identity2affine, solve2theta, affine2image | ||||
| from .hash_utils       import get_md5_file | ||||
| from .str_utils        import split_str2indexes | ||||
|   | ||||
							
								
								
									
										18
									
								
								lib/utils/str_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								lib/utils/str_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,18 @@ | ||||
|  | ||||
| def split_str2indexes(string: str, max_check: int, length_limit=5): | ||||
|   if not isinstance(string, str): | ||||
|     raise ValueError('Invalid scheme for {:}'.format(string)) | ||||
|   srangestr = "".join(string.split()) | ||||
|   indexes = set() | ||||
|   for srange in srangestr.split(','): | ||||
|     srange = srange.split('-') | ||||
|     if len(srange) != 2: | ||||
|       raise ValueError('invalid srange : {:}'.format(srange)) | ||||
|     if length_limit is not None: | ||||
|       assert len(srange[0]) == len(srange[1]) == length_limit, 'invalid srange : {:}'.format(srange) | ||||
|     srange = (int(srange[0]), int(srange[1])) | ||||
|     if not (0 <= srange[0] <= srange[1] < max_check): | ||||
|       raise ValueError('{:} vs {:} vs {:}'.format(srange[0], srange[1], max_check)) | ||||
|     for i in range(srange[0], srange[1]+1): | ||||
|       indexes.add(i) | ||||
|   return indexes | ||||
							
								
								
									
										43
									
								
								scripts/NATS-Bench/train-topology.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								scripts/NATS-Bench/train-topology.sh
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,43 @@ | ||||
| #!/bin/bash | ||||
| ############################################################################## | ||||
| # NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size # | ||||
| ############################################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.01                          # | ||||
| ############################################################################## | ||||
| # CUDA_VISIBLE_DEVICES=0 bash scripts/NATS-Bench/train-topology.sh 00000-05000 12 777 | ||||
| # bash ./scripts/NATS-Bench/train-topology.sh 05001-10000 12 777 | ||||
| # bash ./scripts/NATS-Bench/train-topology.sh 10001-14500 12 777 | ||||
| # bash ./scripts/NATS-Bench/train-topology.sh 14501-15624 12 777 | ||||
| # | ||||
| ############################################################################## | ||||
| echo script name: $0 | ||||
| echo $# arguments | ||||
| if [ "$#" -ne 3 ] ;then | ||||
|   echo "Input illegal number of parameters " $# | ||||
|   echo "Need 3 parameters for start-and-end, hyper-parameters-opt-file, and seeds" | ||||
|   exit 1 | ||||
| fi | ||||
| if [ "$TORCH_HOME" = "" ]; then | ||||
|   echo "Must set TORCH_HOME envoriment variable for data dir saving" | ||||
|   exit 1 | ||||
| else | ||||
|   echo "TORCH_HOME : $TORCH_HOME" | ||||
| fi | ||||
|  | ||||
| srange=$1 | ||||
| opt=$2 | ||||
| all_seeds=$3 | ||||
| cpus=4 | ||||
|  | ||||
| save_dir=./output/NATS-Bench-topology/ | ||||
|  | ||||
| OMP_NUM_THREADS=${cpus} python exps/NATS-Bench/main-tss.py \ | ||||
| 	--mode new --srange ${srange} --hyper ${opt} --save_dir ${save_dir} \ | ||||
| 	--datasets cifar10 cifar10 cifar100 ImageNet16-120 \ | ||||
| 	--splits   1       0       0        0 \ | ||||
| 	--xpaths $TORCH_HOME/cifar.python \ | ||||
| 		 $TORCH_HOME/cifar.python \ | ||||
| 		 $TORCH_HOME/cifar.python \ | ||||
| 		 $TORCH_HOME/cifar.python/ImageNet16 \ | ||||
| 	--workers ${cpus} \ | ||||
| 	--seeds ${all_seeds} | ||||
		Reference in New Issue
	
	Block a user