| 
									
										
										
										
											2020-03-09 19:38:00 +11:00
										 |  |  | ##################################################### | 
					
						
							|  |  |  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.01 # | 
					
						
							|  |  |  | ################################################################################################ | 
					
						
							|  |  |  | # python exps/NAS-Bench-201/show-best.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth # | 
					
						
							|  |  |  | ################################################################################################ | 
					
						
							| 
									
										
										
										
											2020-03-21 01:33:07 -07:00
										 |  |  | import sys, argparse | 
					
						
							| 
									
										
										
										
											2020-03-09 19:38:00 +11:00
										 |  |  | from pathlib import Path | 
					
						
							|  |  |  | lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | 
					
						
							|  |  |  | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | 
					
						
							|  |  |  | from nas_201_api  import NASBench201API as API | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == '__main__': | 
					
						
							|  |  |  |   parser = argparse.ArgumentParser("Analysis of NAS-Bench-201") | 
					
						
							|  |  |  |   parser.add_argument('--api_path',  type=str, default=None,                                         help='The path to the NAS-Bench-201 benchmark file.') | 
					
						
							|  |  |  |   args = parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   meta_file = Path(args.api_path) | 
					
						
							|  |  |  |   assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   api = API(str(meta_file)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   # This will show the results of the best architecture based on the validation set of each dataset. | 
					
						
							|  |  |  |   arch_index, accuracy = api.find_best('cifar10-valid', 'x-valid', None, None, False) | 
					
						
							|  |  |  |   print('FOR CIFAR-010, using the hyper-parameters with 200 training epochs :::') | 
					
						
							|  |  |  |   print('arch-index={:5d}, arch={:}'.format(arch_index, api.arch(arch_index))) | 
					
						
							|  |  |  |   api.show(arch_index) | 
					
						
							|  |  |  |   print('') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   arch_index, accuracy = api.find_best('cifar100', 'x-valid', None, None, False) | 
					
						
							|  |  |  |   print('FOR CIFAR-100, using the hyper-parameters with 200 training epochs :::') | 
					
						
							|  |  |  |   print('arch-index={:5d}, arch={:}'.format(arch_index, api.arch(arch_index))) | 
					
						
							|  |  |  |   api.show(arch_index) | 
					
						
							|  |  |  |   print('') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   arch_index, accuracy = api.find_best('ImageNet16-120', 'x-valid', None, None, False) | 
					
						
							|  |  |  |   print('FOR ImageNet16-120, using the hyper-parameters with 200 training epochs :::') | 
					
						
							|  |  |  |   print('arch-index={:5d}, arch={:}'.format(arch_index, api.arch(arch_index))) | 
					
						
							|  |  |  |   api.show(arch_index) | 
					
						
							|  |  |  |   print('') |