update
This commit is contained in:
		
							
								
								
									
										76
									
								
								zero-cost-nas/foresight/pruners/measures/gradsign.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										76
									
								
								zero-cost-nas/foresight/pruners/measures/gradsign.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,76 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import torch | ||||
| from torch import nn | ||||
| import numpy as np | ||||
|  | ||||
| from . import measure | ||||
|  | ||||
|  | ||||
| def get_flattened_metric(net, metric): | ||||
|     grad_list = [] | ||||
|     for layer in net.modules(): | ||||
|         if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): | ||||
|             grad_list.append(metric(layer).flatten()) | ||||
|     flattened_grad = np.concatenate(grad_list) | ||||
|  | ||||
|     return flattened_grad | ||||
|  | ||||
|  | ||||
| def get_grad_conflict(net, inputs, targets, loss_fn): | ||||
|     N = inputs.shape[0] | ||||
|     batch_grad = [] | ||||
|     for i in range(N): | ||||
|         net.zero_grad() | ||||
|         outputs = net.forward(inputs[[i]]) | ||||
|         loss = loss_fn(outputs, targets[[i]]) | ||||
|         loss.backward() | ||||
|         flattened_grad = get_flattened_metric(net, lambda | ||||
|             l: l.weight.grad.data.clone().cpu().numpy() if l.weight.grad is not None else torch.zeros_like( | ||||
|             l.weight).clone().cpu().numpy()) | ||||
|         batch_grad.append(flattened_grad) | ||||
|     batch_grad = np.stack(batch_grad) | ||||
|     direction_code = np.sign(batch_grad) | ||||
|     direction_code = abs(direction_code.sum(axis=0)) | ||||
|     score = np.nansum(direction_code) | ||||
|     return score | ||||
|  | ||||
|  | ||||
| def get_gradsign(input, target, net, device, loss_fn): | ||||
|     s = [] | ||||
|     net = net.to(device) | ||||
|     x, target = input, target | ||||
|     # x2 = torch.clone(x) | ||||
|     # x2 = x2.to(device) | ||||
|     x, target = x.to(device), target.to(device) | ||||
|     s.append(get_grad_conflict(net=net, inputs=x, targets=target, loss_fn=loss_fn)) | ||||
|     s = np.mean(s) | ||||
|     return s | ||||
|  | ||||
| @measure('gradsign', bn=True) | ||||
| def compute_gradsign(net, inputs, targets, split_data=1, loss_fn=None): | ||||
|     device = inputs.device | ||||
|     # Compute gradients (but don't apply them) | ||||
|     net.zero_grad() | ||||
|  | ||||
|  | ||||
|     try: | ||||
|         gradsign = get_gradsign(inputs, targets, net, device, loss_fn) | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         gradsign= np.nan | ||||
|  | ||||
|     return gradsign | ||||
							
								
								
									
										94
									
								
								zero-cost-nas/foresight/pruners/measures/ntk.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								zero-cost-nas/foresight/pruners/measures/ntk.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,94 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import torch | ||||
| import numpy as np | ||||
|  | ||||
| from . import measure | ||||
|  | ||||
|  | ||||
| def recal_bn(network, inputs, targets, recalbn, device): | ||||
|     for m in network.modules(): | ||||
|         if isinstance(m, torch.nn.BatchNorm2d): | ||||
|             m.running_mean.data.fill_(0) | ||||
|             m.running_var.data.fill_(0) | ||||
|             m.num_batches_tracked.data.zero_() | ||||
|             m.momentum = None | ||||
|     network.train() | ||||
|     with torch.no_grad(): | ||||
|         for i, (inputs, targets) in enumerate(zip(inputs, targets)): | ||||
|             if i >= recalbn: break | ||||
|             inputs = inputs.cuda(device=device, non_blocking=True) | ||||
|             _, _ = network(inputs) | ||||
|     return network | ||||
|  | ||||
|  | ||||
| def get_ntk_n(inputs, targets, network, device, recalbn=0, train_mode=False, num_batch=1): | ||||
|     device = device | ||||
|     # if recalbn > 0: | ||||
|     #     network = recal_bn(network, xloader, recalbn, device) | ||||
|     #     if network_2 is not None: | ||||
|     #         network_2 = recal_bn(network_2, xloader, recalbn, device) | ||||
|     network.eval() | ||||
|     networks = [] | ||||
|     networks.append(network) | ||||
|     ntks = [] | ||||
|     # if train_mode: | ||||
|     #     networks.train() | ||||
|     # else: | ||||
|     #     networks.eval() | ||||
|     ###### | ||||
|     grads = [[] for _ in range(len(networks))] | ||||
|     for i in range(num_batch): | ||||
|         if num_batch > 0 and i >= num_batch: break | ||||
|         inputs = inputs.cuda(device=device, non_blocking=True) | ||||
|         for net_idx, network in enumerate(networks): | ||||
|             network.zero_grad() | ||||
|             # print(inputs.size()) | ||||
|             inputs_ = inputs.clone().cuda(device=device, non_blocking=True) | ||||
|             logit = network(inputs_) | ||||
|             if isinstance(logit, tuple): | ||||
|                 logit = logit[1]  # 201 networks: return features and logits | ||||
|             for _idx in range(len(inputs_)): | ||||
|                 logit[_idx:_idx + 1].backward(torch.ones_like(logit[_idx:_idx + 1]), retain_graph=True) | ||||
|                 grad = [] | ||||
|                 for name, W in network.named_parameters(): | ||||
|                     if 'weight' in name and W.grad is not None: | ||||
|                         grad.append(W.grad.view(-1).detach()) | ||||
|                 grads[net_idx].append(torch.cat(grad, -1)) | ||||
|                 network.zero_grad() | ||||
|                 torch.cuda.empty_cache() | ||||
|     ###### | ||||
|     grads = [torch.stack(_grads, 0) for _grads in grads] | ||||
|     ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads] | ||||
|     for ntk in ntks: | ||||
|         eigenvalues, _ = torch.linalg.eigh(ntk)  # ascending | ||||
|         conds = np.nan_to_num((eigenvalues[0] / eigenvalues[-1]).item(), copy=True, nan=100000.0) | ||||
|     return conds | ||||
|  | ||||
| @measure('ntk', bn=True) | ||||
| def compute_ntk(net, inputs, targets, split_data=1, loss_fn=None): | ||||
|     device = inputs.device | ||||
|     # Compute gradients (but don't apply them) | ||||
|     net.zero_grad() | ||||
|  | ||||
|  | ||||
|     try: | ||||
|         conds = get_ntk_n(inputs, targets, net, device) | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         conds= np.nan | ||||
|  | ||||
|     return conds | ||||
							
								
								
									
										110
									
								
								zero-cost-nas/foresight/pruners/measures/zen.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										110
									
								
								zero-cost-nas/foresight/pruners/measures/zen.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,110 @@ | ||||
| # Copyright 2021 Samsung Electronics Co., Ltd. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
|  | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================= | ||||
|  | ||||
| import torch | ||||
| from torch import nn | ||||
| import numpy as np | ||||
|  | ||||
| from . import measure | ||||
|  | ||||
|  | ||||
| def network_weight_gaussian_init(net: nn.Module): | ||||
|     with torch.no_grad(): | ||||
|         for n, m in net.named_modules(): | ||||
|             if isinstance(m, nn.Conv2d): | ||||
|                 nn.init.normal_(m.weight) | ||||
|                 if hasattr(m, 'bias') and m.bias is not None: | ||||
|                     nn.init.zeros_(m.bias) | ||||
|             elif isinstance(m, nn.BatchNorm2d): | ||||
|                 try: | ||||
|                     nn.init.ones_(m.weight) | ||||
|                     nn.init.zeros_(m.bias) | ||||
|                 except: | ||||
|                     pass | ||||
|             elif isinstance(m, nn.Linear): | ||||
|                 nn.init.normal_(m.weight) | ||||
|                 if hasattr(m, 'bias') and m.bias is not None: | ||||
|                     nn.init.zeros_(m.bias) | ||||
|             else: | ||||
|                 continue | ||||
|  | ||||
|     return net | ||||
|  | ||||
|  | ||||
| def get_zen(gpu, model, mixup_gamma=1e-2, resolution=32, batch_size=64, repeat=32, | ||||
|                       fp16=False): | ||||
|     info = {} | ||||
|     nas_score_list = [] | ||||
|     if gpu is not None: | ||||
|         device = torch.device(gpu) | ||||
|     else: | ||||
|         device = torch.device('cpu') | ||||
|  | ||||
|     if fp16: | ||||
|         dtype = torch.half | ||||
|     else: | ||||
|         dtype = torch.float32 | ||||
|  | ||||
|     with torch.no_grad(): | ||||
|         for repeat_count in range(repeat): | ||||
|             network_weight_gaussian_init(model) | ||||
|             input = torch.randn(size=[batch_size, 3, resolution, resolution], device=device, dtype=dtype) | ||||
|             input2 = torch.randn(size=[batch_size, 3, resolution, resolution], device=device, dtype=dtype) | ||||
|             mixup_input = input + mixup_gamma * input2 | ||||
|             output = model.forward_pre_GAP(input) | ||||
|             mixup_output = model.forward_pre_GAP(mixup_input) | ||||
|  | ||||
|             nas_score = torch.sum(torch.abs(output - mixup_output), dim=[1, 2, 3]) | ||||
|             nas_score = torch.mean(nas_score) | ||||
|  | ||||
|             # compute BN scaling | ||||
|             log_bn_scaling_factor = 0.0 | ||||
|             for m in model.modules(): | ||||
|                 if isinstance(m, nn.BatchNorm2d): | ||||
|                     try: | ||||
|                         bn_scaling_factor = torch.sqrt(torch.mean(m.running_var)) | ||||
|                         log_bn_scaling_factor += torch.log(bn_scaling_factor) | ||||
|                     except: | ||||
|                         pass | ||||
|                 pass | ||||
|             pass | ||||
|             nas_score = torch.log(nas_score) + log_bn_scaling_factor | ||||
|             nas_score_list.append(float(nas_score)) | ||||
|  | ||||
|     std_nas_score = np.std(nas_score_list) | ||||
|     avg_precision = 1.96 * std_nas_score / np.sqrt(len(nas_score_list)) | ||||
|     avg_nas_score = np.mean(nas_score_list) | ||||
|  | ||||
|     info = float(avg_nas_score) | ||||
|     return info | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| @measure('zen', bn=True) | ||||
| def compute_zen(net, inputs, targets, split_data=1, loss_fn=None): | ||||
|     device = inputs.device | ||||
|     # Compute gradients (but don't apply them) | ||||
|     net.zero_grad() | ||||
|  | ||||
|  | ||||
|     try: | ||||
|         zen = get_zen(device,net) | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         zen= np.nan | ||||
|  | ||||
|     return zen | ||||
| @@ -108,7 +108,7 @@ def find_measures(net_orig,                  # neural network | ||||
|  | ||||
|     measures = {} | ||||
|     for k,v in measures_arr.items(): | ||||
|         if k in ['jacob_cov', 'meco', 'zico']: | ||||
|         if k in ['jacob_cov', 'var', 'cor', 'norm', 'meco', 'zico', 'ntk', 'gradsign', 'zen']: | ||||
|             measures[k] = v | ||||
|         else: | ||||
|             measures[k] = sum_arr(v) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user