36 lines
923 B
Python
36 lines
923 B
Python
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
|
|
def count_parameters_in_MB(model):
|
|
if isinstance(model, nn.Module):
|
|
return np.sum(np.prod(v.size()) for v in model.parameters())/1e6
|
|
else:
|
|
return np.sum(np.prod(v.size()) for v in model)/1e6
|
|
|
|
|
|
class Cutout(object):
|
|
def __init__(self, length):
|
|
self.length = length
|
|
|
|
def __repr__(self):
|
|
return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__))
|
|
|
|
def __call__(self, img):
|
|
h, w = img.size(1), img.size(2)
|
|
mask = np.ones((h, w), np.float32)
|
|
y = np.random.randint(h)
|
|
x = np.random.randint(w)
|
|
|
|
y1 = np.clip(y - self.length // 2, 0, h)
|
|
y2 = np.clip(y + self.length // 2, 0, h)
|
|
x1 = np.clip(x - self.length // 2, 0, w)
|
|
x2 = np.clip(x + self.length // 2, 0, w)
|
|
|
|
mask[y1: y2, x1: x2] = 0.
|
|
mask = torch.from_numpy(mask)
|
|
mask = mask.expand_as(img)
|
|
img *= mask
|
|
return img
|