Update the query_by_arch function in API to be compatiable with the submission version of NAS-Bench-201
This commit is contained in:
parent
cba4741d10
commit
4892692622
@ -53,7 +53,7 @@ def evaluate(api, weight_dir, data: str):
|
|||||||
config = api.get_net_config(arch_index, data)
|
config = api.get_net_config(arch_index, data)
|
||||||
net = get_cell_based_tiny_net(config)
|
net = get_cell_based_tiny_net(config)
|
||||||
meta_info = api.query_meta_info_by_index(arch_index, hp='200' if isinstance(api, NASBench201API) else '90')
|
meta_info = api.query_meta_info_by_index(arch_index, hp='200' if isinstance(api, NASBench201API) else '90')
|
||||||
params = meta_info.get_net_param(data, 777)
|
params = meta_info.get_net_param(data, 888 if isinstance(api, NASBench201API) else 777)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
net.load_state_dict(params)
|
net.load_state_dict(params)
|
||||||
_, summary = weight_watcher.analyze(net, alphas=False)
|
_, summary = weight_watcher.analyze(net, alphas=False)
|
||||||
|
@ -90,6 +90,10 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
|||||||
else: arch_index = -1
|
else: arch_index = -1
|
||||||
return arch_index
|
return arch_index
|
||||||
|
|
||||||
|
def query_by_arch(self, arch, hp):
|
||||||
|
# This is to make the current version be compatible with the old version.
|
||||||
|
return self.query_info_str_by_arch(arch, hp)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def reload(self, archive_root: Text = None, index: int = None):
|
def reload(self, archive_root: Text = None, index: int = None):
|
||||||
"""Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'.
|
"""Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'.
|
||||||
|
Loading…
Reference in New Issue
Block a user