| 
									
										
										
										
											2020-11-26 14:43:28 +08:00
										 |  |  | ############################################################################## | 
					
						
							| 
									
										
										
										
											2021-01-25 21:48:14 +08:00
										 |  |  | # NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size # | 
					
						
							| 
									
										
										
										
											2020-11-26 14:43:28 +08:00
										 |  |  | ############################################################################## | 
					
						
							|  |  |  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07                          # | 
					
						
							|  |  |  | ############################################################################## | 
					
						
							|  |  |  | # python ./exps/NATS-Bench/Analyze-time.py                                   # | 
					
						
							|  |  |  | ############################################################################## | 
					
						
							| 
									
										
										
										
											2021-01-25 21:50:47 +08:00
										 |  |  | import os, sys, time, tqdm, argparse | 
					
						
							| 
									
										
										
										
											2020-11-26 14:43:28 +08:00
										 |  |  | from pathlib import Path | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  | lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | 
					
						
							|  |  |  | if str(lib_dir) not in sys.path: | 
					
						
							|  |  |  |     sys.path.insert(0, str(lib_dir)) | 
					
						
							| 
									
										
										
										
											2020-11-26 14:43:28 +08:00
										 |  |  | from config_utils import dict2config, load_config | 
					
						
							|  |  |  | from datasets import get_datasets | 
					
						
							|  |  |  | from nats_bench import create | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-25 21:48:14 +08:00
										 |  |  | def show_time(api, epoch=12): | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     print("Show the time for {:} with {:}-epoch-training".format(api, epoch)) | 
					
						
							|  |  |  |     all_cifar10_time, all_cifar100_time, all_imagenet_time = 0, 0, 0 | 
					
						
							|  |  |  |     for index in tqdm.tqdm(range(len(api))): | 
					
						
							|  |  |  |         info = api.get_more_info(index, "ImageNet16-120", hp=epoch) | 
					
						
							|  |  |  |         imagenet_time = info["train-all-time"] | 
					
						
							|  |  |  |         info = api.get_more_info(index, "cifar10-valid", hp=epoch) | 
					
						
							|  |  |  |         cifar10_time = info["train-all-time"] | 
					
						
							|  |  |  |         info = api.get_more_info(index, "cifar100", hp=epoch) | 
					
						
							|  |  |  |         cifar100_time = info["train-all-time"] | 
					
						
							|  |  |  |         # accumulate the time | 
					
						
							|  |  |  |         all_cifar10_time += cifar10_time | 
					
						
							|  |  |  |         all_cifar100_time += cifar100_time | 
					
						
							|  |  |  |         all_imagenet_time += imagenet_time | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |     print( | 
					
						
							|  |  |  |         "The total training time for CIFAR-10        (held-out train set) is {:} seconds".format( | 
					
						
							|  |  |  |             all_cifar10_time | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     print( | 
					
						
							|  |  |  |         "The total training time for CIFAR-100       (held-out train set) is {:} seconds, {:.2f} times longer than that on CIFAR-10".format( | 
					
						
							|  |  |  |             all_cifar100_time, all_cifar100_time / all_cifar10_time | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     print( | 
					
						
							|  |  |  |         "The total training time for ImageNet-16-120 (held-out train set) is {:} seconds, {:.2f} times longer than that on CIFAR-10".format( | 
					
						
							|  |  |  |             all_imagenet_time, all_imagenet_time / all_cifar10_time | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     api_nats_tss = create(None, "tss", fast_mode=True, verbose=False) | 
					
						
							|  |  |  |     show_time(api_nats_tss, 12) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     api_nats_sss = create(None, "sss", fast_mode=True, verbose=False) | 
					
						
							|  |  |  |     show_time(api_nats_sss, 12) |