Update NATS-Bench API to v1.1
This commit is contained in:
parent
c4ef3f6620
commit
dae387a97d
@ -243,7 +243,10 @@ class NATSsize(NASBenchMetaAPI):
|
||||
except Exception as unused_e: # pylint: disable=broad-except
|
||||
test_info = None
|
||||
valtest_info = None
|
||||
xinfo['comment'] = 'In this dict, train-loss/accuracy/time is the metric on the train set of CIFAR-10. The test-loss/accuracy/time is the performance of the CIFAR-10 test set after training on the train set by {:} epochs. The per-time and total-time indicate the per epoch and total time costs, respectively.'.format(hp)
|
||||
else:
|
||||
if dataset == 'cifar10':
|
||||
xinfo['comment'] = 'In this dict, train-loss/accuracy/time is the metric on the train+valid sets of CIFAR-10. The test-loss/accuracy/time is the performance of the CIFAR-10 test set after training on the train+valid sets by {:} epochs. The per-time and total-time indicate the per epoch and total time costs, respectively.'.format(hp)
|
||||
try: # collect results on the proposed test set
|
||||
if dataset == 'cifar10':
|
||||
test_info = archresult.get_metrics(
|
||||
|
@ -32,26 +32,46 @@ class TestNATSBench(object):
|
||||
benchmark_dir = os.path.join(get_fake_torch_home_dir(), tss_base_names[-1] + '-simple')
|
||||
return _test_nats_bench(benchmark_dir, False, fake_random)
|
||||
|
||||
def prepare_fake_tss(self):
|
||||
print('')
|
||||
tss_benchmark_dir = os.path.join(get_fake_torch_home_dir(), tss_base_names[-1] + '-simple')
|
||||
api = NATStopology(tss_benchmark_dir, True, False)
|
||||
return api
|
||||
|
||||
def test_01_th_issue(self):
|
||||
# Link: https://github.com/D-X-Y/NATS-Bench/issues/1
|
||||
print('')
|
||||
tss_benchmark_dir = os.path.join(get_fake_torch_home_dir(), sss_base_names[-1] + '-simple')
|
||||
api = NATStopology(tss_benchmark_dir, True, False)
|
||||
api = self.prepare_fake_tss()
|
||||
# The performance of 0-th architecture on CIFAR-10 (trained by 12 epochs)
|
||||
info = api.get_more_info(0, 'cifar10', hp=12)
|
||||
print('The loss on the training set of CIFAR-10: {:}'.format(info['train-loss']))
|
||||
print('The total training time for 12 epochs on CIFAR-10: {:}'.format(info['train-all-time']))
|
||||
# First of all, the data split in NATS-Bench is different from that in the official CIFAR paper.
|
||||
# In NATS-Bench, we split the original CIFAR-10 training set into two parts, i.e., a training set and a validation set.
|
||||
# In the following, we will use the splits of NATS-Bench to explain.
|
||||
print(info['comment'])
|
||||
print('The loss on the training + validation sets of CIFAR-10: {:}'.format(info['train-loss']))
|
||||
print('The total training time for 12 epochs on the training + validation sets of CIFAR-10: {:}'.format(info['train-all-time']))
|
||||
print('The per-epoch training time on CIFAR-10: {:}'.format(info['train-per-time']))
|
||||
print('The total evaluation time on the test set of CIFAR-10 for 12 times: {:}'.format(info['test-all-time']))
|
||||
print('The evaluation time on the test set of CIFAR-10: {:}'.format(info['test-per-time']))
|
||||
# Please note that the splits of train/validation/test on CIFAR-10 in our NATS-Bench paper is different from the original CIFAR paper.
|
||||
cost_info = api.get_cost_info(0, 'cifar10')
|
||||
xkeys = ['T-train@epoch', # The per epoch training cost for CIFAR-10. Note that the training set of CIFAR-10 in NATS-Bench is a subset of the original training set in CIFAR paper.
|
||||
xkeys = ['T-train@epoch', # The per epoch training time on the training + validation sets of CIFAR-10.
|
||||
'T-train@total',
|
||||
'T-ori-test@epoch', # The time cost for the evaluation on the original test split of CIFAR-10, which is the validation + test sets of CIFAR-10 on NATS-Bench.
|
||||
'T-ori-test@epoch', # The time cost for the evaluation on CIFAR-10 test set.
|
||||
'T-ori-test@total'] # T-ori-test@epoch * 12 times.
|
||||
for xkey in xkeys:
|
||||
print('The cost info [{:}] for 0-th architecture on CIFAR-10 is {:}'.format(xkey, cost_info[xkey]))
|
||||
|
||||
def test_02_th_issue(self):
|
||||
# https://github.com/D-X-Y/NATS-Bench/issues/2
|
||||
api = self.prepare_fake_tss()
|
||||
data = api.query_by_index(284, dataname='cifar10', hp=200)
|
||||
for xkey, xvalue in data.items():
|
||||
print('{:} : {:}'.format(xkey, xvalue))
|
||||
xinfo = data[777].get_train()
|
||||
print(xinfo)
|
||||
print(data[777].train_acc1es)
|
||||
|
||||
info_012_epochs = api.get_more_info(284, 'cifar10', hp=200)
|
||||
print(info_012_epochs['train-accuracy'])
|
||||
|
||||
|
||||
def _test_nats_bench(benchmark_dir, is_tss, fake_random, verbose=False):
|
||||
@ -62,7 +82,7 @@ def _test_nats_bench(benchmark_dir, is_tss, fake_random, verbose=False):
|
||||
api = NATSsize(benchmark_dir, True, verbose)
|
||||
|
||||
if fake_random:
|
||||
test_indexes = [0, 11, 241]
|
||||
test_indexes = [0, 11, 284]
|
||||
else:
|
||||
test_indexes = [random.randint(0, len(api) - 1) for _ in range(10)]
|
||||
|
||||
|
@ -222,7 +222,10 @@ class NATStopology(NASBenchMetaAPI):
|
||||
except Exception as unused_e: # pylint: disable=broad-except
|
||||
test_info = None
|
||||
valtest_info = None
|
||||
xinfo['comment'] = 'In this dict, train-loss/accuracy/time is the metric on the train set of CIFAR-10. The test-loss/accuracy/time is the performance of the CIFAR-10 test set after training on the train set by {:} epochs. The per-time and total-time indicate the per epoch and total time costs, respectively.'.format(hp)
|
||||
else:
|
||||
if dataset == 'cifar10':
|
||||
xinfo['comment'] = 'In this dict, train-loss/accuracy/time is the metric on the train+valid sets of CIFAR-10. The test-loss/accuracy/time is the performance of the CIFAR-10 test set after training on the train+valid sets by {:} epochs. The per-time and total-time indicate the per epoch and total time costs, respectively.'.format(hp)
|
||||
try: # collect results on the proposed test set
|
||||
if dataset == 'cifar10':
|
||||
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
|
@ -426,13 +426,13 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
||||
arch_index, hp))
|
||||
self._prepare_info(arch_index)
|
||||
if arch_index in self.arch2infos_dict:
|
||||
if hp not in self.arch2infos_dict[arch_index]:
|
||||
if str(hp) not in self.arch2infos_dict[arch_index]:
|
||||
raise ValueError('The {:}-th architecture only has hyper-parameters of '
|
||||
'{:} instead of {:}.'.format(
|
||||
arch_index,
|
||||
list(self.arch2infos_dict[arch_index].keys()),
|
||||
hp))
|
||||
info = self.arch2infos_dict[arch_index][hp]
|
||||
info = self.arch2infos_dict[arch_index][str(hp)]
|
||||
else:
|
||||
raise ValueError('arch_index [{:}] does not in arch2infos'.format(
|
||||
arch_index))
|
||||
@ -472,7 +472,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
||||
if self.verbose:
|
||||
print('{:} Call query_by_index with arch_index={:}, dataname={:}, '
|
||||
'hp={:}'.format(time_string(), arch_index, dataname, hp))
|
||||
info = self.query_meta_info_by_index(arch_index, hp)
|
||||
info = self.query_meta_info_by_index(arch_index, str(hp))
|
||||
if dataname is None:
|
||||
return info
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user