Update NATS-Bench API to v1.1
This commit is contained in:
		| @@ -12,7 +12,8 @@ from nats_bench.api_utils import pickle_save | |||||||
| from nats_bench.api_utils import ResultsCount | from nats_bench.api_utils import ResultsCount | ||||||
|  |  | ||||||
|  |  | ||||||
| NATS_BENCH_API_VERSIONs = ['v1.0']    # [2020.08.31] | NATS_BENCH_API_VERSIONs = ['v1.0',    # [2020.08.31] | ||||||
|  |                            'v1.1']    # [2020.12.20] adding unit tests | ||||||
| NATS_BENCH_SSS_NAMEs = ('sss', 'size') | NATS_BENCH_SSS_NAMEs = ('sss', 'size') | ||||||
| NATS_BENCH_TSS_NAMEs = ('tss', 'topology') | NATS_BENCH_TSS_NAMEs = ('tss', 'topology') | ||||||
|  |  | ||||||
|   | |||||||
| @@ -3,28 +3,69 @@ | |||||||
| ############################################################################## | ############################################################################## | ||||||
| # NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size # | # NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size # | ||||||
| ############################################################################## | ############################################################################## | ||||||
|  | # pytest --capture=tee-sys                                                   # | ||||||
|  | ############################################################################## | ||||||
| """This file is used to quickly test the API.""" | """This file is used to quickly test the API.""" | ||||||
|  | import os | ||||||
|  | import pytest | ||||||
| import random | import random | ||||||
|  |  | ||||||
| from nats_bench.api_size import NATSsize | from nats_bench.api_size import NATSsize | ||||||
|  | from nats_bench.api_size import ALL_BASE_NAMES as sss_base_names | ||||||
| from nats_bench.api_topology import NATStopology | from nats_bench.api_topology import NATStopology | ||||||
|  | from nats_bench.api_topology import ALL_BASE_NAMES as tss_base_names | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_nats_bench_tss(benchmark_dir): | def get_fake_torch_home_dir(): | ||||||
|   return test_nats_bench(benchmark_dir, True) |   return os.environ['FAKE_TORCH_HOME'] | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_nats_bench_sss(benchmark_dir): | class TestNATSBench(object): | ||||||
|   return test_nats_bench(benchmark_dir, False) |  | ||||||
|  |   def test_nats_bench_tss(self, benchmark_dir=None, fake_random=True): | ||||||
|  |     if benchmark_dir is None: | ||||||
|  |       benchmark_dir = os.path.join(get_fake_torch_home_dir(), sss_base_names[-1] + '-simple') | ||||||
|  |     return _test_nats_bench(benchmark_dir, True, fake_random) | ||||||
|  |  | ||||||
|  |   def test_nats_bench_sss(self, benchmark_dir=None, fake_random=True): | ||||||
|  |     if benchmark_dir is None: | ||||||
|  |       benchmark_dir = os.path.join(get_fake_torch_home_dir(), tss_base_names[-1] + '-simple') | ||||||
|  |     return _test_nats_bench(benchmark_dir, False, fake_random) | ||||||
|  |  | ||||||
|  |   def test_01_th_issue(self): | ||||||
|  |     # Link: https://github.com/D-X-Y/NATS-Bench/issues/1 | ||||||
|  |     print('') | ||||||
|  |     tss_benchmark_dir = os.path.join(get_fake_torch_home_dir(), sss_base_names[-1] + '-simple') | ||||||
|  |     api = NATStopology(tss_benchmark_dir, True, False) | ||||||
|  |     # The performance of 0-th architecture on CIFAR-10 (trained by 12 epochs) | ||||||
|  |     info = api.get_more_info(0, 'cifar10', hp=12) | ||||||
|  |     print('The loss on the training set of CIFAR-10: {:}'.format(info['train-loss'])) | ||||||
|  |     print('The total training time for 12 epochs on CIFAR-10: {:}'.format(info['train-all-time'])) | ||||||
|  |     print('The per-epoch training time on CIFAR-10: {:}'.format(info['train-per-time'])) | ||||||
|  |     print('The total evaluation time on the test set of CIFAR-10 for 12 times: {:}'.format(info['test-all-time'])) | ||||||
|  |     print('The evaluation time on the test set of CIFAR-10: {:}'.format(info['test-per-time'])) | ||||||
|  |     # Please note that the splits of train/validation/test on CIFAR-10 in our NATS-Bench paper is different from the original CIFAR paper. | ||||||
|  |     cost_info = api.get_cost_info(0, 'cifar10') | ||||||
|  |     xkeys = ['T-train@epoch',     # The per epoch training cost for CIFAR-10. Note that the training set of CIFAR-10 in NATS-Bench is a subset of the original training set in CIFAR paper. | ||||||
|  |              'T-train@total', | ||||||
|  |              'T-ori-test@epoch',  # The time cost for the evaluation on the original test split of CIFAR-10, which is the validation + test sets of CIFAR-10 on NATS-Bench. | ||||||
|  |              'T-ori-test@total']  # T-ori-test@epoch * 12 times. | ||||||
|  |     for xkey in xkeys: | ||||||
|  |       print('The cost info [{:}] for 0-th architecture on CIFAR-10 is {:}'.format(xkey, cost_info[xkey])) | ||||||
|   |   | ||||||
|  |  | ||||||
| def test_nats_bench(benchmark_dir, is_tss, verbose=False): | def _test_nats_bench(benchmark_dir, is_tss, fake_random, verbose=False): | ||||||
|  |   """The main test entry for NATS-Bench.""" | ||||||
|   if is_tss: |   if is_tss: | ||||||
|     api = NATStopology(benchmark_dir, True, verbose) |     api = NATStopology(benchmark_dir, True, verbose) | ||||||
|   else: |   else: | ||||||
|     api = NATSsize(benchmark_dir, True, verbose) |     api = NATSsize(benchmark_dir, True, verbose) | ||||||
|  |  | ||||||
|   test_indexes = [random.randint(0, len(api) - 1) for _ in range(10)] |   if fake_random: | ||||||
|  |     test_indexes = [0, 11, 241] | ||||||
|  |   else: | ||||||
|  |     test_indexes = [random.randint(0, len(api) - 1) for _ in range(10)] | ||||||
|  |  | ||||||
|   key2dataset = {'cifar10': 'CIFAR-10', |   key2dataset = {'cifar10': 'CIFAR-10', | ||||||
|                  'cifar100': 'CIFAR-100', |                  'cifar100': 'CIFAR-100', | ||||||
|                  'ImageNet16-120': 'ImageNet16-120'} |                  'ImageNet16-120': 'ImageNet16-120'} | ||||||
| @@ -57,3 +98,6 @@ def test_nats_bench(benchmark_dir, is_tss, verbose=False): | |||||||
|  |  | ||||||
|     # Show the information of the `index`-th architecture |     # Show the information of the `index`-th architecture | ||||||
|     api.show(index) |     api.show(index) | ||||||
|  |  | ||||||
|  |   with pytest.raises(ValueError): | ||||||
|  |     api.get_more_info(100000, 'cifar10') | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user