Update NATS-Bench API to v1.1

This commit is contained in:
D-X-Y 2020-12-19 23:42:21 +08:00
parent c8ca1790e9
commit 59b5696a93
2 changed files with 52 additions and 7 deletions

View File

@ -12,7 +12,8 @@ from nats_bench.api_utils import pickle_save
from nats_bench.api_utils import ResultsCount
NATS_BENCH_API_VERSIONs = ['v1.0'] # [2020.08.31]
NATS_BENCH_API_VERSIONs = ['v1.0', # [2020.08.31]
'v1.1'] # [2020.12.20] adding unit tests
NATS_BENCH_SSS_NAMEs = ('sss', 'size')
NATS_BENCH_TSS_NAMEs = ('tss', 'topology')

View File

@ -3,28 +3,69 @@
##############################################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
##############################################################################
# pytest --capture=tee-sys #
##############################################################################
"""This file is used to quickly test the API."""
import os
import pytest
import random
from nats_bench.api_size import NATSsize
from nats_bench.api_size import ALL_BASE_NAMES as sss_base_names
from nats_bench.api_topology import NATStopology
from nats_bench.api_topology import ALL_BASE_NAMES as tss_base_names
def test_nats_bench_tss(benchmark_dir):
return test_nats_bench(benchmark_dir, True)
def get_fake_torch_home_dir():
return os.environ['FAKE_TORCH_HOME']
def test_nats_bench_sss(benchmark_dir):
return test_nats_bench(benchmark_dir, False)
class TestNATSBench(object):
def test_nats_bench_tss(self, benchmark_dir=None, fake_random=True):
if benchmark_dir is None:
benchmark_dir = os.path.join(get_fake_torch_home_dir(), sss_base_names[-1] + '-simple')
return _test_nats_bench(benchmark_dir, True, fake_random)
def test_nats_bench(benchmark_dir, is_tss, verbose=False):
def test_nats_bench_sss(self, benchmark_dir=None, fake_random=True):
if benchmark_dir is None:
benchmark_dir = os.path.join(get_fake_torch_home_dir(), tss_base_names[-1] + '-simple')
return _test_nats_bench(benchmark_dir, False, fake_random)
def test_01_th_issue(self):
# Link: https://github.com/D-X-Y/NATS-Bench/issues/1
print('')
tss_benchmark_dir = os.path.join(get_fake_torch_home_dir(), sss_base_names[-1] + '-simple')
api = NATStopology(tss_benchmark_dir, True, False)
# The performance of 0-th architecture on CIFAR-10 (trained by 12 epochs)
info = api.get_more_info(0, 'cifar10', hp=12)
print('The loss on the training set of CIFAR-10: {:}'.format(info['train-loss']))
print('The total training time for 12 epochs on CIFAR-10: {:}'.format(info['train-all-time']))
print('The per-epoch training time on CIFAR-10: {:}'.format(info['train-per-time']))
print('The total evaluation time on the test set of CIFAR-10 for 12 times: {:}'.format(info['test-all-time']))
print('The evaluation time on the test set of CIFAR-10: {:}'.format(info['test-per-time']))
# Please note that the splits of train/validation/test on CIFAR-10 in our NATS-Bench paper is different from the original CIFAR paper.
cost_info = api.get_cost_info(0, 'cifar10')
xkeys = ['T-train@epoch', # The per epoch training cost for CIFAR-10. Note that the training set of CIFAR-10 in NATS-Bench is a subset of the original training set in CIFAR paper.
'T-train@total',
'T-ori-test@epoch', # The time cost for the evaluation on the original test split of CIFAR-10, which is the validation + test sets of CIFAR-10 on NATS-Bench.
'T-ori-test@total'] # T-ori-test@epoch * 12 times.
for xkey in xkeys:
print('The cost info [{:}] for 0-th architecture on CIFAR-10 is {:}'.format(xkey, cost_info[xkey]))
def _test_nats_bench(benchmark_dir, is_tss, fake_random, verbose=False):
"""The main test entry for NATS-Bench."""
if is_tss:
api = NATStopology(benchmark_dir, True, verbose)
else:
api = NATSsize(benchmark_dir, True, verbose)
test_indexes = [random.randint(0, len(api) - 1) for _ in range(10)]
if fake_random:
test_indexes = [0, 11, 241]
else:
test_indexes = [random.randint(0, len(api) - 1) for _ in range(10)]
key2dataset = {'cifar10': 'CIFAR-10',
'cifar100': 'CIFAR-100',
'ImageNet16-120': 'ImageNet16-120'}
@ -57,3 +98,6 @@ def test_nats_bench(benchmark_dir, is_tss, verbose=False):
# Show the information of the `index`-th architecture
api.show(index)
with pytest.raises(ValueError):
api.get_more_info(100000, 'cifar10')