2020-11-26 07:43:28 +01:00
##############################################################################
2021-01-25 14:48:14 +01:00
# NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size #
2020-11-26 07:43:28 +01:00
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 #
##############################################################################
# python ./exps/NATS-Bench/Analyze-time.py #
##############################################################################
2021-01-25 14:50:47 +01:00
import os , sys , time , tqdm , argparse
2020-11-26 07:43:28 +01:00
from pathlib import Path
lib_dir = ( Path ( __file__ ) . parent / ' .. ' / ' .. ' / ' lib ' ) . resolve ( )
if str ( lib_dir ) not in sys . path : sys . path . insert ( 0 , str ( lib_dir ) )
from config_utils import dict2config , load_config
from datasets import get_datasets
from nats_bench import create
2021-01-25 14:48:14 +01:00
def show_time ( api , epoch = 12 ) :
print ( ' Show the time for {:} with {:} -epoch-training ' . format ( api , epoch ) )
2020-11-26 07:43:28 +01:00
all_cifar10_time , all_cifar100_time , all_imagenet_time = 0 , 0 , 0
for index in tqdm . tqdm ( range ( len ( api ) ) ) :
2021-01-25 14:48:14 +01:00
info = api . get_more_info ( index , ' ImageNet16-120 ' , hp = epoch )
2020-11-26 07:43:28 +01:00
imagenet_time = info [ ' train-all-time ' ]
2021-01-25 14:48:14 +01:00
info = api . get_more_info ( index , ' cifar10-valid ' , hp = epoch )
2020-11-26 07:43:28 +01:00
cifar10_time = info [ ' train-all-time ' ]
2021-01-25 14:48:14 +01:00
info = api . get_more_info ( index , ' cifar100 ' , hp = epoch )
2020-11-26 07:43:28 +01:00
cifar100_time = info [ ' train-all-time ' ]
# accumulate the time
all_cifar10_time + = cifar10_time
all_cifar100_time + = cifar100_time
all_imagenet_time + = imagenet_time
print ( ' The total training time for CIFAR-10 (held-out train set) is {:} seconds ' . format ( all_cifar10_time ) )
print ( ' The total training time for CIFAR-100 (held-out train set) is {:} seconds, {:.2f} times longer than that on CIFAR-10 ' . format ( all_cifar100_time , all_cifar100_time / all_cifar10_time ) )
print ( ' The total training time for ImageNet-16-120 (held-out train set) is {:} seconds, {:.2f} times longer than that on CIFAR-10 ' . format ( all_imagenet_time , all_imagenet_time / all_cifar10_time ) )
if __name__ == ' __main__ ' :
api_nats_tss = create ( None , ' tss ' , fast_mode = True , verbose = False )
2021-01-25 14:48:14 +01:00
show_time ( api_nats_tss , 12 )
2020-11-26 07:43:28 +01:00
api_nats_sss = create ( None , ' sss ' , fast_mode = True , verbose = False )
2021-01-25 14:48:14 +01:00
show_time ( api_nats_sss , 12 )
2020-11-26 07:43:28 +01:00