57 lines
		
	
	
		
			2.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			57 lines
		
	
	
		
			2.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| ##############################################################################
 | |
| # NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size #
 | |
| ##############################################################################
 | |
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07                          #
 | |
| ##############################################################################
 | |
| # python ./exps/NATS-Bench/Analyze-time.py                                   #
 | |
| ##############################################################################
 | |
| import os, sys, time, tqdm, argparse
 | |
| from pathlib import Path
 | |
| 
 | |
| lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
 | |
| if str(lib_dir) not in sys.path:
 | |
|     sys.path.insert(0, str(lib_dir))
 | |
| from config_utils import dict2config, load_config
 | |
| from datasets import get_datasets
 | |
| from nats_bench import create
 | |
| 
 | |
| 
 | |
| def show_time(api, epoch=12):
 | |
|     print("Show the time for {:} with {:}-epoch-training".format(api, epoch))
 | |
|     all_cifar10_time, all_cifar100_time, all_imagenet_time = 0, 0, 0
 | |
|     for index in tqdm.tqdm(range(len(api))):
 | |
|         info = api.get_more_info(index, "ImageNet16-120", hp=epoch)
 | |
|         imagenet_time = info["train-all-time"]
 | |
|         info = api.get_more_info(index, "cifar10-valid", hp=epoch)
 | |
|         cifar10_time = info["train-all-time"]
 | |
|         info = api.get_more_info(index, "cifar100", hp=epoch)
 | |
|         cifar100_time = info["train-all-time"]
 | |
|         # accumulate the time
 | |
|         all_cifar10_time += cifar10_time
 | |
|         all_cifar100_time += cifar100_time
 | |
|         all_imagenet_time += imagenet_time
 | |
|     print(
 | |
|         "The total training time for CIFAR-10        (held-out train set) is {:} seconds".format(
 | |
|             all_cifar10_time
 | |
|         )
 | |
|     )
 | |
|     print(
 | |
|         "The total training time for CIFAR-100       (held-out train set) is {:} seconds, {:.2f} times longer than that on CIFAR-10".format(
 | |
|             all_cifar100_time, all_cifar100_time / all_cifar10_time
 | |
|         )
 | |
|     )
 | |
|     print(
 | |
|         "The total training time for ImageNet-16-120 (held-out train set) is {:} seconds, {:.2f} times longer than that on CIFAR-10".format(
 | |
|             all_imagenet_time, all_imagenet_time / all_cifar10_time
 | |
|         )
 | |
|     )
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
| 
 | |
|     api_nats_tss = create(None, "tss", fast_mode=True, verbose=False)
 | |
|     show_time(api_nats_tss, 12)
 | |
| 
 | |
|     api_nats_sss = create(None, "sss", fast_mode=True, verbose=False)
 | |
|     show_time(api_nats_sss, 12)
 |