From 489269262279e64b760b7825cc16947c0dbdfc57 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Wed, 8 Jul 2020 04:05:55 +0000 Subject: [PATCH] Update the query_by_arch function in API to be compatiable with the submission version of NAS-Bench-201 --- exps/experimental/test-ww-bench.py | 2 +- lib/nas_201_api/api_utils.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/exps/experimental/test-ww-bench.py b/exps/experimental/test-ww-bench.py index a07dfa0..2a2f988 100644 --- a/exps/experimental/test-ww-bench.py +++ b/exps/experimental/test-ww-bench.py @@ -53,7 +53,7 @@ def evaluate(api, weight_dir, data: str): config = api.get_net_config(arch_index, data) net = get_cell_based_tiny_net(config) meta_info = api.query_meta_info_by_index(arch_index, hp='200' if isinstance(api, NASBench201API) else '90') - params = meta_info.get_net_param(data, 777) + params = meta_info.get_net_param(data, 888 if isinstance(api, NASBench201API) else 777) with torch.no_grad(): net.load_state_dict(params) _, summary = weight_watcher.analyze(net, alphas=False) diff --git a/lib/nas_201_api/api_utils.py b/lib/nas_201_api/api_utils.py index 428fab1..74fd8c7 100644 --- a/lib/nas_201_api/api_utils.py +++ b/lib/nas_201_api/api_utils.py @@ -90,6 +90,10 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): else: arch_index = -1 return arch_index + def query_by_arch(self, arch, hp): + # This is to make the current version be compatible with the old version. + return self.query_info_str_by_arch(arch, hp) + @abc.abstractmethod def reload(self, archive_root: Text = None, index: int = None): """Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'.