Fix the potential memory leak in NAS-Bench-201 clear_param
This commit is contained in:
		| @@ -114,15 +114,27 @@ class NASBench201API(object): | ||||
|     assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path) | ||||
|     xdata = torch.load(xfile_path, map_location='cpu') | ||||
|     assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path) | ||||
|     if index in self.arch2infos_less: del self.arch2infos_less[index] | ||||
|     if index in self.arch2infos_full: del self.arch2infos_full[index] | ||||
|     self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] ) | ||||
|     self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] ) | ||||
|  | ||||
|   def clear_params(self, index: int, use_12epochs_result: bool): | ||||
|     """Remove the architecture's weights to save memory.""" | ||||
|     if use_12epochs_result: arch2infos = self.arch2infos_less | ||||
|     else                  : arch2infos = self.arch2infos_full | ||||
|     archresult = arch2infos[index] | ||||
|     archresult.clear_params() | ||||
|   def clear_params(self, index: int, use_12epochs_result: Union[bool, None]): | ||||
|     """Remove the architecture's weights to save memory. | ||||
|     :arg | ||||
|       index: the index of the target architecture | ||||
|       use_12epochs_result: a flag to controll how to clear the parameters. | ||||
|         -- None: clear all the weights in both `less` and `full`, which indicates the training hyper-parameters. | ||||
|         -- True: clear all the weights in arch2infos_less, which by default is 12-epoch-training result. | ||||
|         -- False: clear all the weights in arch2infos_full, which by default is 200-epoch-training result. | ||||
|     """ | ||||
|     if use_12epochs_result is None: | ||||
|       self.arch2infos_less[index].clear_params() | ||||
|       self.arch2infos_full[index].clear_params() | ||||
|     else: | ||||
|       if use_12epochs_result: arch2infos = self.arch2infos_less | ||||
|       else                  : arch2infos = self.arch2infos_full | ||||
|       arch2infos[index].clear_params() | ||||
|    | ||||
|   # This function is used to query the information of a specific archiitecture | ||||
|   # 'arch' can be an architecture index or an architecture string | ||||
| @@ -193,7 +205,6 @@ class NASBench201API(object): | ||||
|         best_index, highest_accuracy = idx, accuracy | ||||
|     return best_index, highest_accuracy | ||||
|  | ||||
|  | ||||
|   def arch(self, index: int): | ||||
|     """Return the topology structure of the `index`-th architecture.""" | ||||
|     assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs)) | ||||
| @@ -213,7 +224,6 @@ class NASBench201API(object): | ||||
|     else: arch2infos = self.arch2infos_full | ||||
|     arch_result = arch2infos[index] | ||||
|     return arch_result.get_net_param(dataset, seed) | ||||
|    | ||||
|  | ||||
|   def get_net_config(self, index: int, dataset: Text): | ||||
|     """ | ||||
| @@ -235,7 +245,6 @@ class NASBench201API(object): | ||||
|       #print ('SEED [{:}] : {:}'.format(seed, result)) | ||||
|     raise ValueError('Impossible to reach here!') | ||||
|  | ||||
|  | ||||
|   def get_cost_info(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> Dict[Text, float]: | ||||
|     """To obtain the cost metric for the `index`-th architecture on a dataset.""" | ||||
|     if use_12epochs_result: arch2infos = self.arch2infos_less | ||||
| @@ -243,7 +252,6 @@ class NASBench201API(object): | ||||
|     arch_result = arch2infos[index] | ||||
|     return arch_result.get_compute_costs(dataset) | ||||
|  | ||||
|  | ||||
|   def get_latency(self, index: int, dataset: Text, use_12epochs_result: bool = False) -> float: | ||||
|     """ | ||||
|     To obtain the latency of the network (by default it will return the latency with the batch size of 256). | ||||
| @@ -254,7 +262,6 @@ class NASBench201API(object): | ||||
|     cost_dict = self.get_cost_info(index, dataset, use_12epochs_result) | ||||
|     return cost_dict['latency'] | ||||
|  | ||||
|  | ||||
|   # obtain the metric for the `index`-th architecture | ||||
|   # `dataset` indicates the dataset: | ||||
|   #   'cifar10-valid'  : using the proposed train set of CIFAR-10 as the training set | ||||
| @@ -388,7 +395,6 @@ class NASBench201API(object): | ||||
|       return xifo | ||||
|   """ | ||||
|  | ||||
|  | ||||
|   def show(self, index: int = -1) -> None: | ||||
|     """ | ||||
|     This function will print the information of a specific (or all) architecture(s). | ||||
| @@ -423,7 +429,6 @@ class NASBench201API(object): | ||||
|       else: | ||||
|         print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs))) | ||||
|  | ||||
|  | ||||
|   def statistics(self, dataset: Text, use_12epochs_result: bool) -> Dict[int, int]: | ||||
|     """ | ||||
|     This function will count the number of total trials. | ||||
| @@ -443,7 +448,6 @@ class NASBench201API(object): | ||||
|         nums[len(dataset_seed[dataset])] += 1 | ||||
|     return dict(nums) | ||||
|  | ||||
|  | ||||
|   @staticmethod | ||||
|   def str2lists(arch_str: Text) -> List[tuple]: | ||||
|     """ | ||||
| @@ -471,7 +475,6 @@ class NASBench201API(object): | ||||
|       genotypes.append( input_infos ) | ||||
|     return genotypes | ||||
|  | ||||
|  | ||||
|   @staticmethod | ||||
|   def str2matrix(arch_str: Text, | ||||
|                  search_space: List[Text] = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']) -> np.ndarray: | ||||
| @@ -511,7 +514,6 @@ class NASBench201API(object): | ||||
|     return matrix | ||||
|  | ||||
|  | ||||
|  | ||||
| class ArchResults(object): | ||||
|  | ||||
|   def __init__(self, arch_index, arch_str): | ||||
| @@ -752,7 +754,6 @@ class ArchResults(object): | ||||
|  | ||||
|   def __repr__(self): | ||||
|     return ('{name}(arch-index={index}, arch={arch}, {num} runs, clear={clear})'.format(name=self.__class__.__name__, index=self.arch_index, arch=self.arch_str, num=len(self.all_results), clear=self.clear_net_done)) | ||||
|      | ||||
|  | ||||
|  | ||||
| """ | ||||
| @@ -872,8 +873,8 @@ class ResultsCount(object): | ||||
|             'cur_time': xtime, | ||||
|             'all_time': atime} | ||||
|  | ||||
|   # get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument). | ||||
|   def get_eval(self, name, iepoch=None): | ||||
|     """Get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument).""" | ||||
|     if iepoch is None: iepoch = self.epochs-1 | ||||
|     assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs) | ||||
|     if isinstance(self.eval_times,dict) and len(self.eval_times) > 0: | ||||
| @@ -890,8 +891,8 @@ class ResultsCount(object): | ||||
|     if clone: return copy.deepcopy(self.net_state_dict) | ||||
|     else: return self.net_state_dict | ||||
|  | ||||
|   # This function is used to obtain the config dict for this architecture. | ||||
|   def get_config(self, str2structure): | ||||
|     """This function is used to obtain the config dict for this architecture.""" | ||||
|     if str2structure is None: | ||||
|       return {'name': 'infer.tiny', 'C': self.arch_config['channel'], | ||||
|               'N'   : self.arch_config['num_cells'], | ||||
|   | ||||
		Reference in New Issue
	
	Block a user