# Copyright 2021 Samsung Electronics Co., Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= import torch import torch.nn as nn import torch.nn.functional as F from .p_utils import * from . import measures import types import copy def no_op(self,x): return x def copynet(self, bn): net = copy.deepcopy(self) if bn==False: for l in net.modules(): if isinstance(l,nn.BatchNorm2d) or isinstance(l,nn.BatchNorm1d) : l.forward = types.MethodType(no_op, l) return net def find_measures_arrays(net_orig, trainloader, dataload_info, device, measure_names=None, loss_fn=F.cross_entropy): if measure_names is None: measure_names = measures.available_measures dataload, num_imgs_or_batches, num_classes = dataload_info if not hasattr(net_orig,'get_prunable_copy'): net_orig.get_prunable_copy = types.MethodType(copynet, net_orig) #move to cpu to free up mem torch.cuda.empty_cache() net_orig = net_orig.cpu() torch.cuda.empty_cache() #given 1 minibatch of data if dataload == 'random': inputs, targets = get_some_data(trainloader, num_batches=num_imgs_or_batches, device=device) elif dataload == 'grasp': inputs, targets = get_some_data_grasp(trainloader, num_classes, samples_per_class=num_imgs_or_batches, device=device) else: raise NotImplementedError(f'dataload {dataload} is not supported') done, ds = False, 1 measure_values = {} while not done: try: for measure_name in measure_names: if measure_name not in measure_values: val = measures.calc_measure(measure_name, net_orig, device, inputs, targets, loss_fn=loss_fn, split_data=ds) measure_values[measure_name] = val done = True except RuntimeError as e: if 'out of memory' in str(e): done=False if ds == inputs.shape[0]//2: raise ValueError(f'Can\'t split data anymore, but still unable to run. Something is wrong') ds += 1 while inputs.shape[0] % ds != 0: ds += 1 torch.cuda.empty_cache() print(f'Caught CUDA OOM, retrying with data split into {ds} parts') else: raise e net_orig = net_orig.to(device).train() return measure_values def find_measures(net_orig, # neural network dataloader, # a data loader (typically for training data) dataload_info, # a tuple with (dataload_type = {random, grasp}, number_of_batches_for_random_or_images_per_class_for_grasp, number of classes) device, # GPU/CPU device used loss_fn=F.cross_entropy, # loss function to use within the zero-cost metrics measure_names=None, # an array of measure names to compute, if left blank, all measures are computed by default measures_arr=None): # [not used] if the measures are already computed but need to be summarized, pass them here #Given a neural net #and some information about the input data (dataloader) #and loss function (loss_fn) #this function returns an array of zero-cost proxy metrics. def sum_arr(arr): sum = 0. for i in range(len(arr)): sum += torch.sum(arr[i]) return sum.item() if measures_arr is None: measures_arr = find_measures_arrays(net_orig, dataloader, dataload_info, device, loss_fn=loss_fn, measure_names=measure_names) measures = {} for k,v in measures_arr.items(): if k in ['jacob_cov', 'var', 'cor', 'norm', 'meco', 'zico', 'ntk', 'gradsign', 'zen']: measures[k] = v else: measures[k] = sum_arr(v) return measures