diff --git a/lib/nats_bench/__init__.py b/lib/nats_bench/__init__.py index 050aae7..0702d52 100644 --- a/lib/nats_bench/__init__.py +++ b/lib/nats_bench/__init__.py @@ -12,7 +12,8 @@ from nats_bench.api_utils import pickle_save from nats_bench.api_utils import ResultsCount -NATS_BENCH_API_VERSIONs = ['v1.0'] # [2020.08.31] +NATS_BENCH_API_VERSIONs = ['v1.0', # [2020.08.31] + 'v1.1'] # [2020.12.20] adding unit tests NATS_BENCH_SSS_NAMEs = ('sss', 'size') NATS_BENCH_TSS_NAMEs = ('tss', 'topology') diff --git a/lib/nats_bench/api_test.py b/lib/nats_bench/api_test.py index a30118f..39a8de2 100644 --- a/lib/nats_bench/api_test.py +++ b/lib/nats_bench/api_test.py @@ -3,28 +3,69 @@ ############################################################################## # NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size # ############################################################################## +# pytest --capture=tee-sys # +############################################################################## """This file is used to quickly test the API.""" +import os +import pytest import random from nats_bench.api_size import NATSsize +from nats_bench.api_size import ALL_BASE_NAMES as sss_base_names from nats_bench.api_topology import NATStopology +from nats_bench.api_topology import ALL_BASE_NAMES as tss_base_names -def test_nats_bench_tss(benchmark_dir): - return test_nats_bench(benchmark_dir, True) +def get_fake_torch_home_dir(): + return os.environ['FAKE_TORCH_HOME'] -def test_nats_bench_sss(benchmark_dir): - return test_nats_bench(benchmark_dir, False) +class TestNATSBench(object): + def test_nats_bench_tss(self, benchmark_dir=None, fake_random=True): + if benchmark_dir is None: + benchmark_dir = os.path.join(get_fake_torch_home_dir(), sss_base_names[-1] + '-simple') + return _test_nats_bench(benchmark_dir, True, fake_random) -def test_nats_bench(benchmark_dir, is_tss, verbose=False): + def test_nats_bench_sss(self, benchmark_dir=None, fake_random=True): + if benchmark_dir is None: + benchmark_dir = os.path.join(get_fake_torch_home_dir(), tss_base_names[-1] + '-simple') + return _test_nats_bench(benchmark_dir, False, fake_random) + + def test_01_th_issue(self): + # Link: https://github.com/D-X-Y/NATS-Bench/issues/1 + print('') + tss_benchmark_dir = os.path.join(get_fake_torch_home_dir(), sss_base_names[-1] + '-simple') + api = NATStopology(tss_benchmark_dir, True, False) + # The performance of 0-th architecture on CIFAR-10 (trained by 12 epochs) + info = api.get_more_info(0, 'cifar10', hp=12) + print('The loss on the training set of CIFAR-10: {:}'.format(info['train-loss'])) + print('The total training time for 12 epochs on CIFAR-10: {:}'.format(info['train-all-time'])) + print('The per-epoch training time on CIFAR-10: {:}'.format(info['train-per-time'])) + print('The total evaluation time on the test set of CIFAR-10 for 12 times: {:}'.format(info['test-all-time'])) + print('The evaluation time on the test set of CIFAR-10: {:}'.format(info['test-per-time'])) + # Please note that the splits of train/validation/test on CIFAR-10 in our NATS-Bench paper is different from the original CIFAR paper. + cost_info = api.get_cost_info(0, 'cifar10') + xkeys = ['T-train@epoch', # The per epoch training cost for CIFAR-10. Note that the training set of CIFAR-10 in NATS-Bench is a subset of the original training set in CIFAR paper. + 'T-train@total', + 'T-ori-test@epoch', # The time cost for the evaluation on the original test split of CIFAR-10, which is the validation + test sets of CIFAR-10 on NATS-Bench. + 'T-ori-test@total'] # T-ori-test@epoch * 12 times. + for xkey in xkeys: + print('The cost info [{:}] for 0-th architecture on CIFAR-10 is {:}'.format(xkey, cost_info[xkey])) + + +def _test_nats_bench(benchmark_dir, is_tss, fake_random, verbose=False): + """The main test entry for NATS-Bench.""" if is_tss: api = NATStopology(benchmark_dir, True, verbose) else: api = NATSsize(benchmark_dir, True, verbose) - test_indexes = [random.randint(0, len(api) - 1) for _ in range(10)] + if fake_random: + test_indexes = [0, 11, 241] + else: + test_indexes = [random.randint(0, len(api) - 1) for _ in range(10)] + key2dataset = {'cifar10': 'CIFAR-10', 'cifar100': 'CIFAR-100', 'ImageNet16-120': 'ImageNet16-120'} @@ -57,3 +98,6 @@ def test_nats_bench(benchmark_dir, is_tss, verbose=False): # Show the information of the `index`-th architecture api.show(index) + + with pytest.raises(ValueError): + api.get_more_info(100000, 'cifar10')