2021-03-20 08:56:37 +01:00
|
|
|
#####################################################
|
|
|
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
|
|
|
#####################################################
|
|
|
|
import math
|
|
|
|
from typing import Optional, Text
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
2021-05-19 07:00:33 +02:00
|
|
|
from xautodl import spaces
|
2021-03-20 08:56:37 +01:00
|
|
|
from .super_module import SuperModule
|
|
|
|
from .super_module import IntSpaceType
|
|
|
|
from .super_module import BoolSpaceType
|
2021-06-09 14:39:35 +02:00
|
|
|
from .super_dropout import SuperDropout, SuperDrop
|
2021-03-20 08:56:37 +01:00
|
|
|
from .super_linear import SuperLinear
|
|
|
|
|
|
|
|
|
2021-05-22 10:41:54 +02:00
|
|
|
class SuperSelfAttention(SuperModule):
|
2021-03-20 08:56:37 +01:00
|
|
|
"""The super model for attention layer."""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
input_dim: IntSpaceType,
|
2021-06-09 14:39:35 +02:00
|
|
|
proj_dim: Optional[IntSpaceType],
|
2021-03-20 08:56:37 +01:00
|
|
|
num_heads: IntSpaceType,
|
|
|
|
qkv_bias: BoolSpaceType = False,
|
2021-03-20 15:28:23 +01:00
|
|
|
attn_drop: Optional[float] = None,
|
|
|
|
proj_drop: Optional[float] = None,
|
2021-05-22 17:04:24 +02:00
|
|
|
use_mask=False,
|
2021-03-20 08:56:37 +01:00
|
|
|
):
|
2021-05-22 10:41:54 +02:00
|
|
|
super(SuperSelfAttention, self).__init__()
|
2021-03-20 08:56:37 +01:00
|
|
|
self._input_dim = input_dim
|
|
|
|
self._proj_dim = proj_dim
|
|
|
|
self._num_heads = num_heads
|
|
|
|
self._qkv_bias = qkv_bias
|
2021-05-22 17:04:24 +02:00
|
|
|
self._use_mask = use_mask
|
|
|
|
self._infinity = 1e9
|
2021-03-20 08:56:37 +01:00
|
|
|
|
2021-06-09 17:08:21 +02:00
|
|
|
mul_head_dim = (
|
|
|
|
spaces.get_max(input_dim) // spaces.get_min(num_heads)
|
|
|
|
) * spaces.get_min(num_heads)
|
|
|
|
assert mul_head_dim == spaces.get_max(input_dim)
|
|
|
|
self.q_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
|
|
|
|
self.k_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
|
|
|
|
self.v_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
|
|
|
|
|
|
|
|
self.attn_drop = SuperDrop(attn_drop or 0.0, [-1, -1, -1, -1], recover=True)
|
2021-06-10 11:11:27 +02:00
|
|
|
if proj_dim is not None:
|
2021-06-09 14:39:35 +02:00
|
|
|
self.proj = SuperLinear(input_dim, proj_dim)
|
|
|
|
self.proj_drop = SuperDropout(proj_drop or 0.0)
|
|
|
|
else:
|
|
|
|
self.proj = None
|
2021-03-20 08:56:37 +01:00
|
|
|
|
|
|
|
@property
|
|
|
|
def num_heads(self):
|
|
|
|
return spaces.get_max(self._num_heads)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def input_dim(self):
|
|
|
|
return spaces.get_max(self._input_dim)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def proj_dim(self):
|
|
|
|
return spaces.get_max(self._proj_dim)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def abstract_search_space(self):
|
|
|
|
root_node = spaces.VirtualNode(id(self))
|
|
|
|
space_q = self.q_fc.abstract_search_space
|
|
|
|
space_k = self.k_fc.abstract_search_space
|
|
|
|
space_v = self.v_fc.abstract_search_space
|
|
|
|
if not spaces.is_determined(self._num_heads):
|
|
|
|
root_node.append("_num_heads", self._num_heads.abstract(reuse_last=True))
|
|
|
|
if not spaces.is_determined(space_q):
|
|
|
|
root_node.append("q_fc", space_q)
|
|
|
|
if not spaces.is_determined(space_k):
|
|
|
|
root_node.append("k_fc", space_k)
|
|
|
|
if not spaces.is_determined(space_v):
|
|
|
|
root_node.append("v_fc", space_v)
|
2021-06-09 14:39:35 +02:00
|
|
|
if self.proj is not None:
|
|
|
|
space_proj = self.proj.abstract_search_space
|
|
|
|
if not spaces.is_determined(space_proj):
|
|
|
|
root_node.append("proj", space_proj)
|
2021-03-20 08:56:37 +01:00
|
|
|
return root_node
|
|
|
|
|
|
|
|
def apply_candidate(self, abstract_child: spaces.VirtualNode):
|
2021-05-22 13:02:29 +02:00
|
|
|
super(SuperSelfAttention, self).apply_candidate(abstract_child)
|
2021-03-20 08:56:37 +01:00
|
|
|
if "q_fc" in abstract_child:
|
|
|
|
self.q_fc.apply_candidate(abstract_child["q_fc"])
|
|
|
|
if "k_fc" in abstract_child:
|
|
|
|
self.k_fc.apply_candidate(abstract_child["k_fc"])
|
|
|
|
if "v_fc" in abstract_child:
|
|
|
|
self.v_fc.apply_candidate(abstract_child["v_fc"])
|
|
|
|
if "proj" in abstract_child:
|
|
|
|
self.proj.apply_candidate(abstract_child["proj"])
|
|
|
|
|
|
|
|
def forward_qkv(self, input: torch.Tensor, num_head: int) -> torch.Tensor:
|
|
|
|
B, N, C = input.shape
|
|
|
|
q = self.q_fc(input)
|
|
|
|
k = self.k_fc(input)
|
|
|
|
v = self.v_fc(input)
|
|
|
|
if num_head > C:
|
|
|
|
raise ValueError("Invalid num_head [{:}] vs C [{:}]".format(num_head, C))
|
|
|
|
head_dim = C // num_head
|
|
|
|
# process the first [num_head * head_dim] part
|
|
|
|
q_v1 = (
|
|
|
|
q[:, :, : num_head * head_dim]
|
|
|
|
.reshape(B, N, num_head, head_dim)
|
|
|
|
.permute(0, 2, 1, 3)
|
|
|
|
)
|
|
|
|
k_v1 = (
|
|
|
|
k[:, :, : num_head * head_dim]
|
|
|
|
.reshape(B, N, num_head, head_dim)
|
|
|
|
.permute(0, 2, 1, 3)
|
|
|
|
)
|
|
|
|
v_v1 = (
|
|
|
|
v[:, :, : num_head * head_dim]
|
|
|
|
.reshape(B, N, num_head, head_dim)
|
|
|
|
.permute(0, 2, 1, 3)
|
|
|
|
)
|
|
|
|
attn_v1 = (q_v1 @ k_v1.transpose(-2, -1)) * math.sqrt(head_dim)
|
2021-05-22 17:04:24 +02:00
|
|
|
if self._use_mask:
|
|
|
|
mask = torch.triu(
|
|
|
|
torch.ones((N, N), dtype=torch.bool, device=input.device), 1
|
|
|
|
)
|
|
|
|
mask = torch.unsqueeze(torch.unsqueeze(mask, dim=0), dim=0)
|
|
|
|
attn_v1 = attn_v1.masked_fill(mask, -self._infinity)
|
2021-04-02 09:40:26 +02:00
|
|
|
attn_v1 = attn_v1.softmax(dim=-1) # B * #head * N * N
|
2021-03-20 08:56:37 +01:00
|
|
|
attn_v1 = self.attn_drop(attn_v1)
|
|
|
|
feats_v1 = (attn_v1 @ v_v1).permute(0, 2, 1, 3).reshape(B, N, -1)
|
2021-06-09 17:08:21 +02:00
|
|
|
if C == head_dim * num_head:
|
|
|
|
feats = feats_v1
|
|
|
|
else: # The channels can not be divided by num_head, the remainder forms an additional head
|
|
|
|
q_v2 = q[:, :, num_head * head_dim :]
|
|
|
|
k_v2 = k[:, :, num_head * head_dim :]
|
|
|
|
v_v2 = v[:, :, num_head * head_dim :]
|
|
|
|
attn_v2 = (q_v2 @ k_v2.transpose(-2, -1)) * math.sqrt(q_v2.shape[-1])
|
|
|
|
attn_v2 = attn_v2.softmax(dim=-1)
|
|
|
|
attn_v2 = self.attn_drop(attn_v2)
|
|
|
|
feats_v2 = attn_v2 @ v_v2
|
|
|
|
feats = torch.cat([feats_v1, feats_v2], dim=-1)
|
|
|
|
return feats
|
2021-03-20 08:56:37 +01:00
|
|
|
|
|
|
|
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
|
|
|
|
# check the num_heads:
|
|
|
|
if not spaces.is_determined(self._num_heads):
|
|
|
|
num_heads = self.abstract_child["_num_heads"].value
|
|
|
|
else:
|
|
|
|
num_heads = spaces.get_determined_value(self._num_heads)
|
|
|
|
feats = self.forward_qkv(input, num_heads)
|
2021-06-09 14:39:35 +02:00
|
|
|
if self.proj is None:
|
|
|
|
return feats
|
|
|
|
else:
|
|
|
|
outs = self.proj(feats)
|
|
|
|
outs = self.proj_drop(outs)
|
|
|
|
return outs
|
2021-03-20 08:56:37 +01:00
|
|
|
|
|
|
|
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
|
|
|
feats = self.forward_qkv(input, self.num_heads)
|
2021-06-09 14:39:35 +02:00
|
|
|
if self.proj is None:
|
|
|
|
return feats
|
|
|
|
else:
|
|
|
|
outs = self.proj(feats)
|
|
|
|
outs = self.proj_drop(outs)
|
|
|
|
return outs
|
2021-03-20 08:56:37 +01:00
|
|
|
|
|
|
|
def extra_repr(self) -> str:
|
2021-05-22 17:04:24 +02:00
|
|
|
return (
|
|
|
|
"input_dim={:}, proj_dim={:}, num_heads={:}, mask={:}, infinity={:}".format(
|
|
|
|
self._input_dim,
|
|
|
|
self._proj_dim,
|
|
|
|
self._num_heads,
|
|
|
|
self._use_mask,
|
|
|
|
self._infinity,
|
|
|
|
)
|
2021-03-20 08:56:37 +01:00
|
|
|
)
|
2021-05-22 10:41:54 +02:00
|
|
|
|
|
|
|
|
|
|
|
class SuperQKVAttention(SuperModule):
|
|
|
|
"""The super model for attention layer."""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
in_q_dim: IntSpaceType,
|
|
|
|
in_k_dim: IntSpaceType,
|
|
|
|
in_v_dim: IntSpaceType,
|
|
|
|
proj_dim: IntSpaceType,
|
|
|
|
num_heads: IntSpaceType,
|
|
|
|
qkv_bias: BoolSpaceType = False,
|
|
|
|
attn_drop: Optional[float] = None,
|
|
|
|
proj_drop: Optional[float] = None,
|
|
|
|
):
|
|
|
|
super(SuperQKVAttention, self).__init__()
|
|
|
|
self._in_v_dim = in_v_dim
|
|
|
|
self._in_q_dim = in_q_dim
|
|
|
|
self._in_k_dim = in_k_dim
|
|
|
|
self._proj_dim = proj_dim
|
|
|
|
self._num_heads = num_heads
|
|
|
|
self._qkv_bias = qkv_bias
|
|
|
|
|
|
|
|
self.q_fc = SuperLinear(in_q_dim, proj_dim, bias=qkv_bias)
|
|
|
|
self.k_fc = SuperLinear(in_k_dim, proj_dim, bias=qkv_bias)
|
|
|
|
self.v_fc = SuperLinear(in_v_dim, proj_dim, bias=qkv_bias)
|
|
|
|
|
|
|
|
self.attn_drop = nn.Dropout(attn_drop or 0.0)
|
|
|
|
self.proj = SuperLinear(proj_dim, proj_dim)
|
|
|
|
self.proj_drop = nn.Dropout(proj_drop or 0.0)
|
2021-05-22 17:04:24 +02:00
|
|
|
self._infinity = 1e9
|
2021-05-22 10:41:54 +02:00
|
|
|
|
|
|
|
@property
|
|
|
|
def num_heads(self):
|
|
|
|
return spaces.get_max(self._num_heads)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def in_v_dim(self):
|
|
|
|
return spaces.get_max(self._in_v_dim)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def in_q_dim(self):
|
|
|
|
return spaces.get_max(self._in_q_dim)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def in_k_dim(self):
|
|
|
|
return spaces.get_max(self._in_k_dim)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def proj_dim(self):
|
|
|
|
return spaces.get_max(self._proj_dim)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def abstract_search_space(self):
|
|
|
|
root_node = spaces.VirtualNode(id(self))
|
|
|
|
space_q = self.q_fc.abstract_search_space
|
|
|
|
space_k = self.k_fc.abstract_search_space
|
|
|
|
space_v = self.v_fc.abstract_search_space
|
|
|
|
space_proj = self.proj.abstract_search_space
|
|
|
|
if not spaces.is_determined(self._num_heads):
|
|
|
|
root_node.append("_num_heads", self._num_heads.abstract(reuse_last=True))
|
|
|
|
if not spaces.is_determined(space_q):
|
|
|
|
root_node.append("q_fc", space_q)
|
|
|
|
if not spaces.is_determined(space_k):
|
|
|
|
root_node.append("k_fc", space_k)
|
|
|
|
if not spaces.is_determined(space_v):
|
|
|
|
root_node.append("v_fc", space_v)
|
|
|
|
if not spaces.is_determined(space_proj):
|
|
|
|
root_node.append("proj", space_proj)
|
|
|
|
return root_node
|
|
|
|
|
|
|
|
def apply_candidate(self, abstract_child: spaces.VirtualNode):
|
2021-05-23 10:21:31 +02:00
|
|
|
super(SuperQKVAttention, self).apply_candidate(abstract_child)
|
2021-05-22 10:41:54 +02:00
|
|
|
if "q_fc" in abstract_child:
|
|
|
|
self.q_fc.apply_candidate(abstract_child["q_fc"])
|
|
|
|
if "k_fc" in abstract_child:
|
|
|
|
self.k_fc.apply_candidate(abstract_child["k_fc"])
|
|
|
|
if "v_fc" in abstract_child:
|
|
|
|
self.v_fc.apply_candidate(abstract_child["v_fc"])
|
|
|
|
if "proj" in abstract_child:
|
|
|
|
self.proj.apply_candidate(abstract_child["proj"])
|
|
|
|
|
2021-05-22 17:04:24 +02:00
|
|
|
def forward_qkv(
|
|
|
|
self, q_tensor, k_tensor, v_tensor, num_head: int, mask=None
|
|
|
|
) -> torch.Tensor:
|
2021-05-22 10:41:54 +02:00
|
|
|
q = self.q_fc(q_tensor)
|
|
|
|
B, N, C = q.shape
|
|
|
|
|
|
|
|
k = self.k_fc(k_tensor)
|
|
|
|
B0, S, _ = k.shape
|
|
|
|
|
|
|
|
v = self.v_fc(v_tensor)
|
|
|
|
assert B0 == v.shape[0] and S == v.shape[1]
|
|
|
|
|
|
|
|
head_dim = C // num_head
|
|
|
|
if num_head > C:
|
|
|
|
raise ValueError("Invalid num_head [{:}] vs C [{:}]".format(num_head, C))
|
|
|
|
q_v1 = (
|
|
|
|
q[:, :, : num_head * head_dim]
|
|
|
|
.reshape(B, N, num_head, head_dim)
|
|
|
|
.permute(0, 2, 1, 3)
|
|
|
|
)
|
|
|
|
k_v1 = (
|
|
|
|
k[:, :, : num_head * head_dim]
|
|
|
|
.reshape(B0, S, num_head, head_dim)
|
|
|
|
.permute(0, 2, 1, 3)
|
|
|
|
)
|
|
|
|
# compute the attention map
|
|
|
|
attn_v1 = (q_v1 @ k_v1.transpose(-2, -1)) * math.sqrt(head_dim)
|
2021-05-22 17:04:24 +02:00
|
|
|
if mask is not None:
|
|
|
|
mask = torch.unsqueeze(mask, dim=1)
|
|
|
|
attn_v1 = attn_v1.masked_fill(mask, -self._infinity)
|
2021-05-22 10:41:54 +02:00
|
|
|
attn_v1 = attn_v1.softmax(dim=-1) # B * #head * N * S
|
|
|
|
attn_v1 = self.attn_drop(attn_v1)
|
|
|
|
|
|
|
|
v_v1 = (
|
|
|
|
v[:, :, : num_head * head_dim]
|
|
|
|
.reshape(B0, S, num_head, head_dim)
|
|
|
|
.permute(0, 2, 1, 3)
|
|
|
|
)
|
|
|
|
feats_v1 = (attn_v1 @ v_v1).permute(0, 2, 1, 3).reshape(B, N, -1)
|
|
|
|
# process the first [num_head * head_dim] part
|
|
|
|
if C == head_dim * num_head:
|
|
|
|
feats = feats_v1
|
|
|
|
else: # The channels can not be divided by num_head, the remainder forms an additional head
|
|
|
|
# [might have bugs, did not check yet]
|
|
|
|
q_v2 = q[:, :, num_head * head_dim :]
|
|
|
|
k_v2 = k[:, :, num_head * head_dim :]
|
|
|
|
v_v2 = v[:, :, num_head * head_dim :]
|
|
|
|
attn_v2 = (q_v2 @ k_v2.transpose(-2, -1)) * math.sqrt(q_v2.shape[-1])
|
|
|
|
attn_v2 = attn_v2.softmax(dim=-1)
|
|
|
|
attn_v2 = self.attn_drop(attn_v2)
|
|
|
|
feats_v2 = attn_v2 @ v_v2
|
|
|
|
feats = torch.cat([feats_v1, feats_v2], dim=-1)
|
|
|
|
return feats
|
|
|
|
|
2021-05-22 17:04:24 +02:00
|
|
|
def forward_candidate(
|
|
|
|
self, q_tensor, k_tensor, v_tensor, mask=None
|
|
|
|
) -> torch.Tensor:
|
2021-05-22 10:41:54 +02:00
|
|
|
# check the num_heads:
|
|
|
|
if not spaces.is_determined(self._num_heads):
|
|
|
|
num_heads = self.abstract_child["_num_heads"].value
|
|
|
|
else:
|
|
|
|
num_heads = spaces.get_determined_value(self._num_heads)
|
2021-05-22 17:04:24 +02:00
|
|
|
feats = self.forward_qkv(q_tensor, k_tensor, v_tensor, num_heads, mask)
|
2021-05-22 10:41:54 +02:00
|
|
|
outs = self.proj(feats)
|
|
|
|
outs = self.proj_drop(outs)
|
|
|
|
return outs
|
|
|
|
|
2021-05-22 17:04:24 +02:00
|
|
|
def forward_raw(self, q_tensor, k_tensor, v_tensor, mask=None) -> torch.Tensor:
|
|
|
|
feats = self.forward_qkv(q_tensor, k_tensor, v_tensor, self.num_heads, mask)
|
2021-05-22 10:41:54 +02:00
|
|
|
outs = self.proj(feats)
|
|
|
|
outs = self.proj_drop(outs)
|
|
|
|
return outs
|
|
|
|
|
|
|
|
def extra_repr(self) -> str:
|
2021-05-22 17:04:24 +02:00
|
|
|
return "input_dim={:}, proj_dim={:}, num_heads={:}, infinity={:}".format(
|
2021-05-22 10:41:54 +02:00
|
|
|
(self.in_q_dim, self.in_k_dim, self.in_v_dim),
|
|
|
|
self._proj_dim,
|
|
|
|
self._num_heads,
|
2021-05-22 17:04:24 +02:00
|
|
|
self._infinity,
|
2021-05-22 10:41:54 +02:00
|
|
|
)
|