16 lines
475 B
Python
16 lines
475 B
Python
|
import time
|
||
|
import torch
|
||
|
|
||
|
from . import measure
|
||
|
from ..p_utils import get_layer_metric_array
|
||
|
|
||
|
|
||
|
|
||
|
@measure('param_count', copy_net=False, mode='param')
|
||
|
def get_param_count_array(net, inputs, targets, mode, loss_fn, split_data=1):
|
||
|
s = time.time()
|
||
|
count = get_layer_metric_array(net, lambda l: torch.tensor(sum(p.numel() for p in l.parameters() if p.requires_grad)), mode=mode)
|
||
|
e = time.time()
|
||
|
t = e - s
|
||
|
# print(f'param_count time: {t} s')
|
||
|
return count, t
|