diff --git a/exps/experimental/test-ww-bench.py b/exps/experimental/test-ww-bench.py index a07dfa0..2a2f988 100644 --- a/exps/experimental/test-ww-bench.py +++ b/exps/experimental/test-ww-bench.py @@ -53,7 +53,7 @@ def evaluate(api, weight_dir, data: str): config = api.get_net_config(arch_index, data) net = get_cell_based_tiny_net(config) meta_info = api.query_meta_info_by_index(arch_index, hp='200' if isinstance(api, NASBench201API) else '90') - params = meta_info.get_net_param(data, 777) + params = meta_info.get_net_param(data, 888 if isinstance(api, NASBench201API) else 777) with torch.no_grad(): net.load_state_dict(params) _, summary = weight_watcher.analyze(net, alphas=False) diff --git a/lib/nas_201_api/api_utils.py b/lib/nas_201_api/api_utils.py index 428fab1..74fd8c7 100644 --- a/lib/nas_201_api/api_utils.py +++ b/lib/nas_201_api/api_utils.py @@ -90,6 +90,10 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): else: arch_index = -1 return arch_index + def query_by_arch(self, arch, hp): + # This is to make the current version be compatible with the old version. + return self.query_info_str_by_arch(arch, hp) + @abc.abstractmethod def reload(self, archive_root: Text = None, index: int = None): """Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'.