Update the query_by_arch function in API to be compatiable with the submission version of NAS-Bench-201

This commit is contained in:
D-X-Y 2020-07-08 05:08:55 +00:00
parent 233a829bd7
commit af1be7f740
3 changed files with 7 additions and 4 deletions

View File

@ -23,7 +23,7 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config from config_utils import dict2config, load_config
from nas_201_api import NASBench201API, NASBench301API from nas_201_api import NASBench201API, NASBench301API
from log_utils import time_string from log_utils import time_string
from models import get_cell_based_tiny_net from models import get_cell_based_tiny_net, CellStructure
def test_api(api, is_301=True): def test_api(api, is_301=True):
@ -80,6 +80,11 @@ def test_issue_81_82(api):
print(results[888].get_eval('valid')) print(results[888].get_eval('valid'))
print(results[888].get_eval('x-valid')) print(results[888].get_eval('x-valid'))
result_dict = api.get_more_info(index=0, dataset='cifar10-valid', iepoch=11, hp='200', is_random=False) result_dict = api.get_more_info(index=0, dataset='cifar10-valid', iepoch=11, hp='200', is_random=False)
info = api.query_by_arch('|nor_conv_3x3~0|+|skip_connect~0|nor_conv_3x3~1|+|skip_connect~0|none~1|nor_conv_3x3~2|', '200')
print(info)
structure = CellStructure.str2structure('|nor_conv_3x3~0|+|skip_connect~0|nor_conv_3x3~1|+|skip_connect~0|none~1|nor_conv_3x3~2|')
info = api.query_by_arch(structure, '200')
print(info)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -4,7 +4,6 @@
from copy import deepcopy from copy import deepcopy
def get_combination(space, num): def get_combination(space, num):
combs = [] combs = []
for i in range(num): for i in range(num):
@ -21,7 +20,6 @@ def get_combination(space, num):
return combs return combs
class Structure: class Structure:
def __init__(self, genotype): def __init__(self, genotype):

View File

@ -123,7 +123,7 @@ class NASBench201API(NASBenchMetaAPI):
""" """
if self.verbose: if self.verbose:
print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp)) print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp))
self._query_info_str_by_arch(arch, hp, print_information) return self._query_info_str_by_arch(arch, hp, print_information)
# obtain the metric for the `index`-th architecture # obtain the metric for the `index`-th architecture
# `dataset` indicates the dataset: # `dataset` indicates the dataset: