diff --git a/exps/NAS-Bench-201/test-nas-api.py b/exps/NAS-Bench-201/test-nas-api.py index d0f69dc..9a79f28 100644 --- a/exps/NAS-Bench-201/test-nas-api.py +++ b/exps/NAS-Bench-201/test-nas-api.py @@ -23,7 +23,7 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) from config_utils import dict2config, load_config from nas_201_api import NASBench201API, NASBench301API from log_utils import time_string -from models import get_cell_based_tiny_net +from models import get_cell_based_tiny_net, CellStructure def test_api(api, is_301=True): @@ -80,6 +80,11 @@ def test_issue_81_82(api): print(results[888].get_eval('valid')) print(results[888].get_eval('x-valid')) result_dict = api.get_more_info(index=0, dataset='cifar10-valid', iepoch=11, hp='200', is_random=False) + info = api.query_by_arch('|nor_conv_3x3~0|+|skip_connect~0|nor_conv_3x3~1|+|skip_connect~0|none~1|nor_conv_3x3~2|', '200') + print(info) + structure = CellStructure.str2structure('|nor_conv_3x3~0|+|skip_connect~0|nor_conv_3x3~1|+|skip_connect~0|none~1|nor_conv_3x3~2|') + info = api.query_by_arch(structure, '200') + print(info) if __name__ == '__main__': diff --git a/lib/models/cell_searchs/genotypes.py b/lib/models/cell_searchs/genotypes.py index dcaa60c..b2b4091 100644 --- a/lib/models/cell_searchs/genotypes.py +++ b/lib/models/cell_searchs/genotypes.py @@ -4,7 +4,6 @@ from copy import deepcopy - def get_combination(space, num): combs = [] for i in range(num): @@ -19,7 +18,6 @@ def get_combination(space, num): new_combs.append( xstring ) combs = new_combs return combs - class Structure: diff --git a/lib/nas_201_api/api_201.py b/lib/nas_201_api/api_201.py index 02159c0..6af1fb9 100644 --- a/lib/nas_201_api/api_201.py +++ b/lib/nas_201_api/api_201.py @@ -123,7 +123,7 @@ class NASBench201API(NASBenchMetaAPI): """ if self.verbose: print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp)) - self._query_info_str_by_arch(arch, hp, print_information) + return self._query_info_str_by_arch(arch, hp, print_information) # obtain the metric for the `index`-th architecture # `dataset` indicates the dataset: