autodl-projects/xautodl/utils/flop_benchmark.py
2021-05-19 07:23:50 +00:00

219 lines
7.9 KiB
Python

import torch
import torch.nn as nn
import numpy as np
def count_parameters_in_MB(model):
return count_parameters(model, "mb")
def count_parameters(model_or_parameters, unit="mb"):
if isinstance(model_or_parameters, nn.Module):
counts = sum(np.prod(v.size()) for v in model_or_parameters.parameters())
elif isinstance(model_or_parameters, nn.Parameter):
counts = models_or_parameters.numel()
elif isinstance(model_or_parameters, (list, tuple)):
counts = sum(count_parameters(x, None) for x in models_or_parameters)
else:
counts = sum(np.prod(v.size()) for v in model_or_parameters)
if unit.lower() == "kb" or unit.lower() == "k":
counts /= 2 ** 10 # changed from 1e3 to 2^10
elif unit.lower() == "mb" or unit.lower() == "m":
counts /= 2 ** 20 # changed from 1e6 to 2^20
elif unit.lower() == "gb" or unit.lower() == "g":
counts /= 2 ** 30 # changed from 1e9 to 2^30
elif unit is not None:
raise ValueError("Unknow unit: {:}".format(unit))
return counts
def get_model_infos(model, shape):
# model = copy.deepcopy( model )
model = add_flops_counting_methods(model)
# model = model.cuda()
model.eval()
# cache_inputs = torch.zeros(*shape).cuda()
# cache_inputs = torch.zeros(*shape)
cache_inputs = torch.rand(*shape)
if next(model.parameters()).is_cuda:
cache_inputs = cache_inputs.cuda()
# print_log('In the calculating function : cache input size : {:}'.format(cache_inputs.size()), log)
with torch.no_grad():
_____ = model(cache_inputs)
FLOPs = compute_average_flops_cost(model) / 1e6
Param = count_parameters_in_MB(model)
if hasattr(model, "auxiliary_param"):
aux_params = count_parameters_in_MB(model.auxiliary_param())
print("The auxiliary params of this model is : {:}".format(aux_params))
print(
"We remove the auxiliary params from the total params ({:}) when counting".format(
Param
)
)
Param = Param - aux_params
# print_log('FLOPs : {:} MB'.format(FLOPs), log)
torch.cuda.empty_cache()
model.apply(remove_hook_function)
return FLOPs, Param
# ---- Public functions
def add_flops_counting_methods(model):
model.__batch_counter__ = 0
add_batch_counter_hook_function(model)
model.apply(add_flops_counter_variable_or_reset)
model.apply(add_flops_counter_hook_function)
return model
def compute_average_flops_cost(model):
"""
A method that will be available after add_flops_counting_methods() is called on a desired net object.
Returns current mean flops consumption per image.
"""
batches_count = model.__batch_counter__
flops_sum = 0
# or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \
for module in model.modules():
if (
isinstance(module, torch.nn.Conv2d)
or isinstance(module, torch.nn.Linear)
or isinstance(module, torch.nn.Conv1d)
or hasattr(module, "calculate_flop_self")
):
flops_sum += module.__flops__
return flops_sum / batches_count
# ---- Internal functions
def pool_flops_counter_hook(pool_module, inputs, output):
batch_size = inputs[0].size(0)
kernel_size = pool_module.kernel_size
out_C, output_height, output_width = output.shape[1:]
assert out_C == inputs[0].size(1), "{:} vs. {:}".format(out_C, inputs[0].size())
overall_flops = (
batch_size * out_C * output_height * output_width * kernel_size * kernel_size
)
pool_module.__flops__ += overall_flops
def self_calculate_flops_counter_hook(self_module, inputs, output):
overall_flops = self_module.calculate_flop_self(inputs[0].shape, output.shape)
self_module.__flops__ += overall_flops
def fc_flops_counter_hook(fc_module, inputs, output):
batch_size = inputs[0].size(0)
xin, xout = fc_module.in_features, fc_module.out_features
assert xin == inputs[0].size(1) and xout == output.size(1), "IO=({:}, {:})".format(
xin, xout
)
overall_flops = batch_size * xin * xout
if fc_module.bias is not None:
overall_flops += batch_size * xout
fc_module.__flops__ += overall_flops
def conv1d_flops_counter_hook(conv_module, inputs, outputs):
batch_size = inputs[0].size(0)
outL = outputs.shape[-1]
[kernel] = conv_module.kernel_size
in_channels = conv_module.in_channels
out_channels = conv_module.out_channels
groups = conv_module.groups
conv_per_position_flops = kernel * in_channels * out_channels / groups
active_elements_count = batch_size * outL
overall_flops = conv_per_position_flops * active_elements_count
if conv_module.bias is not None:
overall_flops += out_channels * active_elements_count
conv_module.__flops__ += overall_flops
def conv2d_flops_counter_hook(conv_module, inputs, output):
batch_size = inputs[0].size(0)
output_height, output_width = output.shape[2:]
kernel_height, kernel_width = conv_module.kernel_size
in_channels = conv_module.in_channels
out_channels = conv_module.out_channels
groups = conv_module.groups
conv_per_position_flops = (
kernel_height * kernel_width * in_channels * out_channels / groups
)
active_elements_count = batch_size * output_height * output_width
overall_flops = conv_per_position_flops * active_elements_count
if conv_module.bias is not None:
overall_flops += out_channels * active_elements_count
conv_module.__flops__ += overall_flops
def batch_counter_hook(module, inputs, output):
# Can have multiple inputs, getting the first one
inputs = inputs[0]
batch_size = inputs.shape[0]
module.__batch_counter__ += batch_size
def add_batch_counter_hook_function(module):
if not hasattr(module, "__batch_counter_handle__"):
handle = module.register_forward_hook(batch_counter_hook)
module.__batch_counter_handle__ = handle
def add_flops_counter_variable_or_reset(module):
if (
isinstance(module, torch.nn.Conv2d)
or isinstance(module, torch.nn.Linear)
or isinstance(module, torch.nn.Conv1d)
or isinstance(module, torch.nn.AvgPool2d)
or isinstance(module, torch.nn.MaxPool2d)
or hasattr(module, "calculate_flop_self")
):
module.__flops__ = 0
def add_flops_counter_hook_function(module):
if isinstance(module, torch.nn.Conv2d):
if not hasattr(module, "__flops_handle__"):
handle = module.register_forward_hook(conv2d_flops_counter_hook)
module.__flops_handle__ = handle
elif isinstance(module, torch.nn.Conv1d):
if not hasattr(module, "__flops_handle__"):
handle = module.register_forward_hook(conv1d_flops_counter_hook)
module.__flops_handle__ = handle
elif isinstance(module, torch.nn.Linear):
if not hasattr(module, "__flops_handle__"):
handle = module.register_forward_hook(fc_flops_counter_hook)
module.__flops_handle__ = handle
elif isinstance(module, torch.nn.AvgPool2d) or isinstance(
module, torch.nn.MaxPool2d
):
if not hasattr(module, "__flops_handle__"):
handle = module.register_forward_hook(pool_flops_counter_hook)
module.__flops_handle__ = handle
elif hasattr(module, "calculate_flop_self"): # self-defined module
if not hasattr(module, "__flops_handle__"):
handle = module.register_forward_hook(self_calculate_flops_counter_hook)
module.__flops_handle__ = handle
def remove_hook_function(module):
hookers = ["__batch_counter_handle__", "__flops_handle__"]
for hooker in hookers:
if hasattr(module, hooker):
handle = getattr(module, hooker)
handle.remove()
keys = ["__flops__", "__batch_counter__", "__flops__"] + hookers
for ckey in keys:
if hasattr(module, ckey):
delattr(module, ckey)