Update NATS-Bench API to v1.1
This commit is contained in:
		| @@ -17,7 +17,13 @@ from nats_bench.api_topology import ALL_BASE_NAMES as tss_base_names | ||||
|  | ||||
|  | ||||
| def get_fake_torch_home_dir(): | ||||
|   return os.environ['FAKE_TORCH_HOME'] | ||||
|   print('This file is {:}'.format(os.path.abspath(__file__))) | ||||
|   print('The current directory is {:}'.format(os.path.abspath(os.getcwd()))) | ||||
|   xname = 'FAKE_TORCH_HOME' | ||||
|   if xname in os.environ: | ||||
|     return os.environ['FAKE_TORCH_HOME'] | ||||
|   else: | ||||
|     return os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'fake_torch_dir') | ||||
|  | ||||
|  | ||||
| class TestNATSBench(object): | ||||
| @@ -70,8 +76,10 @@ class TestNATSBench(object): | ||||
|     print(xinfo) | ||||
|     print(data[777].train_acc1es) | ||||
|  | ||||
|     info_012_epochs = api.get_more_info(284, 'cifar10', hp=200) | ||||
|     print(info_012_epochs['train-accuracy']) | ||||
|     info_012_epochs = api.get_more_info(284, 'cifar10', hp= 12) | ||||
|     print('Train accuracy for  12 epochs is {:}'.format(info_012_epochs['train-accuracy'])) | ||||
|     info_200_epochs = api.get_more_info(284, 'cifar10', hp=200) | ||||
|     print('Train accuracy for 200 epochs is {:}'.format(info_200_epochs['train-accuracy'])) | ||||
|   | ||||
|  | ||||
| def _test_nats_bench(benchmark_dir, is_tss, fake_random, verbose=False): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user