Update time_budget for NATS (algos)
This commit is contained in:
		
							
								
								
									
										48
									
								
								exps/NATS-Bench/Analyze-time.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								exps/NATS-Bench/Analyze-time.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,48 @@ | |||||||
|  | ############################################################################## | ||||||
|  | # NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size # | ||||||
|  | ############################################################################## | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07                          # | ||||||
|  | ############################################################################## | ||||||
|  | # python ./exps/NATS-Bench/Analyze-time.py                                   # | ||||||
|  | ############################################################################## | ||||||
|  | import os, sys, time, tqdm, torch, random, argparse | ||||||
|  | from typing import List, Text, Dict, Any | ||||||
|  | from PIL     import ImageFile | ||||||
|  | ImageFile.LOAD_TRUNCATED_IMAGES = True | ||||||
|  | from copy    import deepcopy | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  | lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||||
|  | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||||
|  | from config_utils import dict2config, load_config | ||||||
|  | from datasets import get_datasets | ||||||
|  | from nats_bench import create | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def show_time(api): | ||||||
|  |   print('Show the time for {:} with 12-epoch-training'.format(api)) | ||||||
|  |   all_cifar10_time, all_cifar100_time, all_imagenet_time = 0, 0, 0 | ||||||
|  |   for index in tqdm.tqdm(range(len(api))): | ||||||
|  |     info = api.get_more_info(index, 'ImageNet16-120', hp='12') | ||||||
|  |     imagenet_time = info['train-all-time'] | ||||||
|  |     info = api.get_more_info(index, 'cifar10-valid', hp='12') | ||||||
|  |     cifar10_time = info['train-all-time'] | ||||||
|  |     info = api.get_more_info(index, 'cifar100', hp='12') | ||||||
|  |     cifar100_time = info['train-all-time'] | ||||||
|  |     # accumulate the time | ||||||
|  |     all_cifar10_time += cifar10_time | ||||||
|  |     all_cifar100_time += cifar100_time | ||||||
|  |     all_imagenet_time += imagenet_time | ||||||
|  |   print('The total training time for CIFAR-10        (held-out train set) is {:} seconds'.format(all_cifar10_time)) | ||||||
|  |   print('The total training time for CIFAR-100       (held-out train set) is {:} seconds, {:.2f} times longer than that on CIFAR-10'.format(all_cifar100_time, all_cifar100_time / all_cifar10_time)) | ||||||
|  |   print('The total training time for ImageNet-16-120 (held-out train set) is {:} seconds, {:.2f} times longer than that on CIFAR-10'.format(all_imagenet_time, all_imagenet_time / all_cifar10_time)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |  | ||||||
|  |   api_nats_tss = create(None, 'tss', fast_mode=True, verbose=False) | ||||||
|  |   show_time(api_nats_tss) | ||||||
|  |  | ||||||
|  |   api_nats_sss = create(None, 'sss', fast_mode=True, verbose=False) | ||||||
|  |   show_time(api_nats_sss) | ||||||
|  |  | ||||||
| @@ -169,7 +169,8 @@ if __name__ == '__main__': | |||||||
|    |    | ||||||
|   api = create(None, args.search_space, fast_mode=True, verbose=False) |   api = create(None, args.search_space, fast_mode=True, verbose=False) | ||||||
|  |  | ||||||
|   args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'BOHB') |   args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), | ||||||
|  |                                '{:}-T{:}'.format(args.dataset, args.time_budget), 'BOHB') | ||||||
|   print('save-dir : {:}'.format(args.save_dir)) |   print('save-dir : {:}'.format(args.save_dir)) | ||||||
|  |  | ||||||
|   if args.rand_seed < 0: |   if args.rand_seed < 0: | ||||||
|   | |||||||
| @@ -73,7 +73,8 @@ if __name__ == '__main__': | |||||||
|    |    | ||||||
|   api = create(None, args.search_space, fast_mode=True, verbose=False) |   api = create(None, args.search_space, fast_mode=True, verbose=False) | ||||||
|  |  | ||||||
|   args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'RANDOM') |   args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), | ||||||
|  |                                '{:}-T{:}'.format(args.dataset, args.time_budget), 'RANDOM') | ||||||
|   print('save-dir : {:}'.format(args.save_dir)) |   print('save-dir : {:}'.format(args.save_dir)) | ||||||
|  |  | ||||||
|   if args.rand_seed < 0: |   if args.rand_seed < 0: | ||||||
|   | |||||||
| @@ -200,7 +200,8 @@ if __name__ == '__main__': | |||||||
|  |  | ||||||
|   api = create(None, args.search_space, fast_mode=True, verbose=False) |   api = create(None, args.search_space, fast_mode=True, verbose=False) | ||||||
|  |  | ||||||
|   args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'R-EA-SS{:}'.format(args.ea_sample_size)) |   args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), | ||||||
|  |                                '{:}-T{:}'.format(args.dataset, args.time_budget), 'R-EA-SS{:}'.format(args.ea_sample_size)) | ||||||
|   print('save-dir : {:}'.format(args.save_dir)) |   print('save-dir : {:}'.format(args.save_dir)) | ||||||
|   print('xargs : {:}'.format(args)) |   print('xargs : {:}'.format(args)) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -194,7 +194,8 @@ if __name__ == '__main__': | |||||||
|  |  | ||||||
|   api = create(None, args.search_space, fast_mode=True, verbose=False) |   api = create(None, args.search_space, fast_mode=True, verbose=False) | ||||||
|  |  | ||||||
|   args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'REINFORCE-{:}'.format(args.learning_rate)) |   args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), | ||||||
|  |                                '{:}-T{:}'.format(args.dataset, args.time_budget), 'REINFORCE-{:}'.format(args.learning_rate)) | ||||||
|   print('save-dir : {:}'.format(args.save_dir)) |   print('save-dir : {:}'.format(args.save_dir)) | ||||||
|  |  | ||||||
|   if args.rand_seed < 0: |   if args.rand_seed < 0: | ||||||
|   | |||||||
| @@ -10,26 +10,61 @@ if [ "$#" -ne 1 ] ;then | |||||||
|   exit 1 |   exit 1 | ||||||
| fi | fi | ||||||
|  |  | ||||||
|  |  | ||||||
| datasets="cifar10 cifar100 ImageNet16-120" |  | ||||||
| alg_type=$1 | alg_type=$1 | ||||||
|  |  | ||||||
| if [ "$alg_type" == "mul" ]; then | if [ "$alg_type" == "mul" ]; then | ||||||
|   search_spaces="tss sss" |   # datasets="cifar10 cifar100 ImageNet16-120" | ||||||
|  |   # The topology search space | ||||||
|  |   dataset="cifar10" | ||||||
|  |   search_space="tss" | ||||||
|  |   time_budget="20000" | ||||||
|  |   python ./exps/NATS-algos/reinforce.py       --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --learning_rate 0.01 | ||||||
|  |   python ./exps/NATS-algos/regularized_ea.py  --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --ea_cycles 200 --ea_population 10 --ea_sample_size 3 | ||||||
|  |   python ./exps/NATS-algos/random_wo_share.py --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} | ||||||
|  |   python ./exps/NATS-algos/bohb.py            --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 | ||||||
|  |  | ||||||
|   for dataset in ${datasets} |   dataset="cifar100" | ||||||
|   do |   search_space="tss" | ||||||
|     for search_space in ${search_spaces} |   time_budget="40000" | ||||||
|     do |   python ./exps/NATS-algos/reinforce.py       --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --learning_rate 0.01 | ||||||
|       python ./exps/NATS-algos/reinforce.py --dataset ${dataset} --search_space ${search_space} --learning_rate 0.01 |   python ./exps/NATS-algos/regularized_ea.py  --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --ea_cycles 200 --ea_population 10 --ea_sample_size 3 | ||||||
|       python ./exps/NATS-algos/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --ea_cycles 200 --ea_population 10 --ea_sample_size 3 |   python ./exps/NATS-algos/random_wo_share.py --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} | ||||||
|       python ./exps/NATS-algos/random_wo_share.py --dataset ${dataset} --search_space ${search_space} |   python ./exps/NATS-algos/bohb.py            --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 | ||||||
|       python ./exps/NATS-algos/bohb.py --dataset ${dataset} --search_space ${search_space} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 |  | ||||||
|     done |  | ||||||
|   done |  | ||||||
|  |  | ||||||
|   python exps/experimental/vis-bench-algos.py --search_space tss |   dataset="ImageNet16-120" | ||||||
|   python exps/experimental/vis-bench-algos.py --search_space sss |   search_space="tss" | ||||||
|  |   time_budget="120000" | ||||||
|  |   python ./exps/NATS-algos/reinforce.py       --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --learning_rate 0.01 | ||||||
|  |   python ./exps/NATS-algos/regularized_ea.py  --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --ea_cycles 200 --ea_population 10 --ea_sample_size 3 | ||||||
|  |   python ./exps/NATS-algos/random_wo_share.py --dataset ${dataset} --search_space ${search_space} | ||||||
|  |   python ./exps/NATS-algos/bohb.py            --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 | ||||||
|  |  | ||||||
|  |   # The size search space | ||||||
|  |   dataset="cifar10" | ||||||
|  |   search_space="sss" | ||||||
|  |   time_budget="20000" | ||||||
|  |   python ./exps/NATS-algos/reinforce.py       --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --learning_rate 0.01 | ||||||
|  |   python ./exps/NATS-algos/regularized_ea.py  --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --ea_cycles 200 --ea_population 10 --ea_sample_size 3 | ||||||
|  |   python ./exps/NATS-algos/random_wo_share.py --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} | ||||||
|  |   python ./exps/NATS-algos/bohb.py            --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 | ||||||
|  |  | ||||||
|  |   dataset="cifar100" | ||||||
|  |   search_space="sss" | ||||||
|  |   time_budget="40000" | ||||||
|  |   python ./exps/NATS-algos/reinforce.py       --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --learning_rate 0.01 | ||||||
|  |   python ./exps/NATS-algos/regularized_ea.py  --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --ea_cycles 200 --ea_population 10 --ea_sample_size 3 | ||||||
|  |   python ./exps/NATS-algos/random_wo_share.py --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} | ||||||
|  |   python ./exps/NATS-algos/bohb.py            --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 | ||||||
|  |  | ||||||
|  |   dataset="ImageNet16-120" | ||||||
|  |   search_space="tss" | ||||||
|  |   time_budget="60000" | ||||||
|  |   python ./exps/NATS-algos/reinforce.py       --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --learning_rate 0.01 | ||||||
|  |   python ./exps/NATS-algos/regularized_ea.py  --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --ea_cycles 200 --ea_population 10 --ea_sample_size 3 | ||||||
|  |   python ./exps/NATS-algos/random_wo_share.py --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} | ||||||
|  |   python ./exps/NATS-algos/bohb.py            --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 | ||||||
|  |   # python exps/experimental/vis-bench-algos.py --search_space tss | ||||||
|  |   # python exps/experimental/vis-bench-algos.py --search_space sss | ||||||
| else | else | ||||||
|   seeds="777 888 999" |   seeds="777 888 999" | ||||||
|   algos="darts-v1 darts-v2 gdas setn random enas" |   algos="darts-v1 darts-v2 gdas setn random enas" | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user