upload
This commit is contained in:
		
							
								
								
									
										175
									
								
								Layers/layers.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										175
									
								
								Layers/layers.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,175 @@ | ||||
| import math | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from torch.nn import functional as F | ||||
| from torch.nn import init | ||||
| from torch.nn.parameter import Parameter | ||||
| from torch.nn.modules.utils import _pair | ||||
|  | ||||
|  | ||||
| class Linear(nn.Linear): | ||||
|     def __init__(self, in_features, out_features, bias=True): | ||||
|         super(Linear, self).__init__(in_features, out_features, bias)         | ||||
|         self.register_buffer('weight_mask', torch.ones(self.weight.shape)) | ||||
|         self.register_buffer('score', torch.zeros(self.weight.shape)) | ||||
|         if self.bias is not None: | ||||
|             self.register_buffer('bias_mask', torch.ones(self.bias.shape)) | ||||
|          | ||||
|     def forward(self, input): | ||||
|         W = self.weight_mask * self.weight | ||||
|         if self.bias is not None: | ||||
|             b = self.bias_mask * self.bias | ||||
|         else: | ||||
|             b = self.bias | ||||
|         return F.linear(input, W, b) | ||||
|  | ||||
|  | ||||
| class Conv2d(nn.Conv2d): | ||||
|     def __init__(self, in_channels, out_channels, kernel_size, stride=1, | ||||
|                  padding=0, dilation=1, groups=1, | ||||
|                  bias=True, padding_mode='zeros'): | ||||
|         super(Conv2d, self).__init__( | ||||
|             in_channels, out_channels, kernel_size, stride, padding,  | ||||
|             dilation, groups, bias, padding_mode) | ||||
|         self.register_buffer('weight_mask', torch.ones(self.weight.shape)) | ||||
|         self.register_buffer('score', torch.zeros(self.weight.shape)) | ||||
|         if self.bias is not None: | ||||
|             self.register_buffer('bias_mask', torch.ones(self.bias.shape)) | ||||
|  | ||||
|     def _conv_forward(self, input, weight, bias): | ||||
|         if self.padding_mode != 'zeros': | ||||
|             return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode), | ||||
|                             weight, bias, self.stride, | ||||
|                             _pair(0), self.dilation, self.groups) | ||||
|         return F.conv2d(input, weight, bias, self.stride, | ||||
|                         self.padding, self.dilation, self.groups) | ||||
|  | ||||
|     def forward(self, input): | ||||
|         W = self.weight_mask * self.weight | ||||
|         if self.bias is not None: | ||||
|             b = self.bias_mask * self.bias | ||||
|         else: | ||||
|             b = self.bias | ||||
|         return self._conv_forward(input, W, b) | ||||
|  | ||||
|  | ||||
| class BatchNorm1d(nn.BatchNorm1d): | ||||
|     def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, | ||||
|                  track_running_stats=True): | ||||
|         super(BatchNorm1d, self).__init__( | ||||
|             num_features, eps, momentum, affine, track_running_stats) | ||||
|         if self.affine:      | ||||
|             self.register_buffer('weight_mask', torch.ones(self.weight.shape)) | ||||
|             self.register_buffer('bias_mask', torch.ones(self.bias.shape)) | ||||
|             self.register_buffer('score', torch.zeros(self.weight.shape))  | ||||
|     def forward(self, input): | ||||
|         self._check_input_dim(input) | ||||
|  | ||||
|         # exponential_average_factor is set to self.momentum | ||||
|         # (when it is available) only so that if gets updated | ||||
|         # in ONNX graph when this node is exported to ONNX. | ||||
|         if self.momentum is None: | ||||
|             exponential_average_factor = 0.0 | ||||
|         else: | ||||
|             exponential_average_factor = self.momentum | ||||
|  | ||||
|         if self.training and self.track_running_stats: | ||||
|             # TODO: if statement only here to tell the jit to skip emitting this when it is None | ||||
|             if self.num_batches_tracked is not None: | ||||
|                 self.num_batches_tracked = self.num_batches_tracked + 1 | ||||
|                 if self.momentum is None:  # use cumulative moving average | ||||
|                     exponential_average_factor = 1.0 / float(self.num_batches_tracked) | ||||
|                 else:  # use exponential moving average | ||||
|                     exponential_average_factor = self.momentum | ||||
|         if self.affine: | ||||
|             W = self.weight_mask * self.weight | ||||
|             b = self.bias_mask * self.bias | ||||
|         else: | ||||
|             W = self.weight | ||||
|             b = self.bias | ||||
|  | ||||
|         return F.batch_norm( | ||||
|             input, self.running_mean, self.running_var, W, b, | ||||
|             self.training or not self.track_running_stats, | ||||
|             exponential_average_factor, self.eps) | ||||
|  | ||||
|  | ||||
| class BatchNorm2d(nn.BatchNorm2d): | ||||
|     def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, | ||||
|                  track_running_stats=True): | ||||
|         super(BatchNorm2d, self).__init__( | ||||
|             num_features, eps, momentum, affine, track_running_stats) | ||||
|         if self.affine:      | ||||
|             self.register_buffer('weight_mask', torch.ones(self.weight.shape)) | ||||
|             self.register_buffer('bias_mask', torch.ones(self.bias.shape)) | ||||
|             self.register_buffer('score', torch.zeros(self.weight.shape)) | ||||
|     def forward(self, input): | ||||
|         self._check_input_dim(input) | ||||
|  | ||||
|         # exponential_average_factor is set to self.momentum | ||||
|         # (when it is available) only so that if gets updated | ||||
|         # in ONNX graph when this node is exported to ONNX. | ||||
|         if self.momentum is None: | ||||
|             exponential_average_factor = 0.0 | ||||
|         else: | ||||
|             exponential_average_factor = self.momentum | ||||
|  | ||||
|         if self.training and self.track_running_stats: | ||||
|             # TODO: if statement only here to tell the jit to skip emitting this when it is None | ||||
|             if self.num_batches_tracked is not None: | ||||
|                 self.num_batches_tracked = self.num_batches_tracked + 1 | ||||
|                 if self.momentum is None:  # use cumulative moving average | ||||
|                     exponential_average_factor = 1.0 / float(self.num_batches_tracked) | ||||
|                 else:  # use exponential moving average | ||||
|                     exponential_average_factor = self.momentum | ||||
|         if self.affine: | ||||
|             W = self.weight_mask * self.weight | ||||
|             b = self.bias_mask * self.bias | ||||
|         else: | ||||
|             W = self.weight | ||||
|             b = self.bias | ||||
|  | ||||
|         return F.batch_norm( | ||||
|             input, self.running_mean, self.running_var, W, b, | ||||
|             self.training or not self.track_running_stats, | ||||
|             exponential_average_factor, self.eps) | ||||
|  | ||||
|  | ||||
| class Identity1d(nn.Module): | ||||
|     def __init__(self, num_features): | ||||
|         super(Identity1d, self).__init__() | ||||
|         self.num_features = num_features | ||||
|         self.weight = Parameter(torch.Tensor(num_features)) | ||||
|         self.bias = None | ||||
|         self.register_buffer('weight_mask', torch.ones(self.weight.shape)) | ||||
|         self.reset_parameters() | ||||
|         self.register_buffer('score', torch.zeros(self.weight.shape)) | ||||
|  | ||||
|     def reset_parameters(self): | ||||
|         init.ones_(self.weight) | ||||
|  | ||||
|     def forward(self, input): | ||||
|         W = self.weight_mask * self.weight | ||||
|         return input * W | ||||
|  | ||||
|  | ||||
| class Identity2d(nn.Module): | ||||
|     def __init__(self, num_features): | ||||
|         super(Identity2d, self).__init__() | ||||
|         self.num_features = num_features | ||||
|         self.weight = Parameter(torch.Tensor(num_features, 1, 1)) | ||||
|         self.bias = None | ||||
|         self.register_buffer('weight_mask', torch.ones(self.weight.shape)) | ||||
|         self.reset_parameters() | ||||
|         self.register_buffer('score', torch.zeros(self.weight.shape)) | ||||
|  | ||||
|     def reset_parameters(self): | ||||
|         init.ones_(self.weight) | ||||
|  | ||||
|     def forward(self, input): | ||||
|         W = self.weight_mask * self.weight | ||||
|         return input * W | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
		Reference in New Issue
	
	Block a user