36 lines
		
	
	
		
			923 B
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			36 lines
		
	
	
		
			923 B
		
	
	
	
		
			Python
		
	
	
	
	
	
| import torch
 | |
| import torch.nn as nn
 | |
| import numpy as np
 | |
| 
 | |
| 
 | |
| def count_parameters_in_MB(model):
 | |
|   if isinstance(model, nn.Module):
 | |
|     return np.sum(np.prod(v.size()) for v in model.parameters())/1e6
 | |
|   else:
 | |
|     return np.sum(np.prod(v.size()) for v in model)/1e6
 | |
| 
 | |
| 
 | |
| class Cutout(object):
 | |
|   def __init__(self, length):
 | |
|     self.length = length
 | |
| 
 | |
|   def __repr__(self):
 | |
|     return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__))
 | |
| 
 | |
|   def __call__(self, img):
 | |
|     h, w = img.size(1), img.size(2)
 | |
|     mask = np.ones((h, w), np.float32)
 | |
|     y = np.random.randint(h)
 | |
|     x = np.random.randint(w)
 | |
| 
 | |
|     y1 = np.clip(y - self.length // 2, 0, h)
 | |
|     y2 = np.clip(y + self.length // 2, 0, h)
 | |
|     x1 = np.clip(x - self.length // 2, 0, w)
 | |
|     x2 = np.clip(x + self.length // 2, 0, w)
 | |
| 
 | |
|     mask[y1: y2, x1: x2] = 0.
 | |
|     mask = torch.from_numpy(mask)
 | |
|     mask = mask.expand_as(img)
 | |
|     img *= mask
 | |
|     return img
 |