16 lines
		
	
	
		
			475 B
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			16 lines
		
	
	
		
			475 B
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | import time | ||
|  | import torch | ||
|  | 
 | ||
|  | from . import measure | ||
|  | from ..p_utils import get_layer_metric_array | ||
|  | 
 | ||
|  | 
 | ||
|  | 
 | ||
|  | @measure('param_count', copy_net=False, mode='param') | ||
|  | def get_param_count_array(net, inputs, targets, mode, loss_fn, split_data=1): | ||
|  |     s = time.time() | ||
|  |     count = get_layer_metric_array(net, lambda l: torch.tensor(sum(p.numel() for p in l.parameters() if p.requires_grad)), mode=mode) | ||
|  |     e = time.time() | ||
|  |     t = e - s | ||
|  |     # print(f'param_count time: {t} s') | ||
|  |     return count, t |