Update NATS-Bench API to v1.1
This commit is contained in:
parent
c8ca1790e9
commit
59b5696a93
@ -12,7 +12,8 @@ from nats_bench.api_utils import pickle_save
|
|||||||
from nats_bench.api_utils import ResultsCount
|
from nats_bench.api_utils import ResultsCount
|
||||||
|
|
||||||
|
|
||||||
NATS_BENCH_API_VERSIONs = ['v1.0'] # [2020.08.31]
|
NATS_BENCH_API_VERSIONs = ['v1.0', # [2020.08.31]
|
||||||
|
'v1.1'] # [2020.12.20] adding unit tests
|
||||||
NATS_BENCH_SSS_NAMEs = ('sss', 'size')
|
NATS_BENCH_SSS_NAMEs = ('sss', 'size')
|
||||||
NATS_BENCH_TSS_NAMEs = ('tss', 'topology')
|
NATS_BENCH_TSS_NAMEs = ('tss', 'topology')
|
||||||
|
|
||||||
|
@ -3,28 +3,69 @@
|
|||||||
##############################################################################
|
##############################################################################
|
||||||
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
|
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
|
||||||
##############################################################################
|
##############################################################################
|
||||||
|
# pytest --capture=tee-sys #
|
||||||
|
##############################################################################
|
||||||
"""This file is used to quickly test the API."""
|
"""This file is used to quickly test the API."""
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from nats_bench.api_size import NATSsize
|
from nats_bench.api_size import NATSsize
|
||||||
|
from nats_bench.api_size import ALL_BASE_NAMES as sss_base_names
|
||||||
from nats_bench.api_topology import NATStopology
|
from nats_bench.api_topology import NATStopology
|
||||||
|
from nats_bench.api_topology import ALL_BASE_NAMES as tss_base_names
|
||||||
|
|
||||||
|
|
||||||
def test_nats_bench_tss(benchmark_dir):
|
def get_fake_torch_home_dir():
|
||||||
return test_nats_bench(benchmark_dir, True)
|
return os.environ['FAKE_TORCH_HOME']
|
||||||
|
|
||||||
|
|
||||||
def test_nats_bench_sss(benchmark_dir):
|
class TestNATSBench(object):
|
||||||
return test_nats_bench(benchmark_dir, False)
|
|
||||||
|
|
||||||
|
def test_nats_bench_tss(self, benchmark_dir=None, fake_random=True):
|
||||||
|
if benchmark_dir is None:
|
||||||
|
benchmark_dir = os.path.join(get_fake_torch_home_dir(), sss_base_names[-1] + '-simple')
|
||||||
|
return _test_nats_bench(benchmark_dir, True, fake_random)
|
||||||
|
|
||||||
def test_nats_bench(benchmark_dir, is_tss, verbose=False):
|
def test_nats_bench_sss(self, benchmark_dir=None, fake_random=True):
|
||||||
|
if benchmark_dir is None:
|
||||||
|
benchmark_dir = os.path.join(get_fake_torch_home_dir(), tss_base_names[-1] + '-simple')
|
||||||
|
return _test_nats_bench(benchmark_dir, False, fake_random)
|
||||||
|
|
||||||
|
def test_01_th_issue(self):
|
||||||
|
# Link: https://github.com/D-X-Y/NATS-Bench/issues/1
|
||||||
|
print('')
|
||||||
|
tss_benchmark_dir = os.path.join(get_fake_torch_home_dir(), sss_base_names[-1] + '-simple')
|
||||||
|
api = NATStopology(tss_benchmark_dir, True, False)
|
||||||
|
# The performance of 0-th architecture on CIFAR-10 (trained by 12 epochs)
|
||||||
|
info = api.get_more_info(0, 'cifar10', hp=12)
|
||||||
|
print('The loss on the training set of CIFAR-10: {:}'.format(info['train-loss']))
|
||||||
|
print('The total training time for 12 epochs on CIFAR-10: {:}'.format(info['train-all-time']))
|
||||||
|
print('The per-epoch training time on CIFAR-10: {:}'.format(info['train-per-time']))
|
||||||
|
print('The total evaluation time on the test set of CIFAR-10 for 12 times: {:}'.format(info['test-all-time']))
|
||||||
|
print('The evaluation time on the test set of CIFAR-10: {:}'.format(info['test-per-time']))
|
||||||
|
# Please note that the splits of train/validation/test on CIFAR-10 in our NATS-Bench paper is different from the original CIFAR paper.
|
||||||
|
cost_info = api.get_cost_info(0, 'cifar10')
|
||||||
|
xkeys = ['T-train@epoch', # The per epoch training cost for CIFAR-10. Note that the training set of CIFAR-10 in NATS-Bench is a subset of the original training set in CIFAR paper.
|
||||||
|
'T-train@total',
|
||||||
|
'T-ori-test@epoch', # The time cost for the evaluation on the original test split of CIFAR-10, which is the validation + test sets of CIFAR-10 on NATS-Bench.
|
||||||
|
'T-ori-test@total'] # T-ori-test@epoch * 12 times.
|
||||||
|
for xkey in xkeys:
|
||||||
|
print('The cost info [{:}] for 0-th architecture on CIFAR-10 is {:}'.format(xkey, cost_info[xkey]))
|
||||||
|
|
||||||
|
|
||||||
|
def _test_nats_bench(benchmark_dir, is_tss, fake_random, verbose=False):
|
||||||
|
"""The main test entry for NATS-Bench."""
|
||||||
if is_tss:
|
if is_tss:
|
||||||
api = NATStopology(benchmark_dir, True, verbose)
|
api = NATStopology(benchmark_dir, True, verbose)
|
||||||
else:
|
else:
|
||||||
api = NATSsize(benchmark_dir, True, verbose)
|
api = NATSsize(benchmark_dir, True, verbose)
|
||||||
|
|
||||||
test_indexes = [random.randint(0, len(api) - 1) for _ in range(10)]
|
if fake_random:
|
||||||
|
test_indexes = [0, 11, 241]
|
||||||
|
else:
|
||||||
|
test_indexes = [random.randint(0, len(api) - 1) for _ in range(10)]
|
||||||
|
|
||||||
key2dataset = {'cifar10': 'CIFAR-10',
|
key2dataset = {'cifar10': 'CIFAR-10',
|
||||||
'cifar100': 'CIFAR-100',
|
'cifar100': 'CIFAR-100',
|
||||||
'ImageNet16-120': 'ImageNet16-120'}
|
'ImageNet16-120': 'ImageNet16-120'}
|
||||||
@ -57,3 +98,6 @@ def test_nats_bench(benchmark_dir, is_tss, verbose=False):
|
|||||||
|
|
||||||
# Show the information of the `index`-th architecture
|
# Show the information of the `index`-th architecture
|
||||||
api.show(index)
|
api.show(index)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
api.get_more_info(100000, 'cifar10')
|
||||||
|
Loading…
Reference in New Issue
Block a user