Update the query_by_arch function in API to be compatiable with the submission version of NAS-Bench-201

This commit is contained in:
D-X-Y 2020-07-08 04:05:55 +00:00
parent cba4741d10
commit 4892692622
2 changed files with 5 additions and 1 deletions

View File

@ -53,7 +53,7 @@ def evaluate(api, weight_dir, data: str):
config = api.get_net_config(arch_index, data)
net = get_cell_based_tiny_net(config)
meta_info = api.query_meta_info_by_index(arch_index, hp='200' if isinstance(api, NASBench201API) else '90')
params = meta_info.get_net_param(data, 777)
params = meta_info.get_net_param(data, 888 if isinstance(api, NASBench201API) else 777)
with torch.no_grad():
net.load_state_dict(params)
_, summary = weight_watcher.analyze(net, alphas=False)

View File

@ -90,6 +90,10 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
else: arch_index = -1
return arch_index
def query_by_arch(self, arch, hp):
# This is to make the current version be compatible with the old version.
return self.query_info_str_by_arch(arch, hp)
@abc.abstractmethod
def reload(self, archive_root: Text = None, index: int = None):
"""Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'.