Update the query_by_arch function in API to be compatiable with the submission version of NAS-Bench-201
This commit is contained in:
		| @@ -53,7 +53,7 @@ def evaluate(api, weight_dir, data: str): | |||||||
|     config = api.get_net_config(arch_index, data) |     config = api.get_net_config(arch_index, data) | ||||||
|     net = get_cell_based_tiny_net(config) |     net = get_cell_based_tiny_net(config) | ||||||
|     meta_info = api.query_meta_info_by_index(arch_index, hp='200' if isinstance(api, NASBench201API) else '90') |     meta_info = api.query_meta_info_by_index(arch_index, hp='200' if isinstance(api, NASBench201API) else '90') | ||||||
|     params = meta_info.get_net_param(data, 777) |     params = meta_info.get_net_param(data, 888 if isinstance(api, NASBench201API) else 777) | ||||||
|     with torch.no_grad(): |     with torch.no_grad(): | ||||||
|       net.load_state_dict(params) |       net.load_state_dict(params) | ||||||
|       _, summary = weight_watcher.analyze(net, alphas=False) |       _, summary = weight_watcher.analyze(net, alphas=False) | ||||||
|   | |||||||
| @@ -90,6 +90,10 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): | |||||||
|     else: arch_index = -1 |     else: arch_index = -1 | ||||||
|     return arch_index |     return arch_index | ||||||
|  |  | ||||||
|  |   def query_by_arch(self, arch, hp): | ||||||
|  |     # This is to make the current version be compatible with the old version. | ||||||
|  |     return self.query_info_str_by_arch(arch, hp) | ||||||
|  |  | ||||||
|   @abc.abstractmethod |   @abc.abstractmethod | ||||||
|   def reload(self, archive_root: Text = None, index: int = None): |   def reload(self, archive_root: Text = None, index: int = None): | ||||||
|     """Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'. |     """Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'. | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user