Update the query_by_arch function in API to be compatiable with the submission version of NAS-Bench-201
This commit is contained in:
parent
233a829bd7
commit
af1be7f740
@ -23,7 +23,7 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
|||||||
from config_utils import dict2config, load_config
|
from config_utils import dict2config, load_config
|
||||||
from nas_201_api import NASBench201API, NASBench301API
|
from nas_201_api import NASBench201API, NASBench301API
|
||||||
from log_utils import time_string
|
from log_utils import time_string
|
||||||
from models import get_cell_based_tiny_net
|
from models import get_cell_based_tiny_net, CellStructure
|
||||||
|
|
||||||
|
|
||||||
def test_api(api, is_301=True):
|
def test_api(api, is_301=True):
|
||||||
@ -80,6 +80,11 @@ def test_issue_81_82(api):
|
|||||||
print(results[888].get_eval('valid'))
|
print(results[888].get_eval('valid'))
|
||||||
print(results[888].get_eval('x-valid'))
|
print(results[888].get_eval('x-valid'))
|
||||||
result_dict = api.get_more_info(index=0, dataset='cifar10-valid', iepoch=11, hp='200', is_random=False)
|
result_dict = api.get_more_info(index=0, dataset='cifar10-valid', iepoch=11, hp='200', is_random=False)
|
||||||
|
info = api.query_by_arch('|nor_conv_3x3~0|+|skip_connect~0|nor_conv_3x3~1|+|skip_connect~0|none~1|nor_conv_3x3~2|', '200')
|
||||||
|
print(info)
|
||||||
|
structure = CellStructure.str2structure('|nor_conv_3x3~0|+|skip_connect~0|nor_conv_3x3~1|+|skip_connect~0|none~1|nor_conv_3x3~2|')
|
||||||
|
info = api.query_by_arch(structure, '200')
|
||||||
|
print(info)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -4,7 +4,6 @@
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_combination(space, num):
|
def get_combination(space, num):
|
||||||
combs = []
|
combs = []
|
||||||
for i in range(num):
|
for i in range(num):
|
||||||
@ -19,7 +18,6 @@ def get_combination(space, num):
|
|||||||
new_combs.append( xstring )
|
new_combs.append( xstring )
|
||||||
combs = new_combs
|
combs = new_combs
|
||||||
return combs
|
return combs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Structure:
|
class Structure:
|
||||||
|
@ -123,7 +123,7 @@ class NASBench201API(NASBenchMetaAPI):
|
|||||||
"""
|
"""
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp))
|
print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp))
|
||||||
self._query_info_str_by_arch(arch, hp, print_information)
|
return self._query_info_str_by_arch(arch, hp, print_information)
|
||||||
|
|
||||||
# obtain the metric for the `index`-th architecture
|
# obtain the metric for the `index`-th architecture
|
||||||
# `dataset` indicates the dataset:
|
# `dataset` indicates the dataset:
|
||||||
|
Loading…
Reference in New Issue
Block a user