| 
									
										
										
										
											2020-03-09 19:38:00 +11:00
										 |  |  | ##################################################### | 
					
						
							|  |  |  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.01 # | 
					
						
							|  |  |  | ################################################################################################ | 
					
						
							|  |  |  | # python exps/NAS-Bench-201/show-best.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth # | 
					
						
							|  |  |  | ################################################################################################ | 
					
						
							| 
									
										
										
										
											2020-03-21 01:33:07 -07:00
										 |  |  | import sys, argparse | 
					
						
							| 
									
										
										
										
											2020-03-09 19:38:00 +11:00
										 |  |  | from pathlib import Path | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  | lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | 
					
						
							|  |  |  | if str(lib_dir) not in sys.path: | 
					
						
							|  |  |  |     sys.path.insert(0, str(lib_dir)) | 
					
						
							|  |  |  | from nas_201_api import NASBench201API as API | 
					
						
							| 
									
										
										
										
											2020-03-09 19:38:00 +11:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  |     parser = argparse.ArgumentParser("Analysis of NAS-Bench-201") | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--api_path", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         default=None, | 
					
						
							|  |  |  |         help="The path to the NAS-Bench-201 benchmark file.", | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     args = parser.parse_args() | 
					
						
							| 
									
										
										
										
											2020-03-09 19:38:00 +11:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     meta_file = Path(args.api_path) | 
					
						
							|  |  |  |     assert meta_file.exists(), "invalid path for api : {:}".format(meta_file) | 
					
						
							| 
									
										
										
										
											2020-03-09 19:38:00 +11:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     api = API(str(meta_file)) | 
					
						
							| 
									
										
										
										
											2020-03-09 19:38:00 +11:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     # This will show the results of the best architecture based on the validation set of each dataset. | 
					
						
							|  |  |  |     arch_index, accuracy = api.find_best("cifar10-valid", "x-valid", None, None, False) | 
					
						
							|  |  |  |     print("FOR CIFAR-010, using the hyper-parameters with 200 training epochs :::") | 
					
						
							|  |  |  |     print("arch-index={:5d}, arch={:}".format(arch_index, api.arch(arch_index))) | 
					
						
							|  |  |  |     api.show(arch_index) | 
					
						
							|  |  |  |     print("") | 
					
						
							| 
									
										
										
										
											2020-03-09 19:38:00 +11:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     arch_index, accuracy = api.find_best("cifar100", "x-valid", None, None, False) | 
					
						
							|  |  |  |     print("FOR CIFAR-100, using the hyper-parameters with 200 training epochs :::") | 
					
						
							|  |  |  |     print("arch-index={:5d}, arch={:}".format(arch_index, api.arch(arch_index))) | 
					
						
							|  |  |  |     api.show(arch_index) | 
					
						
							|  |  |  |     print("") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     arch_index, accuracy = api.find_best("ImageNet16-120", "x-valid", None, None, False) | 
					
						
							|  |  |  |     print("FOR ImageNet16-120, using the hyper-parameters with 200 training epochs :::") | 
					
						
							|  |  |  |     print("arch-index={:5d}, arch={:}".format(arch_index, api.arch(arch_index))) | 
					
						
							|  |  |  |     api.show(arch_index) | 
					
						
							|  |  |  |     print("") |