This commit is contained in:
HamsterMimi
2024-01-23 10:08:45 +08:00
parent 1a57decf65
commit 3f6d16e791
92 changed files with 12855 additions and 41 deletions

View File

@@ -0,0 +1,69 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
available_measures = []
_measure_impls = {}
def measure(name, bn=True, copy_net=True, force_clean=True, **impl_args):
def make_impl(func):
def measure_impl(net_orig, device, *args, **kwargs):
if copy_net:
net = net_orig.get_prunable_copy(bn=bn).to(device)
else:
net = net_orig
ret = func(net, *args, **kwargs, **impl_args)
if copy_net and force_clean:
import gc
import torch
del net
torch.cuda.empty_cache()
gc.collect()
return ret
global _measure_impls
if name in _measure_impls:
raise KeyError(f'Duplicated measure! {name}')
available_measures.append(name)
_measure_impls[name] = measure_impl
return func
return make_impl
def calc_measure(name, net, device, *args, **kwargs):
return _measure_impls[name](net, device, *args, **kwargs)
def load_all():
# from . import grad_norm
# from . import snip
# from . import grasp
# from . import fisher
# from . import jacob_cov
# from . import plain
# from . import synflow
# from . import var
# from . import cor
# from . import norm
from . import meco
# from . import zico
# from . import gradsign
# from . import ntk
# from . import zen
# TODO: should we do that by default?
load_all()

View File

@@ -0,0 +1,53 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
import time
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import numpy as np
import torch
from . import measure
def get_score(net, x, target, device, split_data):
result_list = []
def forward_hook(module, data_input, data_output):
corr = np.mean(np.corrcoef(data_input[0].detach().cpu().numpy()))
result_list.append(corr)
net.classifier.register_forward_hook(forward_hook)
N = x.shape[0]
for sp in range(split_data):
st = sp * N // split_data
en = (sp + 1) * N // split_data
y = net(x[st:en])
cor = result_list[0].item()
result_list.clear()
return cor
@measure('cor', bn=True)
def compute_norm(net, inputs, targets, split_data=1, loss_fn=None):
device = inputs.device
# Compute gradients (but don't apply them)
net.zero_grad()
try:
cor= get_score(net, inputs, targets, device, split_data=split_data)
except Exception as e:
print(e)
cor= np.nan
return cor

View File

@@ -0,0 +1,67 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
import copy
import time
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import numpy as np
from torch import nn
from . import measure
def get_score(net, x, target, device, split_data):
result_list = []
result_t = []
def forward_hook(module, data_input, data_output):
s = time.time()
fea = data_output[0].detach().cpu().numpy()
fea = fea.reshape(fea.shape[0], -1)
result = 1 / np.var(np.corrcoef(fea))
e = time.time()
t = e - s
result_list.append(result)
result_t.append(t)
for name, modules in net.named_modules():
modules.register_forward_hook(forward_hook)
N = x.shape[0]
for sp in range(split_data):
st = sp * N // split_data
en = (sp + 1) * N // split_data
y = net(x[st:en])
results = np.array(result_list)
results = results[np.logical_not(np.isnan(results))]
v = np.sum(results)
t = sum(result_t)
result_list.clear()
result_t.clear()
return v, t
@measure('cova', bn=True)
def compute_cova(net, inputs, targets, split_data=1, loss_fn=None):
device = inputs.device
# Compute gradients (but don't apply them)
net.zero_grad()
try:
cova, t = get_score(net, inputs, targets, device, split_data=split_data)
except Exception as e:
print(e)
cova, t = np.nan, None
return cova, t

View File

@@ -0,0 +1,107 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import types
from . import measure
from ..p_utils import get_layer_metric_array, reshape_elements
def fisher_forward_conv2d(self, x):
x = F.conv2d(x, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
#intercept and store the activations after passing through 'hooked' identity op
self.act = self.dummy(x)
return self.act
def fisher_forward_linear(self, x):
x = F.linear(x, self.weight, self.bias)
self.act = self.dummy(x)
return self.act
@measure('fisher', bn=True, mode='channel')
def compute_fisher_per_weight(net, inputs, targets, loss_fn, mode, split_data=1):
device = inputs.device
if mode == 'param':
raise ValueError('Fisher pruning does not support parameter pruning.')
net.train()
all_hooks = []
for layer in net.modules():
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
#variables/op needed for fisher computation
layer.fisher = None
layer.act = 0.
layer.dummy = nn.Identity()
#replace forward method of conv/linear
if isinstance(layer, nn.Conv2d):
layer.forward = types.MethodType(fisher_forward_conv2d, layer)
if isinstance(layer, nn.Linear):
layer.forward = types.MethodType(fisher_forward_linear, layer)
#function to call during backward pass (hooked on identity op at output of layer)
def hook_factory(layer):
def hook(module, grad_input, grad_output):
act = layer.act.detach()
grad = grad_output[0].detach()
if len(act.shape) > 2:
g_nk = torch.sum((act * grad), list(range(2,len(act.shape))))
else:
g_nk = act * grad
del_k = g_nk.pow(2).mean(0).mul(0.5)
if layer.fisher is None:
layer.fisher = del_k
else:
layer.fisher += del_k
del layer.act #without deleting this, a nasty memory leak occurs! related: https://discuss.pytorch.org/t/memory-leak-when-using-forward-hook-and-backward-hook-simultaneously/27555
return hook
#register backward hook on identity fcn to compute fisher info
layer.dummy.register_backward_hook(hook_factory(layer))
N = inputs.shape[0]
for sp in range(split_data):
st=sp*N//split_data
en=(sp+1)*N//split_data
net.zero_grad()
outputs = net(inputs[st:en])
loss = loss_fn(outputs, targets[st:en])
loss.backward()
# retrieve fisher info
def fisher(layer):
if layer.fisher is not None:
return torch.abs(layer.fisher.detach())
else:
return torch.zeros(layer.weight.shape[0]) #size=ch
grads_abs_ch = get_layer_metric_array(net, fisher, mode)
#broadcast channel value here to all parameters in that channel
#to be compatible with stuff downstream (which expects per-parameter metrics)
#TODO cleanup on the selectors/apply_prune_mask side (?)
shapes = get_layer_metric_array(net, lambda l : l.weight.shape[1:], mode)
grads_abs = reshape_elements(grads_abs_ch, shapes, device)
return grads_abs

View File

@@ -0,0 +1,38 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import torch
import torch.nn.functional as F
import copy
from . import measure
from ..p_utils import get_layer_metric_array
@measure('grad_norm', bn=True)
def get_grad_norm_arr(net, inputs, targets, loss_fn, split_data=1, skip_grad=False):
net.zero_grad()
N = inputs.shape[0]
for sp in range(split_data):
st=sp*N//split_data
en=(sp+1)*N//split_data
outputs = net.forward(inputs[st:en])
loss = loss_fn(outputs, targets[st:en])
loss.backward()
grad_norm_arr = get_layer_metric_array(net, lambda l: l.weight.grad.norm() if l.weight.grad is not None else torch.zeros_like(l.weight), mode='param')
return grad_norm_arr

View File

@@ -0,0 +1,76 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import torch
from torch import nn
import numpy as np
from . import measure
def get_flattened_metric(net, metric):
grad_list = []
for layer in net.modules():
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
grad_list.append(metric(layer).flatten())
flattened_grad = np.concatenate(grad_list)
return flattened_grad
def get_grad_conflict(net, inputs, targets, loss_fn):
N = inputs.shape[0]
batch_grad = []
for i in range(N):
net.zero_grad()
outputs = net.forward(inputs[[i]])
loss = loss_fn(outputs, targets[[i]])
loss.backward()
flattened_grad = get_flattened_metric(net, lambda
l: l.weight.grad.data.clone().cpu().numpy() if l.weight.grad is not None else torch.zeros_like(
l.weight).clone().cpu().numpy())
batch_grad.append(flattened_grad)
batch_grad = np.stack(batch_grad)
direction_code = np.sign(batch_grad)
direction_code = abs(direction_code.sum(axis=0))
score = np.nansum(direction_code)
return score
def get_gradsign(input, target, net, device, loss_fn):
s = []
net = net.to(device)
x, target = input, target
# x2 = torch.clone(x)
# x2 = x2.to(device)
x, target = x.to(device), target.to(device)
s.append(get_grad_conflict(net=net, inputs=x, targets=target, loss_fn=loss_fn))
s = np.mean(s)
return s
@measure('gradsign', bn=True)
def compute_gradsign(net, inputs, targets, split_data=1, loss_fn=None):
device = inputs.device
# Compute gradients (but don't apply them)
net.zero_grad()
try:
gradsign = get_gradsign(inputs, targets, net, device, loss_fn)
except Exception as e:
print(e)
gradsign= np.nan
return gradsign

View File

@@ -0,0 +1,87 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
from . import measure
from ..p_utils import get_layer_metric_array
@measure('grasp', bn=True, mode='param')
def compute_grasp_per_weight(net, inputs, targets, mode, loss_fn, T=1, num_iters=1, split_data=1):
# get all applicable weights
weights = []
for layer in net.modules():
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
weights.append(layer.weight)
layer.weight.requires_grad_(True) # TODO isn't this already true?
# NOTE original code had some input/target splitting into 2
# I am guessing this was because of GPU mem limit
net.zero_grad()
N = inputs.shape[0]
for sp in range(split_data):
st=sp*N//split_data
en=(sp+1)*N//split_data
#forward/grad pass #1
grad_w = None
for _ in range(num_iters):
#TODO get new data, otherwise num_iters is useless!
outputs = net.forward(inputs[st:en])/T
loss = loss_fn(outputs, targets[st:en])
grad_w_p = autograd.grad(loss, weights, allow_unused=True)
if grad_w is None:
grad_w = list(grad_w_p)
else:
for idx in range(len(grad_w)):
grad_w[idx] += grad_w_p[idx]
for sp in range(split_data):
st=sp*N//split_data
en=(sp+1)*N//split_data
# forward/grad pass #2
outputs = net.forward(inputs[st:en])/T
loss = loss_fn(outputs, targets[st:en])
grad_f = autograd.grad(loss, weights, create_graph=True, allow_unused=True)
# accumulate gradients computed in previous step and call backwards
z, count = 0,0
for layer in net.modules():
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
if grad_w[count] is not None:
z += (grad_w[count].data * grad_f[count]).sum()
count += 1
z.backward()
# compute final sensitivity metric and put in grads
def grasp(layer):
if layer.weight.grad is not None:
return -layer.weight.data * layer.weight.grad # -theta_q Hg
#NOTE in the grasp code they take the *bottom* (1-p)% of values
#but we take the *top* (1-p)%, therefore we remove the -ve sign
#EDIT accuracy seems to be negatively correlated with this metric, so we add -ve sign here!
else:
return torch.zeros_like(layer.weight)
grads = get_layer_metric_array(net, grasp, mode)
return grads

View File

@@ -0,0 +1,57 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import torch
import numpy as np
from . import measure
def get_batch_jacobian(net, x, target, device, split_data):
x.requires_grad_(True)
N = x.shape[0]
for sp in range(split_data):
st=sp*N//split_data
en=(sp+1)*N//split_data
y = net(x[st:en])
y.backward(torch.ones_like(y))
jacob = x.grad.detach()
x.requires_grad_(False)
return jacob, target.detach()
def eval_score(jacob, labels=None):
corrs = np.corrcoef(jacob)
v, _ = np.linalg.eig(corrs)
k = 1e-5
return -np.sum(np.log(v + k) + 1./(v + k))
@measure('jacob_cov', bn=True)
def compute_jacob_cov(net, inputs, targets, split_data=1, loss_fn=None):
device = inputs.device
# Compute gradients (but don't apply them)
net.zero_grad()
jacobs, labels = get_batch_jacobian(net, inputs, targets, device, split_data=split_data)
jacobs = jacobs.reshape(jacobs.size(0), -1).cpu().numpy()
try:
jc = eval_score(jacobs, labels)
except Exception as e:
print(e)
jc = np.nan
return jc

View File

@@ -0,0 +1,22 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
from . import measure
from ..p_utils import get_layer_metric_array
@measure('l2_norm', copy_net=False, mode='param')
def get_l2_norm_array(net, inputs, targets, mode, split_data=1):
return get_layer_metric_array(net, lambda l: l.weight.norm(), mode=mode)

View File

@@ -0,0 +1,63 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
import time
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import numpy as np
import torch
from . import measure
def get_score(net, x, target, device, split_data):
result_list = []
def forward_hook(module, data_input, data_output):
s = time.time()
mean = torch.mean(data_input[0])
e = time.time()
t = e - s
result_list.append(mean)
result_list.append(t)
net.classifier.register_forward_hook(forward_hook)
N = x.shape[0]
for sp in range(split_data):
st = sp * N // split_data
en = (sp + 1) * N // split_data
# t1 = time.time()
y = net(x[st:en])
# t2 = time.time()
# print('var:', t2-t1)
m = result_list[0].item()
t = result_list[1]
result_list.clear()
return m, t
@measure('mean', bn=True)
def compute_mean(net, inputs, targets, split_data=1, loss_fn=None):
device = inputs.device
# Compute gradients (but don't apply them)
net.zero_grad()
# print('var:', features.shape)
try:
mean, t = get_score(net, inputs, targets, device, split_data=split_data)
except Exception as e:
print(e)
mean, t = np.nan, None
# print(jc)
# print(f'var time: {t} s')
return mean, t

View File

@@ -0,0 +1,73 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
import copy
import time
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import numpy as np
import torch
from torch import nn
from . import measure
def get_score(net, x, target, device, split_data):
result_list = []
x = torch.randn(size=(1, 3, 64, 64)).to(device)
net.to(device)
def forward_hook(module, data_input, data_output):
fea = data_output[0].detach()
fea = fea.reshape(fea.shape[0], -1)
n = fea.shape[0]
corr = torch.corrcoef(fea)
corr[torch.isnan(corr)] = 0
corr[torch.isinf(corr)] = 0
values = torch.linalg.eig(corr)[0]
# result = np.real(np.min(values)) / np.real(np.max(values))
result = torch.min(torch.real(values))
result_list.append(result)
for name, modules in net.named_modules():
modules.register_forward_hook(forward_hook)
N = x.shape[0]
for sp in range(split_data):
st = sp * N // split_data
en = (sp + 1) * N // split_data
y = net(x[st:en])
# break
results = torch.tensor(result_list)
results = results[torch.logical_not(torch.isnan(results))]
v = torch.sum(results)
result_list.clear()
return v.item()
@measure('meco', bn=True)
def compute_meco(net, inputs, targets, split_data=1, loss_fn=None):
device = inputs.device
# Compute gradients (but don't apply them)
net.zero_grad()
try:
meco = get_score(net, inputs, targets, device, split_data=split_data)
except Exception as e:
print(e)
meco = np.nan, None
return meco

View File

@@ -0,0 +1,55 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
import time
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import numpy as np
import torch
from . import measure
def get_score(net, x, target, device, split_data):
result_list = []
def forward_hook(module, data_input, data_output):
norm = torch.norm(data_input[0])
result_list.append(norm)
net.classifier.register_forward_hook(forward_hook)
N = x.shape[0]
for sp in range(split_data):
st = sp * N // split_data
en = (sp + 1) * N // split_data
y = net(x[st:en])
n = result_list[0].item()
result_list.clear()
return n
@measure('norm', bn=True)
def compute_norm(net, inputs, targets, split_data=1, loss_fn=None):
device = inputs.device
# Compute gradients (but don't apply them)
net.zero_grad()
# print('var:', feature.shape)
try:
norm, t = get_score(net, inputs, targets, device, split_data=split_data)
except Exception as e:
print(e)
norm, t = np.nan, None
# print(jc)
# print(f'norm time: {t} s')
return norm, t

View File

@@ -0,0 +1,94 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import torch
import numpy as np
from . import measure
def recal_bn(network, inputs, targets, recalbn, device):
for m in network.modules():
if isinstance(m, torch.nn.BatchNorm2d):
m.running_mean.data.fill_(0)
m.running_var.data.fill_(0)
m.num_batches_tracked.data.zero_()
m.momentum = None
network.train()
with torch.no_grad():
for i, (inputs, targets) in enumerate(zip(inputs, targets)):
if i >= recalbn: break
inputs = inputs.cuda(device=device, non_blocking=True)
_, _ = network(inputs)
return network
def get_ntk_n(inputs, targets, network, device, recalbn=0, train_mode=False, num_batch=1):
device = device
# if recalbn > 0:
# network = recal_bn(network, xloader, recalbn, device)
# if network_2 is not None:
# network_2 = recal_bn(network_2, xloader, recalbn, device)
network.eval()
networks = []
networks.append(network)
ntks = []
# if train_mode:
# networks.train()
# else:
# networks.eval()
######
grads = [[] for _ in range(len(networks))]
for i in range(num_batch):
if num_batch > 0 and i >= num_batch: break
inputs = inputs.cuda(device=device, non_blocking=True)
for net_idx, network in enumerate(networks):
network.zero_grad()
# print(inputs.size())
inputs_ = inputs.clone().cuda(device=device, non_blocking=True)
logit = network(inputs_)
if isinstance(logit, tuple):
logit = logit[1] # 201 networks: return features and logits
for _idx in range(len(inputs_)):
logit[_idx:_idx + 1].backward(torch.ones_like(logit[_idx:_idx + 1]), retain_graph=True)
grad = []
for name, W in network.named_parameters():
if 'weight' in name and W.grad is not None:
grad.append(W.grad.view(-1).detach())
grads[net_idx].append(torch.cat(grad, -1))
network.zero_grad()
torch.cuda.empty_cache()
######
grads = [torch.stack(_grads, 0) for _grads in grads]
ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads]
for ntk in ntks:
eigenvalues, _ = torch.linalg.eigh(ntk) # ascending
conds = np.nan_to_num((eigenvalues[-1] / eigenvalues[0]).item(), copy=True, nan=100000.0)
return conds
@measure('ntk', bn=True)
def compute_ntk(net, inputs, targets, split_data=1, loss_fn=None):
device = inputs.device
# Compute gradients (but don't apply them)
net.zero_grad()
try:
conds = get_ntk_n(inputs, targets, net, device)
except Exception as e:
print(e)
conds= np.nan
return conds

View File

@@ -0,0 +1,16 @@
import time
import torch
from . import measure
from ..p_utils import get_layer_metric_array
@measure('param_count', copy_net=False, mode='param')
def get_param_count_array(net, inputs, targets, mode, loss_fn, split_data=1):
s = time.time()
count = get_layer_metric_array(net, lambda l: torch.tensor(sum(p.numel() for p in l.parameters() if p.requires_grad)), mode=mode)
e = time.time()
t = e - s
# print(f'param_count time: {t} s')
return count, t

View File

@@ -0,0 +1,71 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
import copy
import time
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import numpy as np
from torch import nn
# import pandas as pd
from . import measure
def get_score(net, x, target, device, split_data):
result_list = []
result_t = []
def forward_hook(module, data_input, data_output):
s = time.time()
fea = data_output[0].detach().cpu().numpy()
fea = fea.reshape(fea.shape[0], -1)
# result = 1 / np.var(np.corrcoef(fea))
result = np.var(np.corrcoef(fea))
e = time.time()
t = e - s
result_list.append(result)
result_t.append(t)
for name, modules in net.named_modules():
modules.register_forward_hook(forward_hook)
N = x.shape[0]
for sp in range(split_data):
st = sp * N // split_data
en = (sp + 1) * N // split_data
y = net(x[st:en])
# print(y)
results = np.array(result_list)
results = results[np.logical_not(np.isnan(results))]
v = np.sum(results)
t = sum(result_t)
result_list.clear()
result_t.clear()
return v, t
@measure('pearson', bn=True)
def compute_pearson(net, inputs, targets, split_data=1, loss_fn=None):
device = inputs.device
# Compute gradients (but don't apply them)
net.zero_grad()
try:
pearson, t = get_score(net, inputs, targets, device, split_data=split_data)
except Exception as e:
print(e)
pearson, t = np.nan, None
return pearson, t

View File

@@ -0,0 +1,44 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import torch
import torch.nn.functional as F
from . import measure
from ..p_utils import get_layer_metric_array
@measure('plain', bn=True, mode='param')
def compute_plain_per_weight(net, inputs, targets, mode, loss_fn, split_data=1):
net.zero_grad()
N = inputs.shape[0]
for sp in range(split_data):
st=sp*N//split_data
en=(sp+1)*N//split_data
outputs = net.forward(inputs[st:en])
loss = loss_fn(outputs, targets[st:en])
loss.backward()
# select the gradients that we want to use for search/prune
def plain(layer):
if layer.weight.grad is not None:
return layer.weight.grad * layer.weight
else:
return torch.zeros_like(layer.weight)
grads_abs = get_layer_metric_array(net, plain, mode)
return grads_abs

View File

@@ -0,0 +1,69 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import types
from . import measure
from ..p_utils import get_layer_metric_array
def snip_forward_conv2d(self, x):
return F.conv2d(x, self.weight * self.weight_mask, self.bias,
self.stride, self.padding, self.dilation, self.groups)
def snip_forward_linear(self, x):
return F.linear(x, self.weight * self.weight_mask, self.bias)
@measure('snip', bn=True, mode='param')
def compute_snip_per_weight(net, inputs, targets, mode, loss_fn, split_data=1):
for layer in net.modules():
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
layer.weight_mask = nn.Parameter(torch.ones_like(layer.weight))
layer.weight.requires_grad = False
# Override the forward methods:
if isinstance(layer, nn.Conv2d):
layer.forward = types.MethodType(snip_forward_conv2d, layer)
if isinstance(layer, nn.Linear):
layer.forward = types.MethodType(snip_forward_linear, layer)
# Compute gradients (but don't apply them)
net.zero_grad()
N = inputs.shape[0]
for sp in range(split_data):
st=sp*N//split_data
en=(sp+1)*N//split_data
outputs = net.forward(inputs[st:en])
loss = loss_fn(outputs, targets[st:en])
loss.backward()
# select the gradients that we want to use for search/prune
def snip(layer):
if layer.weight_mask.grad is not None:
return torch.abs(layer.weight_mask.grad)
else:
return torch.zeros_like(layer.weight)
grads_abs = get_layer_metric_array(net, snip, mode)
return grads_abs

View File

@@ -0,0 +1,69 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import torch
from . import measure
from ..p_utils import get_layer_metric_array
@measure('synflow', bn=False, mode='param')
@measure('synflow_bn', bn=True, mode='param')
def compute_synflow_per_weight(net, inputs, targets, mode, split_data=1, loss_fn=None):
device = inputs.device
#convert params to their abs. Keep sign for converting it back.
@torch.no_grad()
def linearize(net):
signs = {}
for name, param in net.state_dict().items():
signs[name] = torch.sign(param)
param.abs_()
return signs
#convert to orig values
@torch.no_grad()
def nonlinearize(net, signs):
for name, param in net.state_dict().items():
if 'weight_mask' not in name:
param.mul_(signs[name])
# keep signs of all params
signs = linearize(net)
# Compute gradients with input of 1s
net.zero_grad()
net.double()
input_dim = list(inputs[0,:].shape)
inputs = torch.ones([1] + input_dim).double().to(device)
output = net.forward(inputs)
torch.sum(output).backward()
# select the gradients that we want to use for search/prune
def synflow(layer):
if layer.weight.grad is not None:
return torch.abs(layer.weight * layer.weight.grad)
else:
return torch.zeros_like(layer.weight)
grads_abs = get_layer_metric_array(net, synflow, mode)
# apply signs of all params
nonlinearize(net, signs)
return grads_abs

View File

@@ -0,0 +1,55 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
import time
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import numpy as np
import torch
from . import measure
def get_score(net, x, target, device, split_data):
result_list = []
def forward_hook(module, data_input, data_output):
var = torch.var(data_input[0])
result_list.append(var)
net.classifier.register_forward_hook(forward_hook)
N = x.shape[0]
for sp in range(split_data):
st = sp * N // split_data
en = (sp + 1) * N // split_data
y = net(x[st:en])
v = result_list[0].item()
result_list.clear()
return v
@measure('var', bn=True)
def compute_var(net, inputs, targets, split_data=1, loss_fn=None):
device = inputs.device
# Compute gradients (but don't apply them)
net.zero_grad()
# print('var:', feature.shape)
try:
var= get_score(net, inputs, targets, device, split_data=split_data)
except Exception as e:
print(e)
var= np.nan
# print(jc)
# print(f'var time: {t} s')
return var

View File

@@ -0,0 +1,110 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import torch
from torch import nn
import numpy as np
from . import measure
def network_weight_gaussian_init(net: nn.Module):
with torch.no_grad():
for n, m in net.named_modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
try:
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
except:
pass
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.zeros_(m.bias)
else:
continue
return net
def get_zen(gpu, model, mixup_gamma=1e-2, resolution=32, batch_size=64, repeat=32,
fp16=False):
info = {}
nas_score_list = []
if gpu is not None:
device = torch.device(gpu)
else:
device = torch.device('cpu')
if fp16:
dtype = torch.half
else:
dtype = torch.float32
with torch.no_grad():
for repeat_count in range(repeat):
network_weight_gaussian_init(model)
input = torch.randn(size=[batch_size, 3, resolution, resolution], device=device, dtype=dtype)
input2 = torch.randn(size=[batch_size, 3, resolution, resolution], device=device, dtype=dtype)
mixup_input = input + mixup_gamma * input2
output = model.forward_pre_GAP(input)
mixup_output = model.forward_pre_GAP(mixup_input)
nas_score = torch.sum(torch.abs(output - mixup_output), dim=[1, 2, 3])
nas_score = torch.mean(nas_score)
# compute BN scaling
log_bn_scaling_factor = 0.0
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
try:
bn_scaling_factor = torch.sqrt(torch.mean(m.running_var))
log_bn_scaling_factor += torch.log(bn_scaling_factor)
except:
pass
pass
pass
nas_score = torch.log(nas_score) + log_bn_scaling_factor
nas_score_list.append(float(nas_score))
std_nas_score = np.std(nas_score_list)
avg_precision = 1.96 * std_nas_score / np.sqrt(len(nas_score_list))
avg_nas_score = np.mean(nas_score_list)
info = float(avg_nas_score)
return info
@measure('zen', bn=True)
def compute_zen(net, inputs, targets, split_data=1, loss_fn=None):
device = inputs.device
# Compute gradients (but don't apply them)
net.zero_grad()
try:
zen = get_zen(device,net)
except Exception as e:
print(e)
zen= np.nan
return zen

View File

@@ -0,0 +1,106 @@
# Copyright 2021 Samsung Electronics Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
import time
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import numpy as np
import torch
from . import measure
from torch import nn
from ...dataset import get_cifar_dataloaders
def getgrad(model: torch.nn.Module, grad_dict: dict, step_iter=0):
if step_iter == 0:
for name, mod in model.named_modules():
if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Linear):
# print(mod.weight.grad.data.size())
# print(mod.weight.data.size())
try:
grad_dict[name] = [mod.weight.grad.data.cpu().reshape(-1).numpy()]
except:
continue
else:
for name, mod in model.named_modules():
if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Linear):
try:
grad_dict[name].append(mod.weight.grad.data.cpu().reshape(-1).numpy())
except:
continue
return grad_dict
def caculate_zico(grad_dict):
allgrad_array = None
for i, modname in enumerate(grad_dict.keys()):
grad_dict[modname] = np.array(grad_dict[modname])
nsr_mean_sum = 0
nsr_mean_sum_abs = 0
nsr_mean_avg = 0
nsr_mean_avg_abs = 0
for j, modname in enumerate(grad_dict.keys()):
nsr_std = np.std(grad_dict[modname], axis=0)
# print(grad_dict[modname].shape)
# print(grad_dict[modname].shape, nsr_std.shape)
nonzero_idx = np.nonzero(nsr_std)[0]
nsr_mean_abs = np.mean(np.abs(grad_dict[modname]), axis=0)
tmpsum = np.sum(nsr_mean_abs[nonzero_idx] / nsr_std[nonzero_idx])
if tmpsum == 0:
pass
else:
nsr_mean_sum_abs += np.log(tmpsum)
nsr_mean_avg_abs += np.log(np.mean(nsr_mean_abs[nonzero_idx] / nsr_std[nonzero_idx]))
return nsr_mean_sum_abs
def getzico(network, inputs, targets, loss_fn, split_data=2):
grad_dict = {}
network.train()
device = inputs.device
network.to(device)
N = inputs.shape[0]
split_data = 2
for sp in range(split_data):
st = sp * N // split_data
en = (sp + 1) * N // split_data
outputs = network.forward(inputs[st:en])
loss = loss_fn(outputs, targets[st:en])
loss.backward()
grad_dict = getgrad(network, grad_dict, sp)
# print(grad_dict)
res = caculate_zico(grad_dict)
return res
@measure('zico', bn=True)
def compute_zico(net, inputs, targets, split_data=2, loss_fn=None):
# Compute gradients (but don't apply them)
net.zero_grad()
# print('var:', feature.shape)
try:
zico = getzico(net, inputs, targets, loss_fn, split_data=split_data)
except Exception as e:
print(e)
zico= np.nan
# print(jc)
# print(f'var time: {t} s')
return zico