114 lines
4.1 KiB
Python
114 lines
4.1 KiB
Python
# modified from https://github.com/warmspringwinds/pytorch-segmentation-detection/blob/master/pytorch_segmentation_detection/utils/flops_benchmark.py
|
|
import copy, torch
|
|
|
|
def print_FLOPs(model, shape, logs):
|
|
print_log, log = logs
|
|
model = copy.deepcopy( model )
|
|
|
|
model = add_flops_counting_methods(model)
|
|
model = model.cuda()
|
|
model.eval()
|
|
|
|
cache_inputs = torch.zeros(*shape).cuda()
|
|
#print_log('In the calculating function : cache input size : {:}'.format(cache_inputs.size()), log)
|
|
_ = model(cache_inputs)
|
|
FLOPs = compute_average_flops_cost( model ) / 1e6
|
|
print_log('FLOPs : {:} MB'.format(FLOPs), log)
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
# ---- Public functions
|
|
def add_flops_counting_methods( model ):
|
|
model.__batch_counter__ = 0
|
|
add_batch_counter_hook_function( model )
|
|
model.apply( add_flops_counter_variable_or_reset )
|
|
model.apply( add_flops_counter_hook_function )
|
|
return model
|
|
|
|
|
|
|
|
def compute_average_flops_cost(model):
|
|
"""
|
|
A method that will be available after add_flops_counting_methods() is called on a desired net object.
|
|
Returns current mean flops consumption per image.
|
|
"""
|
|
batches_count = model.__batch_counter__
|
|
flops_sum = 0
|
|
for module in model.modules():
|
|
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
|
|
flops_sum += module.__flops__
|
|
return flops_sum / batches_count
|
|
|
|
|
|
# ---- Internal functions
|
|
def pool_flops_counter_hook(pool_module, inputs, output):
|
|
batch_size = inputs[0].size(0)
|
|
kernel_size = pool_module.kernel_size
|
|
out_C, output_height, output_width = output.shape[1:]
|
|
assert out_C == inputs[0].size(1), '{:} vs. {:}'.format(out_C, inputs[0].size())
|
|
|
|
overall_flops = batch_size * out_C * output_height * output_width * kernel_size * kernel_size
|
|
pool_module.__flops__ += overall_flops
|
|
|
|
|
|
def fc_flops_counter_hook(fc_module, inputs, output):
|
|
batch_size = inputs[0].size(0)
|
|
xin, xout = fc_module.in_features, fc_module.out_features
|
|
assert xin == inputs[0].size(1) and xout == output.size(1), 'IO=({:}, {:})'.format(xin, xout)
|
|
overall_flops = batch_size * xin * xout
|
|
if fc_module.bias is not None:
|
|
overall_flops += batch_size * xout
|
|
fc_module.__flops__ += overall_flops
|
|
|
|
|
|
def conv_flops_counter_hook(conv_module, inputs, output):
|
|
batch_size = inputs[0].size(0)
|
|
output_height, output_width = output.shape[2:]
|
|
|
|
kernel_height, kernel_width = conv_module.kernel_size
|
|
in_channels = conv_module.in_channels
|
|
out_channels = conv_module.out_channels
|
|
groups = conv_module.groups
|
|
conv_per_position_flops = kernel_height * kernel_width * in_channels * out_channels / groups
|
|
|
|
active_elements_count = batch_size * output_height * output_width
|
|
overall_flops = conv_per_position_flops * active_elements_count
|
|
|
|
if conv_module.bias is not None:
|
|
overall_flops += out_channels * active_elements_count
|
|
conv_module.__flops__ += overall_flops
|
|
|
|
|
|
def batch_counter_hook(module, inputs, output):
|
|
# Can have multiple inputs, getting the first one
|
|
inputs = inputs[0]
|
|
batch_size = inputs.shape[0]
|
|
module.__batch_counter__ += batch_size
|
|
|
|
|
|
def add_batch_counter_hook_function(module):
|
|
if not hasattr(module, '__batch_counter_handle__'):
|
|
handle = module.register_forward_hook(batch_counter_hook)
|
|
module.__batch_counter_handle__ = handle
|
|
|
|
|
|
def add_flops_counter_variable_or_reset(module):
|
|
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \
|
|
or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d):
|
|
module.__flops__ = 0
|
|
|
|
|
|
def add_flops_counter_hook_function(module):
|
|
if isinstance(module, torch.nn.Conv2d):
|
|
if not hasattr(module, '__flops_handle__'):
|
|
handle = module.register_forward_hook(conv_flops_counter_hook)
|
|
module.__flops_handle__ = handle
|
|
elif isinstance(module, torch.nn.Linear):
|
|
if not hasattr(module, '__flops_handle__'):
|
|
handle = module.register_forward_hook(fc_flops_counter_hook)
|
|
module.__flops_handle__ = handle
|
|
elif isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d):
|
|
if not hasattr(module, '__flops_handle__'):
|
|
handle = module.register_forward_hook(pool_flops_counter_hook)
|
|
module.__flops_handle__ = handle
|