################################################## # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # ################################################## # modified from https://github.com/warmspringwinds/pytorch-segmentation-detection/blob/master/pytorch_segmentation_detection/utils/flops_benchmark.py import copy, torch def print_FLOPs(model, shape, logs): print_log, log = logs model = copy.deepcopy( model ) model = add_flops_counting_methods(model) model = model.cuda() model.eval() cache_inputs = torch.zeros(*shape).cuda() #print_log('In the calculating function : cache input size : {:}'.format(cache_inputs.size()), log) _ = model(cache_inputs) FLOPs = compute_average_flops_cost( model ) / 1e6 print_log('FLOPs : {:} MB'.format(FLOPs), log) torch.cuda.empty_cache() # ---- Public functions def add_flops_counting_methods( model ): model.__batch_counter__ = 0 add_batch_counter_hook_function( model ) model.apply( add_flops_counter_variable_or_reset ) model.apply( add_flops_counter_hook_function ) return model def compute_average_flops_cost(model): """ A method that will be available after add_flops_counting_methods() is called on a desired net object. Returns current mean flops consumption per image. """ batches_count = model.__batch_counter__ flops_sum = 0 for module in model.modules(): if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): flops_sum += module.__flops__ return flops_sum / batches_count # ---- Internal functions def pool_flops_counter_hook(pool_module, inputs, output): batch_size = inputs[0].size(0) kernel_size = pool_module.kernel_size out_C, output_height, output_width = output.shape[1:] assert out_C == inputs[0].size(1), '{:} vs. {:}'.format(out_C, inputs[0].size()) overall_flops = batch_size * out_C * output_height * output_width * kernel_size * kernel_size pool_module.__flops__ += overall_flops def fc_flops_counter_hook(fc_module, inputs, output): batch_size = inputs[0].size(0) xin, xout = fc_module.in_features, fc_module.out_features assert xin == inputs[0].size(1) and xout == output.size(1), 'IO=({:}, {:})'.format(xin, xout) overall_flops = batch_size * xin * xout if fc_module.bias is not None: overall_flops += batch_size * xout fc_module.__flops__ += overall_flops def conv_flops_counter_hook(conv_module, inputs, output): batch_size = inputs[0].size(0) output_height, output_width = output.shape[2:] kernel_height, kernel_width = conv_module.kernel_size in_channels = conv_module.in_channels out_channels = conv_module.out_channels groups = conv_module.groups conv_per_position_flops = kernel_height * kernel_width * in_channels * out_channels / groups active_elements_count = batch_size * output_height * output_width overall_flops = conv_per_position_flops * active_elements_count if conv_module.bias is not None: overall_flops += out_channels * active_elements_count conv_module.__flops__ += overall_flops def batch_counter_hook(module, inputs, output): # Can have multiple inputs, getting the first one inputs = inputs[0] batch_size = inputs.shape[0] module.__batch_counter__ += batch_size def add_batch_counter_hook_function(module): if not hasattr(module, '__batch_counter_handle__'): handle = module.register_forward_hook(batch_counter_hook) module.__batch_counter_handle__ = handle def add_flops_counter_variable_or_reset(module): if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \ or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d): module.__flops__ = 0 def add_flops_counter_hook_function(module): if isinstance(module, torch.nn.Conv2d): if not hasattr(module, '__flops_handle__'): handle = module.register_forward_hook(conv_flops_counter_hook) module.__flops_handle__ = handle elif isinstance(module, torch.nn.Linear): if not hasattr(module, '__flops_handle__'): handle = module.register_forward_hook(fc_flops_counter_hook) module.__flops_handle__ = handle elif isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d): if not hasattr(module, '__flops_handle__'): handle = module.register_forward_hook(pool_flops_counter_hook) module.__flops_handle__ = handle