upload
This commit is contained in:
		
							
								
								
									
										66
									
								
								zero-cost-nas/foresight/pruners/measures/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								zero-cost-nas/foresight/pruners/measures/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,66 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
|  | ||||
| available_measures = [] | ||||
| _measure_impls = {} | ||||
|  | ||||
|  | ||||
| def measure(name, bn=True, copy_net=True, force_clean=True, **impl_args): | ||||
|     def make_impl(func): | ||||
|         def measure_impl(net_orig, device, *args, **kwargs): | ||||
|             if copy_net: | ||||
|                 net = net_orig.get_prunable_copy(bn=bn).to(device) | ||||
|             else: | ||||
|                 net = net_orig | ||||
|             ret = func(net, *args, **kwargs, **impl_args) | ||||
|             if copy_net and force_clean: | ||||
|                 import gc | ||||
|                 import torch | ||||
|                 del net | ||||
|                 torch.cuda.empty_cache() | ||||
|                 gc.collect() | ||||
|             return ret | ||||
|  | ||||
|         global _measure_impls | ||||
|         if name in _measure_impls: | ||||
|             raise KeyError(f'Duplicated measure! {name}') | ||||
|         available_measures.append(name) | ||||
|         _measure_impls[name] = measure_impl | ||||
|         return func | ||||
|     return make_impl | ||||
|  | ||||
|  | ||||
| def calc_measure(name, net, device, *args, **kwargs): | ||||
|     return _measure_impls[name](net, device, *args, **kwargs) | ||||
|  | ||||
|  | ||||
| def load_all(): | ||||
|     from . import grad_norm | ||||
|     from . import snip | ||||
|     from . import grasp | ||||
|     from . import fisher | ||||
|     from . import jacob_cov | ||||
|     from . import plain | ||||
|     from . import synflow | ||||
|     from . import var | ||||
|     from . import cor | ||||
|     from . import norm | ||||
|     from . import meco | ||||
|     from . import zico | ||||
|  | ||||
|  | ||||
| # TODO: should we do that by default? | ||||
| load_all() | ||||
							
								
								
									
										53
									
								
								zero-cost-nas/foresight/pruners/measures/cor.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								zero-cost-nas/foresight/pruners/measures/cor.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,53 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| import time | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import numpy as np | ||||
| import torch | ||||
|  | ||||
| from . import measure | ||||
|  | ||||
|  | ||||
| def get_score(net, x, target, device, split_data): | ||||
|     result_list = [] | ||||
|     def forward_hook(module, data_input, data_output): | ||||
|         corr = np.mean(np.corrcoef(data_input[0].detach().cpu().numpy())) | ||||
|         result_list.append(corr) | ||||
|     net.classifier.register_forward_hook(forward_hook) | ||||
|  | ||||
|     N = x.shape[0] | ||||
|     for sp in range(split_data): | ||||
|         st = sp * N // split_data | ||||
|         en = (sp + 1) * N // split_data | ||||
|         y = net(x[st:en]) | ||||
|     cor = result_list[0].item() | ||||
|     result_list.clear() | ||||
|     return cor | ||||
|  | ||||
|  | ||||
|  | ||||
| @measure('cor', bn=True) | ||||
| def compute_norm(net, inputs, targets, split_data=1, loss_fn=None): | ||||
|     device = inputs.device | ||||
|     # Compute gradients (but don't apply them) | ||||
|     net.zero_grad() | ||||
|  | ||||
|     try: | ||||
|         cor= get_score(net, inputs, targets, device, split_data=split_data) | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         cor= np.nan | ||||
|  | ||||
|     return cor | ||||
							
								
								
									
										107
									
								
								zero-cost-nas/foresight/pruners/measures/fisher.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										107
									
								
								zero-cost-nas/foresight/pruners/measures/fisher.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,107 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| import types | ||||
|  | ||||
| from . import measure | ||||
| from ..p_utils import get_layer_metric_array, reshape_elements | ||||
|  | ||||
|  | ||||
| def fisher_forward_conv2d(self, x): | ||||
|     x = F.conv2d(x, self.weight, self.bias, self.stride, | ||||
|                     self.padding, self.dilation, self.groups) | ||||
|     #intercept and store the activations after passing through 'hooked' identity op | ||||
|     self.act = self.dummy(x) | ||||
|     return self.act | ||||
|  | ||||
| def fisher_forward_linear(self, x): | ||||
|     x = F.linear(x, self.weight, self.bias) | ||||
|     self.act = self.dummy(x) | ||||
|     return self.act | ||||
|  | ||||
| @measure('fisher', bn=True, mode='channel') | ||||
| def compute_fisher_per_weight(net, inputs, targets, loss_fn, mode, split_data=1): | ||||
|      | ||||
|     device = inputs.device | ||||
|  | ||||
|     if mode == 'param': | ||||
|         raise ValueError('Fisher pruning does not support parameter pruning.') | ||||
|  | ||||
|     net.train() | ||||
|     all_hooks = [] | ||||
|     for layer in net.modules(): | ||||
|         if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): | ||||
|             #variables/op needed for fisher computation | ||||
|             layer.fisher = None | ||||
|             layer.act = 0. | ||||
|             layer.dummy = nn.Identity() | ||||
|  | ||||
|             #replace forward method of conv/linear | ||||
|             if isinstance(layer, nn.Conv2d): | ||||
|                 layer.forward = types.MethodType(fisher_forward_conv2d, layer) | ||||
|             if isinstance(layer, nn.Linear): | ||||
|                 layer.forward = types.MethodType(fisher_forward_linear, layer) | ||||
|  | ||||
|             #function to call during backward pass (hooked on identity op at output of layer) | ||||
|             def hook_factory(layer): | ||||
|                 def hook(module, grad_input, grad_output): | ||||
|                     act = layer.act.detach() | ||||
|                     grad = grad_output[0].detach() | ||||
|                     if len(act.shape) > 2: | ||||
|                         g_nk = torch.sum((act * grad), list(range(2,len(act.shape)))) | ||||
|                     else: | ||||
|                         g_nk = act * grad | ||||
|                     del_k = g_nk.pow(2).mean(0).mul(0.5) | ||||
|                     if layer.fisher is None: | ||||
|                         layer.fisher = del_k | ||||
|                     else: | ||||
|                         layer.fisher += del_k | ||||
|                     del layer.act #without deleting this, a nasty memory leak occurs! related: https://discuss.pytorch.org/t/memory-leak-when-using-forward-hook-and-backward-hook-simultaneously/27555 | ||||
|                 return hook | ||||
|  | ||||
|             #register backward hook on identity fcn to compute fisher info | ||||
|             layer.dummy.register_backward_hook(hook_factory(layer)) | ||||
|  | ||||
|     N = inputs.shape[0] | ||||
|     for sp in range(split_data): | ||||
|         st=sp*N//split_data | ||||
|         en=(sp+1)*N//split_data | ||||
|  | ||||
|         net.zero_grad() | ||||
|         outputs = net(inputs[st:en]) | ||||
|         loss = loss_fn(outputs, targets[st:en]) | ||||
|         loss.backward() | ||||
|  | ||||
|     # retrieve fisher info | ||||
|     def fisher(layer): | ||||
|         if layer.fisher is not None: | ||||
|             return torch.abs(layer.fisher.detach()) | ||||
|         else: | ||||
|             return torch.zeros(layer.weight.shape[0]) #size=ch | ||||
|  | ||||
|     grads_abs_ch = get_layer_metric_array(net, fisher, mode) | ||||
|  | ||||
|     #broadcast channel value here to all parameters in that channel | ||||
|     #to be compatible with stuff downstream (which expects per-parameter metrics) | ||||
|     #TODO cleanup on the selectors/apply_prune_mask side (?) | ||||
|     shapes = get_layer_metric_array(net, lambda l : l.weight.shape[1:], mode) | ||||
|  | ||||
|     grads_abs = reshape_elements(grads_abs_ch, shapes, device) | ||||
|  | ||||
|     return grads_abs | ||||
							
								
								
									
										38
									
								
								zero-cost-nas/foresight/pruners/measures/grad_norm.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								zero-cost-nas/foresight/pruners/measures/grad_norm.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,38 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| import copy | ||||
|  | ||||
| from . import measure | ||||
| from ..p_utils import get_layer_metric_array | ||||
|  | ||||
| @measure('grad_norm', bn=True) | ||||
| def get_grad_norm_arr(net, inputs, targets, loss_fn, split_data=1, skip_grad=False): | ||||
|     net.zero_grad() | ||||
|     N = inputs.shape[0] | ||||
|     for sp in range(split_data): | ||||
|         st=sp*N//split_data | ||||
|         en=(sp+1)*N//split_data | ||||
|  | ||||
|         outputs = net.forward(inputs[st:en]) | ||||
|         loss = loss_fn(outputs, targets[st:en]) | ||||
|         loss.backward() | ||||
|  | ||||
|         grad_norm_arr = get_layer_metric_array(net, lambda l: l.weight.grad.norm() if l.weight.grad is not None else torch.zeros_like(l.weight), mode='param') | ||||
|          | ||||
|     return grad_norm_arr | ||||
							
								
								
									
										87
									
								
								zero-cost-nas/foresight/pruners/measures/grasp.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										87
									
								
								zero-cost-nas/foresight/pruners/measures/grasp.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,87 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| import torch.autograd as autograd | ||||
|  | ||||
| from . import measure | ||||
| from ..p_utils import get_layer_metric_array | ||||
|  | ||||
|  | ||||
| @measure('grasp', bn=True, mode='param') | ||||
| def compute_grasp_per_weight(net, inputs, targets, mode, loss_fn, T=1, num_iters=1, split_data=1): | ||||
|  | ||||
|     # get all applicable weights | ||||
|     weights = [] | ||||
|     for layer in net.modules(): | ||||
|         if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): | ||||
|             weights.append(layer.weight) | ||||
|             layer.weight.requires_grad_(True) # TODO isn't this already true? | ||||
|  | ||||
|     # NOTE original code had some input/target splitting into 2 | ||||
|     # I am guessing this was because of GPU mem limit | ||||
|     net.zero_grad() | ||||
|     N = inputs.shape[0] | ||||
|     for sp in range(split_data): | ||||
|         st=sp*N//split_data | ||||
|         en=(sp+1)*N//split_data | ||||
|  | ||||
|         #forward/grad pass #1 | ||||
|         grad_w = None | ||||
|         for _ in range(num_iters): | ||||
|             #TODO get new data, otherwise num_iters is useless! | ||||
|             outputs = net.forward(inputs[st:en])/T | ||||
|             loss = loss_fn(outputs, targets[st:en]) | ||||
|             grad_w_p = autograd.grad(loss, weights, allow_unused=True) | ||||
|             if grad_w is None: | ||||
|                 grad_w = list(grad_w_p) | ||||
|             else: | ||||
|                 for idx in range(len(grad_w)): | ||||
|                     grad_w[idx] += grad_w_p[idx] | ||||
|  | ||||
|      | ||||
|     for sp in range(split_data): | ||||
|         st=sp*N//split_data | ||||
|         en=(sp+1)*N//split_data | ||||
|  | ||||
|         # forward/grad pass #2 | ||||
|         outputs = net.forward(inputs[st:en])/T | ||||
|         loss = loss_fn(outputs, targets[st:en]) | ||||
|         grad_f = autograd.grad(loss, weights, create_graph=True, allow_unused=True) | ||||
|          | ||||
|         # accumulate gradients computed in previous step and call backwards | ||||
|         z, count = 0,0 | ||||
|         for layer in net.modules(): | ||||
|             if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): | ||||
|                 if grad_w[count] is not None: | ||||
|                     z += (grad_w[count].data * grad_f[count]).sum() | ||||
|                 count += 1 | ||||
|         z.backward() | ||||
|  | ||||
|     # compute final sensitivity metric and put in grads | ||||
|     def grasp(layer): | ||||
|         if layer.weight.grad is not None: | ||||
|             return -layer.weight.data * layer.weight.grad   # -theta_q Hg | ||||
|             #NOTE in the grasp code they take the *bottom* (1-p)% of values | ||||
|             #but we take the *top* (1-p)%, therefore we remove the -ve sign | ||||
|             #EDIT accuracy seems to be negatively correlated with this metric, so we add -ve sign here! | ||||
|         else: | ||||
|             return torch.zeros_like(layer.weight) | ||||
|      | ||||
|     grads = get_layer_metric_array(net, grasp, mode) | ||||
|  | ||||
|     return grads | ||||
							
								
								
									
										57
									
								
								zero-cost-nas/foresight/pruners/measures/jacob_cov.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								zero-cost-nas/foresight/pruners/measures/jacob_cov.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,57 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import torch | ||||
| import numpy as np | ||||
|  | ||||
| from . import measure | ||||
|  | ||||
|  | ||||
| def get_batch_jacobian(net, x, target, device, split_data): | ||||
|     x.requires_grad_(True) | ||||
|  | ||||
|     N = x.shape[0] | ||||
|     for sp in range(split_data): | ||||
|         st=sp*N//split_data | ||||
|         en=(sp+1)*N//split_data | ||||
|         y = net(x[st:en]) | ||||
|         y.backward(torch.ones_like(y)) | ||||
|  | ||||
|     jacob = x.grad.detach() | ||||
|     x.requires_grad_(False) | ||||
|     return jacob, target.detach() | ||||
|  | ||||
| def eval_score(jacob, labels=None): | ||||
|     corrs = np.corrcoef(jacob) | ||||
|     v, _  = np.linalg.eig(corrs) | ||||
|     k = 1e-5 | ||||
|     return -np.sum(np.log(v + k) + 1./(v + k)) | ||||
|  | ||||
| @measure('jacob_cov', bn=True) | ||||
| def compute_jacob_cov(net, inputs, targets, split_data=1, loss_fn=None): | ||||
|     device = inputs.device | ||||
|     # Compute gradients (but don't apply them) | ||||
|     net.zero_grad() | ||||
|  | ||||
|     jacobs, labels = get_batch_jacobian(net, inputs, targets, device, split_data=split_data) | ||||
|     jacobs = jacobs.reshape(jacobs.size(0), -1).cpu().numpy() | ||||
|  | ||||
|     try: | ||||
|         jc = eval_score(jacobs, labels) | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         jc = np.nan | ||||
|  | ||||
|     return jc | ||||
							
								
								
									
										22
									
								
								zero-cost-nas/foresight/pruners/measures/l2_norm.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								zero-cost-nas/foresight/pruners/measures/l2_norm.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| from . import measure | ||||
| from ..p_utils import get_layer_metric_array | ||||
|  | ||||
|  | ||||
| @measure('l2_norm', copy_net=False, mode='param') | ||||
| def get_l2_norm_array(net, inputs, targets, mode, split_data=1): | ||||
|     return get_layer_metric_array(net, lambda l: l.weight.norm(), mode=mode) | ||||
							
								
								
									
										69
									
								
								zero-cost-nas/foresight/pruners/measures/meco.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								zero-cost-nas/foresight/pruners/measures/meco.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,69 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| import copy | ||||
| import time | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import numpy as np | ||||
| import torch | ||||
| from torch import nn | ||||
|  | ||||
| from . import measure | ||||
|  | ||||
|  | ||||
| def get_score(net, x, target, device, split_data): | ||||
|     result_list = [] | ||||
|  | ||||
|     def forward_hook(module, data_input, data_output): | ||||
|  | ||||
|         fea = data_output[0].detach() | ||||
|         fea = fea.reshape(fea.shape[0], -1) | ||||
|         corr = torch.corrcoef(fea) | ||||
|         corr[torch.isnan(corr)] = 0 | ||||
|         corr[torch.isinf(corr)] = 0 | ||||
|         values = torch.linalg.eig(corr)[0] | ||||
|         # result = np.real(np.min(values)) / np.real(np.max(values)) | ||||
|         result = torch.min(torch.real(values)) | ||||
|         result_list.append(result) | ||||
|  | ||||
|     for name, modules in net.named_modules(): | ||||
|         modules.register_forward_hook(forward_hook) | ||||
|  | ||||
|  | ||||
|  | ||||
|     N = x.shape[0] | ||||
|     for sp in range(split_data): | ||||
|         st = sp * N // split_data | ||||
|         en = (sp + 1) * N // split_data | ||||
|         y = net(x[st:en]) | ||||
|     results = torch.tensor(result_list) | ||||
|     results = results[torch.logical_not(torch.isnan(results))] | ||||
|     v = torch.sum(results) | ||||
|     result_list.clear() | ||||
|     return v.item() | ||||
|  | ||||
|  | ||||
|  | ||||
| @measure('meco', bn=True) | ||||
| def compute_meco(net, inputs, targets, split_data=1, loss_fn=None): | ||||
|     device = inputs.device | ||||
|     # Compute gradients (but don't apply them) | ||||
|     net.zero_grad() | ||||
|  | ||||
|     try: | ||||
|         meco = get_score(net, inputs, targets, device, split_data=split_data) | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         meco = np.nan, None | ||||
|     return meco | ||||
							
								
								
									
										55
									
								
								zero-cost-nas/foresight/pruners/measures/norm.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								zero-cost-nas/foresight/pruners/measures/norm.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,55 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| import time | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import numpy as np | ||||
| import torch | ||||
|  | ||||
| from . import measure | ||||
|  | ||||
|  | ||||
| def get_score(net, x, target, device, split_data): | ||||
|     result_list = [] | ||||
|     def forward_hook(module, data_input, data_output): | ||||
|         norm = torch.norm(data_input[0]) | ||||
|         result_list.append(norm) | ||||
|     net.classifier.register_forward_hook(forward_hook) | ||||
|  | ||||
|     N = x.shape[0] | ||||
|     for sp in range(split_data): | ||||
|         st = sp * N // split_data | ||||
|         en = (sp + 1) * N // split_data | ||||
|         y = net(x[st:en]) | ||||
|     n = result_list[0].item() | ||||
|     result_list.clear() | ||||
|     return n | ||||
|  | ||||
|  | ||||
|  | ||||
| @measure('norm', bn=True) | ||||
| def compute_norm(net, inputs, targets, split_data=1, loss_fn=None): | ||||
|     device = inputs.device | ||||
|     # Compute gradients (but don't apply them) | ||||
|     net.zero_grad() | ||||
|  | ||||
|     # print('var:', feature.shape) | ||||
|     try: | ||||
|         norm, t = get_score(net, inputs, targets, device, split_data=split_data) | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         norm, t = np.nan, None | ||||
|     # print(jc) | ||||
|     # print(f'norm time: {t} s') | ||||
|     return norm, t | ||||
							
								
								
									
										16
									
								
								zero-cost-nas/foresight/pruners/measures/param_count.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								zero-cost-nas/foresight/pruners/measures/param_count.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,16 @@ | ||||
| import time | ||||
| import torch | ||||
|  | ||||
| from . import measure | ||||
| from ..p_utils import get_layer_metric_array | ||||
|  | ||||
|  | ||||
|  | ||||
| @measure('param_count', copy_net=False, mode='param') | ||||
| def get_param_count_array(net, inputs, targets, mode, loss_fn, split_data=1): | ||||
|     s = time.time() | ||||
|     count = get_layer_metric_array(net, lambda l: torch.tensor(sum(p.numel() for p in l.parameters() if p.requires_grad)), mode=mode) | ||||
|     e = time.time() | ||||
|     t = e - s | ||||
|     # print(f'param_count time: {t} s') | ||||
|     return count, t | ||||
							
								
								
									
										44
									
								
								zero-cost-nas/foresight/pruners/measures/plain.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								zero-cost-nas/foresight/pruners/measures/plain.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,44 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| from . import measure | ||||
| from ..p_utils import get_layer_metric_array | ||||
|  | ||||
|  | ||||
| @measure('plain', bn=True, mode='param') | ||||
| def compute_plain_per_weight(net, inputs, targets, mode, loss_fn, split_data=1): | ||||
|  | ||||
|     net.zero_grad() | ||||
|     N = inputs.shape[0] | ||||
|     for sp in range(split_data): | ||||
|         st=sp*N//split_data | ||||
|         en=(sp+1)*N//split_data | ||||
|  | ||||
|         outputs = net.forward(inputs[st:en]) | ||||
|         loss = loss_fn(outputs, targets[st:en]) | ||||
|         loss.backward() | ||||
|  | ||||
|     # select the gradients that we want to use for search/prune | ||||
|     def plain(layer): | ||||
|         if layer.weight.grad is not None: | ||||
|             return layer.weight.grad * layer.weight | ||||
|         else: | ||||
|             return torch.zeros_like(layer.weight) | ||||
|  | ||||
|     grads_abs = get_layer_metric_array(net, plain, mode) | ||||
|     return grads_abs | ||||
							
								
								
									
										69
									
								
								zero-cost-nas/foresight/pruners/measures/snip.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								zero-cost-nas/foresight/pruners/measures/snip.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,69 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import numpy as np | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| import copy | ||||
| import types | ||||
|  | ||||
| from . import measure | ||||
| from ..p_utils import get_layer_metric_array | ||||
|  | ||||
|  | ||||
| def snip_forward_conv2d(self, x): | ||||
|         return F.conv2d(x, self.weight * self.weight_mask, self.bias, | ||||
|                         self.stride, self.padding, self.dilation, self.groups) | ||||
|  | ||||
| def snip_forward_linear(self, x): | ||||
|         return F.linear(x, self.weight * self.weight_mask, self.bias) | ||||
|  | ||||
| @measure('snip', bn=True, mode='param') | ||||
| def compute_snip_per_weight(net, inputs, targets, mode, loss_fn, split_data=1): | ||||
|     for layer in net.modules(): | ||||
|         if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): | ||||
|             layer.weight_mask = nn.Parameter(torch.ones_like(layer.weight)) | ||||
|             layer.weight.requires_grad = False | ||||
|  | ||||
|         # Override the forward methods: | ||||
|         if isinstance(layer, nn.Conv2d): | ||||
|             layer.forward = types.MethodType(snip_forward_conv2d, layer) | ||||
|  | ||||
|         if isinstance(layer, nn.Linear): | ||||
|             layer.forward = types.MethodType(snip_forward_linear, layer) | ||||
|  | ||||
|     # Compute gradients (but don't apply them) | ||||
|     net.zero_grad() | ||||
|     N = inputs.shape[0] | ||||
|     for sp in range(split_data): | ||||
|         st=sp*N//split_data | ||||
|         en=(sp+1)*N//split_data | ||||
|      | ||||
|         outputs = net.forward(inputs[st:en]) | ||||
|         loss = loss_fn(outputs, targets[st:en]) | ||||
|         loss.backward() | ||||
|  | ||||
|     # select the gradients that we want to use for search/prune | ||||
|     def snip(layer): | ||||
|         if layer.weight_mask.grad is not None: | ||||
|             return torch.abs(layer.weight_mask.grad) | ||||
|         else: | ||||
|             return torch.zeros_like(layer.weight) | ||||
|      | ||||
|     grads_abs = get_layer_metric_array(net, snip, mode) | ||||
|  | ||||
|     return grads_abs | ||||
							
								
								
									
										69
									
								
								zero-cost-nas/foresight/pruners/measures/synflow.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								zero-cost-nas/foresight/pruners/measures/synflow.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,69 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import torch | ||||
|  | ||||
| from . import measure | ||||
| from ..p_utils import get_layer_metric_array | ||||
|  | ||||
|  | ||||
| @measure('synflow', bn=False, mode='param') | ||||
| @measure('synflow_bn', bn=True, mode='param') | ||||
| def compute_synflow_per_weight(net, inputs, targets, mode, split_data=1, loss_fn=None): | ||||
|  | ||||
|     device = inputs.device | ||||
|  | ||||
|     #convert params to their abs. Keep sign for converting it back. | ||||
|     @torch.no_grad() | ||||
|     def linearize(net): | ||||
|         signs = {} | ||||
|         for name, param in net.state_dict().items(): | ||||
|             signs[name] = torch.sign(param) | ||||
|             param.abs_() | ||||
|         return signs | ||||
|  | ||||
|     #convert to orig values | ||||
|     @torch.no_grad() | ||||
|     def nonlinearize(net, signs): | ||||
|         for name, param in net.state_dict().items(): | ||||
|             if 'weight_mask' not in name: | ||||
|                 param.mul_(signs[name]) | ||||
|  | ||||
|     # keep signs of all params | ||||
|     signs = linearize(net) | ||||
|      | ||||
|     # Compute gradients with input of 1s  | ||||
|     net.zero_grad() | ||||
|     net.double() | ||||
|     input_dim = list(inputs[0,:].shape) | ||||
|     inputs = torch.ones([1] + input_dim).double().to(device) | ||||
|     output = net.forward(inputs) | ||||
|     torch.sum(output).backward()  | ||||
|  | ||||
|     # select the gradients that we want to use for search/prune | ||||
|     def synflow(layer): | ||||
|         if layer.weight.grad is not None: | ||||
|             return torch.abs(layer.weight * layer.weight.grad) | ||||
|         else: | ||||
|             return torch.zeros_like(layer.weight) | ||||
|  | ||||
|     grads_abs = get_layer_metric_array(net, synflow, mode) | ||||
|  | ||||
|     # apply signs of all params | ||||
|     nonlinearize(net, signs) | ||||
|  | ||||
|     return grads_abs | ||||
|  | ||||
|  | ||||
							
								
								
									
										55
									
								
								zero-cost-nas/foresight/pruners/measures/var.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								zero-cost-nas/foresight/pruners/measures/var.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,55 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| import time | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import numpy as np | ||||
| import torch | ||||
|  | ||||
| from . import measure | ||||
|  | ||||
|  | ||||
| def get_score(net, x, target, device, split_data): | ||||
|     result_list = [] | ||||
|     def forward_hook(module, data_input, data_output): | ||||
|         var = torch.var(data_input[0]) | ||||
|         result_list.append(var) | ||||
|     net.classifier.register_forward_hook(forward_hook) | ||||
|  | ||||
|     N = x.shape[0] | ||||
|     for sp in range(split_data): | ||||
|         st = sp * N // split_data | ||||
|         en = (sp + 1) * N // split_data | ||||
|         y = net(x[st:en]) | ||||
|     v = result_list[0].item() | ||||
|     result_list.clear() | ||||
|     return v | ||||
|  | ||||
|  | ||||
|  | ||||
| @measure('var', bn=True) | ||||
| def compute_var(net, inputs, targets, split_data=1, loss_fn=None): | ||||
|     device = inputs.device | ||||
|     # Compute gradients (but don't apply them) | ||||
|     net.zero_grad() | ||||
|  | ||||
|     # print('var:', feature.shape) | ||||
|     try: | ||||
|         var= get_score(net, inputs, targets, device, split_data=split_data) | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         var= np.nan | ||||
|     # print(jc) | ||||
|     # print(f'var time: {t} s') | ||||
|     return var | ||||
							
								
								
									
										106
									
								
								zero-cost-nas/foresight/pruners/measures/zico.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										106
									
								
								zero-cost-nas/foresight/pruners/measures/zico.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,106 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| import time | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import numpy as np | ||||
| import torch | ||||
|  | ||||
| from . import measure | ||||
| from torch import nn | ||||
|  | ||||
| from ...dataset import get_cifar_dataloaders | ||||
|  | ||||
|  | ||||
| def getgrad(model: torch.nn.Module, grad_dict: dict, step_iter=0): | ||||
|     if step_iter == 0: | ||||
|         for name, mod in model.named_modules(): | ||||
|             if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Linear): | ||||
|                 # print(mod.weight.grad.data.size()) | ||||
|                 # print(mod.weight.data.size()) | ||||
|                 try: | ||||
|                     grad_dict[name] = [mod.weight.grad.data.cpu().reshape(-1).numpy()] | ||||
|                 except: | ||||
|                     continue | ||||
|     else: | ||||
|         for name, mod in model.named_modules(): | ||||
|             if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Linear): | ||||
|                 try: | ||||
|                     grad_dict[name].append(mod.weight.grad.data.cpu().reshape(-1).numpy()) | ||||
|                 except: | ||||
|                     continue | ||||
|     return grad_dict | ||||
|  | ||||
|  | ||||
| def caculate_zico(grad_dict): | ||||
|     allgrad_array = None | ||||
|     for i, modname in enumerate(grad_dict.keys()): | ||||
|         grad_dict[modname] = np.array(grad_dict[modname]) | ||||
|     nsr_mean_sum = 0 | ||||
|     nsr_mean_sum_abs = 0 | ||||
|     nsr_mean_avg = 0 | ||||
|     nsr_mean_avg_abs = 0 | ||||
|     for j, modname in enumerate(grad_dict.keys()): | ||||
|         nsr_std = np.std(grad_dict[modname], axis=0) | ||||
|         # print(grad_dict[modname].shape) | ||||
|         # print(grad_dict[modname].shape, nsr_std.shape) | ||||
|         nonzero_idx = np.nonzero(nsr_std)[0] | ||||
|         nsr_mean_abs = np.mean(np.abs(grad_dict[modname]), axis=0) | ||||
|         tmpsum = np.sum(nsr_mean_abs[nonzero_idx] / nsr_std[nonzero_idx]) | ||||
|         if tmpsum == 0: | ||||
|             pass | ||||
|         else: | ||||
|             nsr_mean_sum_abs += np.log(tmpsum) | ||||
|             nsr_mean_avg_abs += np.log(np.mean(nsr_mean_abs[nonzero_idx] / nsr_std[nonzero_idx])) | ||||
|     return nsr_mean_sum_abs | ||||
|  | ||||
|  | ||||
| def getzico(network, inputs, targets, loss_fn, split_data=2): | ||||
|     grad_dict = {} | ||||
|     network.train() | ||||
|     device = inputs.device | ||||
|     network.to(device) | ||||
|     N = inputs.shape[0] | ||||
|     split_data = 2 | ||||
|  | ||||
|     for sp in range(split_data): | ||||
|         st = sp * N // split_data | ||||
|         en = (sp + 1) * N // split_data | ||||
|         outputs = network.forward(inputs[st:en]) | ||||
|         loss = loss_fn(outputs, targets[st:en]) | ||||
|         loss.backward() | ||||
|         grad_dict = getgrad(network, grad_dict, sp) | ||||
|     # print(grad_dict) | ||||
|     res = caculate_zico(grad_dict) | ||||
|     return res | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| @measure('zico', bn=True) | ||||
| def compute_zico(net, inputs, targets, split_data=2, loss_fn=None): | ||||
|  | ||||
|     # Compute gradients (but don't apply them) | ||||
|     net.zero_grad() | ||||
|  | ||||
|     # print('var:', feature.shape) | ||||
|     try: | ||||
|         zico = getzico(net, inputs, targets, loss_fn, split_data=split_data) | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         zico= np.nan | ||||
|     # print(jc) | ||||
|     # print(f'var time: {t} s') | ||||
|     return zico | ||||
		Reference in New Issue
	
	Block a user