autodl-projects/lib/xlayers/super_transformer.py

123 lines
4.3 KiB
Python
Raw Normal View History

2021-03-20 15:28:23 +01:00
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
from __future__ import division
from __future__ import print_function
import math
from functools import partial
from typing import Optional, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
import spaces
from .super_module import IntSpaceType
from .super_module import BoolSpaceType
2021-03-23 12:13:51 +01:00
from .super_module import LayerOrder
2021-03-20 15:28:23 +01:00
from .super_module import SuperModule
from .super_linear import SuperMLPv2
from .super_norm import SuperLayerNorm1D
from .super_attention import SuperAttention
class SuperTransformerEncoderLayer(SuperModule):
"""TransformerEncoderLayer is made up of self-attn and feedforward network.
This is a super model for TransformerEncoderLayer that can support search for the transformer encoder layer.
Reference:
- Paper: Attention Is All You Need, NeurIPS 2017
- PyTorch Implementation: https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
Details:
2021-03-23 12:13:51 +01:00
the original post-norm version: MHA -> residual -> norm -> MLP -> residual -> norm
the pre-norm version: norm -> MHA -> residual -> norm -> MLP -> residual
2021-03-20 15:28:23 +01:00
"""
def __init__(
self,
2021-03-24 13:33:52 +01:00
d_model: IntSpaceType,
2021-03-20 15:28:23 +01:00
num_heads: IntSpaceType,
qkv_bias: BoolSpaceType = False,
mlp_hidden_multiplier: IntSpaceType = 4,
drop: Optional[float] = None,
act_layer: Callable[[], nn.Module] = nn.GELU,
2021-03-23 12:13:51 +01:00
order: LayerOrder = LayerOrder.PreNorm,
2021-03-20 15:28:23 +01:00
):
super(SuperTransformerEncoderLayer, self).__init__()
2021-03-23 12:13:51 +01:00
mha = SuperAttention(
2021-03-24 13:33:52 +01:00
d_model,
d_model,
2021-03-20 15:28:23 +01:00
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=drop,
proj_drop=drop,
)
2021-03-23 12:13:51 +01:00
mlp = SuperMLPv2(
2021-03-24 13:33:52 +01:00
d_model,
2021-03-20 15:28:23 +01:00
hidden_multiplier=mlp_hidden_multiplier,
2021-03-24 13:33:52 +01:00
out_features=d_model,
2021-03-20 15:28:23 +01:00
act_layer=act_layer,
drop=drop,
)
2021-03-23 12:13:51 +01:00
if order is LayerOrder.PreNorm:
2021-03-24 13:33:52 +01:00
self.norm1 = SuperLayerNorm1D(d_model)
2021-03-23 12:13:51 +01:00
self.mha = mha
2021-03-24 13:33:52 +01:00
self.drop1 = nn.Dropout(drop or 0.0)
self.norm2 = SuperLayerNorm1D(d_model)
2021-03-23 12:13:51 +01:00
self.mlp = mlp
2021-03-24 13:33:52 +01:00
self.drop2 = nn.Dropout(drop or 0.0)
elif order is LayerOrder.PostNorm:
2021-03-23 12:13:51 +01:00
self.mha = mha
2021-03-24 13:33:52 +01:00
self.drop1 = nn.Dropout(drop or 0.0)
self.norm1 = SuperLayerNorm1D(d_model)
2021-03-23 12:13:51 +01:00
self.mlp = mlp
2021-03-24 13:33:52 +01:00
self.drop2 = nn.Dropout(drop or 0.0)
self.norm2 = SuperLayerNorm1D(d_model)
2021-03-23 12:13:51 +01:00
else:
raise ValueError("Unknown order: {:}".format(order))
2021-03-24 13:33:52 +01:00
self._order = order
2021-03-20 15:28:23 +01:00
@property
def abstract_search_space(self):
root_node = spaces.VirtualNode(id(self))
xdict = dict(
mha=self.mha.abstract_search_space,
norm1=self.norm1.abstract_search_space,
mlp=self.mlp.abstract_search_space,
norm2=self.norm2.abstract_search_space,
)
for key, space in xdict.items():
if not spaces.is_determined(space):
root_node.append(key, space)
return root_node
def apply_candidate(self, abstract_child: spaces.VirtualNode):
super(SuperTransformerEncoderLayer, self).apply_candidate(abstract_child)
valid_keys = ["mha", "norm1", "mlp", "norm2"]
for key in valid_keys:
if key in abstract_child:
getattr(self, key).apply_candidate(abstract_child[key])
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
return self.forward_raw(input)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
2021-03-24 13:33:52 +01:00
if self._order is LayerOrder.PreNorm:
2021-03-23 12:13:51 +01:00
x = self.norm1(input)
x = x + self.drop1(self.mha(x))
x = self.norm2(x)
x = x + self.drop2(self.mlp(x))
2021-03-24 13:33:52 +01:00
elif self._order is LayerOrder.PostNorm:
2021-03-23 12:13:51 +01:00
# multi-head attention
2021-03-24 13:33:52 +01:00
x = self.mha(input)
x = x + self.drop1(x)
2021-03-23 12:13:51 +01:00
x = self.norm1(x)
# feed-forward layer
x = x + self.drop2(self.mlp(x))
x = self.norm2(x)
else:
2021-03-24 13:33:52 +01:00
raise ValueError("Unknown order: {:}".format(self._order))
2021-03-20 15:28:23 +01:00
return x