import torch
import torch.nn as nn
import torch.nn.functional as F
from .initialization import initialize_resnet


class WideBasicblock(nn.Module):
    def __init__(self, inplanes, planes, stride, dropout=False):
        super(WideBasicblock, self).__init__()

        self.bn_a = nn.BatchNorm2d(inplanes)
        self.conv_a = nn.Conv2d(
            inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False
        )

        self.bn_b = nn.BatchNorm2d(planes)
        if dropout:
            self.dropout = nn.Dropout2d(p=0.5, inplace=True)
        else:
            self.dropout = None
        self.conv_b = nn.Conv2d(
            planes, planes, kernel_size=3, stride=1, padding=1, bias=False
        )

        if inplanes != planes:
            self.downsample = nn.Conv2d(
                inplanes, planes, kernel_size=1, stride=stride, padding=0, bias=False
            )
        else:
            self.downsample = None

    def forward(self, x):

        basicblock = self.bn_a(x)
        basicblock = F.relu(basicblock)
        basicblock = self.conv_a(basicblock)

        basicblock = self.bn_b(basicblock)
        basicblock = F.relu(basicblock)
        if self.dropout is not None:
            basicblock = self.dropout(basicblock)
        basicblock = self.conv_b(basicblock)

        if self.downsample is not None:
            x = self.downsample(x)

        return x + basicblock


class CifarWideResNet(nn.Module):
    """
    ResNet optimized for the Cifar dataset, as specified in
    https://arxiv.org/abs/1512.03385.pdf
    """

    def __init__(self, depth, widen_factor, num_classes, dropout):
        super(CifarWideResNet, self).__init__()

        # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
        assert (depth - 4) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
        layer_blocks = (depth - 4) // 6
        print(
            "CifarPreResNet : Depth : {} , Layers for each block : {}".format(
                depth, layer_blocks
            )
        )

        self.num_classes = num_classes
        self.dropout = dropout
        self.conv_3x3 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)

        self.message = "Wide ResNet : depth={:}, widen_factor={:}, class={:}".format(
            depth, widen_factor, num_classes
        )
        self.inplanes = 16
        self.stage_1 = self._make_layer(
            WideBasicblock, 16 * widen_factor, layer_blocks, 1
        )
        self.stage_2 = self._make_layer(
            WideBasicblock, 32 * widen_factor, layer_blocks, 2
        )
        self.stage_3 = self._make_layer(
            WideBasicblock, 64 * widen_factor, layer_blocks, 2
        )
        self.lastact = nn.Sequential(
            nn.BatchNorm2d(64 * widen_factor), nn.ReLU(inplace=True)
        )
        self.avgpool = nn.AvgPool2d(8)
        self.classifier = nn.Linear(64 * widen_factor, num_classes)

        self.apply(initialize_resnet)

    def get_message(self):
        return self.message

    def _make_layer(self, block, planes, blocks, stride):

        layers = []
        layers.append(block(self.inplanes, planes, stride, self.dropout))
        self.inplanes = planes
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, 1, self.dropout))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv_3x3(x)
        x = self.stage_1(x)
        x = self.stage_2(x)
        x = self.stage_3(x)
        x = self.lastact(x)
        x = self.avgpool(x)
        features = x.view(x.size(0), -1)
        outs = self.classifier(features)
        return features, outs