Upgrade spaces and add more tests
This commit is contained in:
		| @@ -1,7 +1,43 @@ | ||||
| import torch.nn as nn | ||||
| from torch.nn.parameter import Parameter | ||||
| from typing import Optional | ||||
|  | ||||
| class MLP(nn.Module): | ||||
|  | ||||
| class Linear(nn.Module): | ||||
|     """Applies a linear transformation to the incoming data: :math:`y = xA^T + b` | ||||
|     """ | ||||
|     __constants__ = ['in_features', 'out_features'] | ||||
|     in_features: int | ||||
|     out_features: int | ||||
|     weight: Tensor | ||||
|  | ||||
|     def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: | ||||
|         super(Linear, self).__init__() | ||||
|         self.in_features = in_features | ||||
|         self.out_features = out_features | ||||
|         self.weight = Parameter(torch.Tensor(out_features, in_features)) | ||||
|         if bias: | ||||
|             self.bias = Parameter(torch.Tensor(out_features)) | ||||
|         else: | ||||
|             self.register_parameter('bias', None) | ||||
|         self.reset_parameters() | ||||
|  | ||||
|     def reset_parameters(self) -> None: | ||||
|         init.kaiming_uniform_(self.weight, a=math.sqrt(5)) | ||||
|         if self.bias is not None: | ||||
|             fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) | ||||
|             bound = 1 / math.sqrt(fan_in) | ||||
|             init.uniform_(self.bias, -bound, bound) | ||||
|  | ||||
|     def forward(self, input: Tensor) -> Tensor: | ||||
|         return F.linear(input, self.weight, self.bias) | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         return 'in_features={}, out_features={}, bias={}'.format( | ||||
|             self.in_features, self.out_features, self.bias is not None | ||||
|         ) | ||||
|  | ||||
| class SuperMLP(nn.Module): | ||||
|   # MLP: FC -> Activation -> Drop -> FC -> Drop | ||||
|   def __init__(self, in_features, hidden_features: Optional[int] = None, | ||||
|                out_features: Optional[int] = None, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user