Update NATS-Bench API to v1.1
This commit is contained in:
		| @@ -7,4 +7,4 @@ | ||||
| - [2020.07.01] [a45808b] Upgrade NAS-API to the 2.0 version. | ||||
| - [2020.09.16] [7052265] Create NATS-BENCH. | ||||
| - [2020.10.15] [446262a] Update NATS-BENCH to version 1.0 | ||||
| - [2020.12.20] [59b5696] Update NATS-BENCH to version 1.1 | ||||
| - [2020.12.20] [dae387a] Update NATS-BENCH to version 1.1 | ||||
|   | ||||
| @@ -17,7 +17,13 @@ from nats_bench.api_topology import ALL_BASE_NAMES as tss_base_names | ||||
|  | ||||
|  | ||||
| def get_fake_torch_home_dir(): | ||||
|   return os.environ['FAKE_TORCH_HOME'] | ||||
|   print('This file is {:}'.format(os.path.abspath(__file__))) | ||||
|   print('The current directory is {:}'.format(os.path.abspath(os.getcwd()))) | ||||
|   xname = 'FAKE_TORCH_HOME' | ||||
|   if xname in os.environ: | ||||
|     return os.environ['FAKE_TORCH_HOME'] | ||||
|   else: | ||||
|     return os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'fake_torch_dir') | ||||
|  | ||||
|  | ||||
| class TestNATSBench(object): | ||||
| @@ -70,8 +76,10 @@ class TestNATSBench(object): | ||||
|     print(xinfo) | ||||
|     print(data[777].train_acc1es) | ||||
|  | ||||
|     info_012_epochs = api.get_more_info(284, 'cifar10', hp=200) | ||||
|     print(info_012_epochs['train-accuracy']) | ||||
|     info_012_epochs = api.get_more_info(284, 'cifar10', hp= 12) | ||||
|     print('Train accuracy for  12 epochs is {:}'.format(info_012_epochs['train-accuracy'])) | ||||
|     info_200_epochs = api.get_more_info(284, 'cifar10', hp=200) | ||||
|     print('Train accuracy for 200 epochs is {:}'.format(info_200_epochs['train-accuracy'])) | ||||
|   | ||||
|  | ||||
| def _test_nats_bench(benchmark_dir, is_tss, fake_random, verbose=False): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user