import torch import torch.nn as nn import numpy as np def count_parameters_in_MB(model): if isinstance(model, nn.Module): return np.sum(np.prod(v.size()) for v in model.parameters())/1e6 else: return np.sum(np.prod(v.size()) for v in model)/1e6 class Cutout(object): def __init__(self, length): self.length = length def __repr__(self): return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__)) def __call__(self, img): h, w = img.size(1), img.size(2) mask = np.ones((h, w), np.float32) y = np.random.randint(h) x = np.random.randint(w) y1 = np.clip(y - self.length // 2, 0, h) y2 = np.clip(y + self.length // 2, 0, h) x1 = np.clip(x - self.length // 2, 0, w) x2 = np.clip(x + self.length // 2, 0, w) mask[y1: y2, x1: x2] = 0. mask = torch.from_numpy(mask) mask = mask.expand_as(img) img *= mask return img