import torch.nn as nn
from torch.nn.parameter import Parameter
from typing import Optional

from layers.super_module import SuperModule
from layers.super_module import SuperModule


class SuperLinear(SuperModule):
    """Applies a linear transformation to the incoming data: :math:`y = xA^T + b`"""

    def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
        super(SuperLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter("bias", None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def forward(self, input: Tensor) -> Tensor:
        return F.linear(input, self.weight, self.bias)

    def extra_repr(self) -> str:
        return "in_features={:}, out_features={:}, bias={:}".format(
            self.in_features, self.out_features, self.bias is not None
        )


class SuperMLP(nn.Module):
    # MLP: FC -> Activation -> Drop -> FC -> Drop
    def __init__(
        self,
        in_features,
        hidden_features: Optional[int] = None,
        out_features: Optional[int] = None,
        act_layer=nn.GELU,
        drop: Optional[float] = None,
    ):
        super(MLP, self).__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop or 0)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x