Add SuperAttention
This commit is contained in:
		
							
								
								
									
										0
									
								
								lib/trade_models/naive_v1_model.py
									
									
									
									
									
										
										
										Executable file → Normal file
									
								
							
							
						
						
									
										0
									
								
								lib/trade_models/naive_v1_model.py
									
									
									
									
									
										
										
										Executable file → Normal file
									
								
							
							
								
								
									
										0
									
								
								lib/trade_models/naive_v2_model.py
									
									
									
									
									
										
										
										Executable file → Normal file
									
								
							
							
						
						
									
										0
									
								
								lib/trade_models/naive_v2_model.py
									
									
									
									
									
										
										
										Executable file → Normal file
									
								
							
							
								
								
									
										0
									
								
								lib/trade_models/quant_transformer.py
									
									
									
									
									
										
										
										Executable file → Normal file
									
								
							
							
						
						
									
										0
									
								
								lib/trade_models/quant_transformer.py
									
									
									
									
									
										
										
										Executable file → Normal file
									
								
							
							
								
								
									
										6
									
								
								lib/trade_models/transformers.py
									
									
									
									
									
										
										
										Executable file → Normal file
									
								
							
							
						
						
									
										6
									
								
								lib/trade_models/transformers.py
									
									
									
									
									
										
										
										Executable file → Normal file
									
								
							| @@ -1,6 +1,6 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021 # | ||||
| ################################################## | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
|  | ||||
|   | ||||
							
								
								
									
										155
									
								
								lib/xlayers/super_attention.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										155
									
								
								lib/xlayers/super_attention.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,155 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
|  | ||||
| import math | ||||
| from functools import partial | ||||
| from typing import Optional, Text | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
|  | ||||
|  | ||||
| import spaces | ||||
| from .super_module import SuperModule | ||||
| from .super_module import IntSpaceType | ||||
| from .super_module import BoolSpaceType | ||||
| from .super_linear import SuperLinear | ||||
|  | ||||
|  | ||||
| class SuperAttention(SuperModule): | ||||
|     """The super model for attention layer.""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         input_dim: IntSpaceType, | ||||
|         proj_dim: IntSpaceType, | ||||
|         num_heads: IntSpaceType, | ||||
|         qkv_bias: BoolSpaceType = False, | ||||
|         attn_drop: float = 0.0, | ||||
|         proj_drop: float = 0.0, | ||||
|     ): | ||||
|         super(SuperAttention, self).__init__() | ||||
|         self._input_dim = input_dim | ||||
|         self._proj_dim = proj_dim | ||||
|         self._num_heads = num_heads | ||||
|         self._qkv_bias = qkv_bias | ||||
|         # head_dim = dim // num_heads | ||||
|         # self.scale = qk_scale or math.sqrt(head_dim) | ||||
|  | ||||
|         # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | ||||
|         self.q_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias) | ||||
|         self.k_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias) | ||||
|         self.v_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias) | ||||
|  | ||||
|         self.attn_drop = nn.Dropout(attn_drop) | ||||
|         self.proj = SuperLinear(input_dim, proj_dim) | ||||
|         self.proj_drop = nn.Dropout(proj_drop) | ||||
|  | ||||
|     @property | ||||
|     def num_heads(self): | ||||
|         return spaces.get_max(self._num_heads) | ||||
|  | ||||
|     @property | ||||
|     def input_dim(self): | ||||
|         return spaces.get_max(self._input_dim) | ||||
|  | ||||
|     @property | ||||
|     def proj_dim(self): | ||||
|         return spaces.get_max(self._proj_dim) | ||||
|  | ||||
|     @property | ||||
|     def abstract_search_space(self): | ||||
|         root_node = spaces.VirtualNode(id(self)) | ||||
|         space_q = self.q_fc.abstract_search_space | ||||
|         space_k = self.k_fc.abstract_search_space | ||||
|         space_v = self.v_fc.abstract_search_space | ||||
|         space_proj = self.proj.abstract_search_space | ||||
|         if not spaces.is_determined(self._num_heads): | ||||
|             root_node.append("_num_heads", self._num_heads.abstract(reuse_last=True)) | ||||
|         if not spaces.is_determined(space_q): | ||||
|             root_node.append("q_fc", space_q) | ||||
|         if not spaces.is_determined(space_k): | ||||
|             root_node.append("k_fc", space_k) | ||||
|         if not spaces.is_determined(space_v): | ||||
|             root_node.append("v_fc", space_v) | ||||
|         if not spaces.is_determined(space_proj): | ||||
|             root_node.append("proj", space_proj) | ||||
|         return root_node | ||||
|  | ||||
|     def apply_candidate(self, abstract_child: spaces.VirtualNode): | ||||
|         super(SuperAttention, self).apply_candidate(abstract_child) | ||||
|         if "q_fc" in abstract_child: | ||||
|             self.q_fc.apply_candidate(abstract_child["q_fc"]) | ||||
|         if "k_fc" in abstract_child: | ||||
|             self.k_fc.apply_candidate(abstract_child["k_fc"]) | ||||
|         if "v_fc" in abstract_child: | ||||
|             self.v_fc.apply_candidate(abstract_child["v_fc"]) | ||||
|         if "proj" in abstract_child: | ||||
|             self.proj.apply_candidate(abstract_child["proj"]) | ||||
|  | ||||
|     def forward_qkv(self, input: torch.Tensor, num_head: int) -> torch.Tensor: | ||||
|         B, N, C = input.shape | ||||
|         q = self.q_fc(input) | ||||
|         k = self.k_fc(input) | ||||
|         v = self.v_fc(input) | ||||
|         if num_head > C: | ||||
|             raise ValueError("Invalid num_head [{:}] vs C [{:}]".format(num_head, C)) | ||||
|         head_dim = C // num_head | ||||
|         # process the first [num_head * head_dim] part | ||||
|         q_v1 = ( | ||||
|             q[:, :, : num_head * head_dim] | ||||
|             .reshape(B, N, num_head, head_dim) | ||||
|             .permute(0, 2, 1, 3) | ||||
|         ) | ||||
|         k_v1 = ( | ||||
|             k[:, :, : num_head * head_dim] | ||||
|             .reshape(B, N, num_head, head_dim) | ||||
|             .permute(0, 2, 1, 3) | ||||
|         ) | ||||
|         v_v1 = ( | ||||
|             v[:, :, : num_head * head_dim] | ||||
|             .reshape(B, N, num_head, head_dim) | ||||
|             .permute(0, 2, 1, 3) | ||||
|         ) | ||||
|         attn_v1 = (q_v1 @ k_v1.transpose(-2, -1)) * math.sqrt(head_dim) | ||||
|         attn_v1 = attn_v1.softmax(dim=-1) | ||||
|         attn_v1 = self.attn_drop(attn_v1) | ||||
|         feats_v1 = (attn_v1 @ v_v1).permute(0, 2, 1, 3).reshape(B, N, -1) | ||||
|         if C == head_dim * num_head: | ||||
|             feats = feats_v1 | ||||
|         else:  # The channels can not be divided by num_head, the remainder forms an additional head | ||||
|             q_v2 = q[:, :, num_head * head_dim :] | ||||
|             k_v2 = k[:, :, num_head * head_dim :] | ||||
|             v_v2 = v[:, :, num_head * head_dim :] | ||||
|             attn_v2 = (q_v2 @ k_v2.transpose(-2, -1)) * math.sqrt(q_v2.shape[-1]) | ||||
|             attn_v2 = attn_v2.softmax(dim=-1) | ||||
|             attn_v2 = self.attn_drop(attn_v2) | ||||
|             feats_v2 = attn_v2 @ v_v2 | ||||
|             feats = torch.cat([feats_v1, feats_v2], dim=-1) | ||||
|         return feats | ||||
|  | ||||
|     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         # check the num_heads: | ||||
|         if not spaces.is_determined(self._num_heads): | ||||
|             num_heads = self.abstract_child["_num_heads"].value | ||||
|         else: | ||||
|             num_heads = spaces.get_determined_value(self._num_heads) | ||||
|         feats = self.forward_qkv(input, num_heads) | ||||
|         outs = self.proj(feats) | ||||
|         outs = self.proj_drop(outs) | ||||
|         return outs | ||||
|  | ||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         feats = self.forward_qkv(input, self.num_heads) | ||||
|         outs = self.proj(feats) | ||||
|         outs = self.proj_drop(outs) | ||||
|         return outs | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         return "input_dim={:}, proj_dim={:}, num_heads={:}".format( | ||||
|             self._input_dim, self._proj_dim, self._num_heads | ||||
|         ) | ||||
| @@ -5,3 +5,4 @@ from .super_module import SuperRunMode | ||||
| from .super_module import SuperModule | ||||
| from .super_linear import SuperLinear | ||||
| from .super_linear import SuperMLP | ||||
| from .super_attention import SuperAttention | ||||
|   | ||||
| @@ -6,14 +6,12 @@ import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| import math | ||||
| from typing import Optional, Union, Callable | ||||
| from typing import Optional, Callable | ||||
|  | ||||
| import spaces | ||||
| from .super_module import SuperModule | ||||
| from .super_module import SuperRunMode | ||||
|  | ||||
| IntSpaceType = Union[int, spaces.Integer, spaces.Categorical] | ||||
| BoolSpaceType = Union[bool, spaces.Categorical] | ||||
| from .super_module import IntSpaceType | ||||
| from .super_module import BoolSpaceType | ||||
|  | ||||
|  | ||||
| class SuperLinear(SuperModule): | ||||
|   | ||||
| @@ -1,13 +1,18 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 # | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
|  | ||||
| import abc | ||||
| from typing import Optional, Union, Callable | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from enum import Enum | ||||
|  | ||||
| import spaces | ||||
|  | ||||
| IntSpaceType = Union[int, spaces.Integer, spaces.Categorical] | ||||
| BoolSpaceType = Union[bool, spaces.Categorical] | ||||
|  | ||||
|  | ||||
| class SuperRunMode(Enum): | ||||
|     """This class defines the enumerations for Super Model Running Mode.""" | ||||
| @@ -24,6 +29,7 @@ class SuperModule(abc.ABC, nn.Module): | ||||
|         super(SuperModule, self).__init__() | ||||
|         self._super_run_type = SuperRunMode.Default | ||||
|         self._abstract_child = None | ||||
|         self._verbose = False | ||||
|  | ||||
|     def set_super_run_type(self, super_run_type): | ||||
|         def _reset_super_run(m): | ||||
| @@ -32,6 +38,13 @@ class SuperModule(abc.ABC, nn.Module): | ||||
|  | ||||
|         self.apply(_reset_super_run) | ||||
|  | ||||
|     def apply_verbose(self, verbose): | ||||
|         def _reset_verbose(m): | ||||
|             if isinstance(m, SuperModule): | ||||
|                 m._verbose = verbose | ||||
|  | ||||
|         self.apply(_reset_verbose) | ||||
|  | ||||
|     def apply_candidate(self, abstract_child): | ||||
|         if not isinstance(abstract_child, spaces.VirtualNode): | ||||
|             raise ValueError( | ||||
| @@ -51,6 +64,10 @@ class SuperModule(abc.ABC, nn.Module): | ||||
|     def abstract_child(self): | ||||
|         return self._abstract_child | ||||
|  | ||||
|     @property | ||||
|     def verbose(self): | ||||
|         return self._verbose | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def forward_raw(self, *inputs): | ||||
|         """Use the largest candidate for forward. Similar to the original PyTorch model.""" | ||||
| @@ -60,12 +77,41 @@ class SuperModule(abc.ABC, nn.Module): | ||||
|     def forward_candidate(self, *inputs): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @property | ||||
|     def name_with_id(self): | ||||
|         return "name={:}, id={:}".format(self.__class__.__name__, id(self)) | ||||
|  | ||||
|     def get_shape_str(self, tensors): | ||||
|         if isinstance(tensors, (list, tuple)): | ||||
|             shapes = [self.get_shape_str(tensor) for tensor in tensors] | ||||
|             if len(shapes) == 1: | ||||
|                 return shapes[0] | ||||
|             else: | ||||
|                 return ", ".join(shapes) | ||||
|         elif isinstance(tensors, (torch.Tensor, nn.Parameter)): | ||||
|             return str(tuple(tensors.shape)) | ||||
|         else: | ||||
|             raise TypeError("Invalid input type: {:}.".format(type(tensors))) | ||||
|  | ||||
|     def forward(self, *inputs): | ||||
|         if self.verbose: | ||||
|             print( | ||||
|                 "[{:}] inputs shape: {:}".format( | ||||
|                     self.name_with_id, self.get_shape_str(inputs) | ||||
|                 ) | ||||
|             ) | ||||
|         if self.super_run_type == SuperRunMode.FullModel: | ||||
|             return self.forward_raw(*inputs) | ||||
|             outputs = self.forward_raw(*inputs) | ||||
|         elif self.super_run_type == SuperRunMode.Candidate: | ||||
|             return self.forward_candidate(*inputs) | ||||
|             outputs = self.forward_candidate(*inputs) | ||||
|         else: | ||||
|             raise ModeError( | ||||
|                 "Unknown Super Model Run Mode: {:}".format(self.super_run_type) | ||||
|             ) | ||||
|         if self.verbose: | ||||
|             print( | ||||
|                 "[{:}] outputs shape: {:}".format( | ||||
|                     self.name_with_id, self.get_shape_str(outputs) | ||||
|                 ) | ||||
|             ) | ||||
|         return outputs | ||||
|   | ||||
| @@ -26,6 +26,7 @@ class TestSuperLinear(unittest.TestCase): | ||||
|         bias = spaces.Categorical(True, False) | ||||
|         model = super_core.SuperLinear(10, out_features, bias=bias) | ||||
|         print("The simple super linear module is:\n{:}".format(model)) | ||||
|         model.apply_verbose(True) | ||||
|  | ||||
|         print(model.super_run_type) | ||||
|         self.assertTrue(model.bias) | ||||
| @@ -55,6 +56,7 @@ class TestSuperLinear(unittest.TestCase): | ||||
|         out_features = spaces.Categorical(24, 36, 48) | ||||
|         mlp = super_core.SuperMLP(10, hidden_features, out_features) | ||||
|         print(mlp) | ||||
|         mlp.apply_verbose(True) | ||||
|         self.assertTrue(mlp.fc1._out_features, mlp.fc2._in_features) | ||||
|  | ||||
|         inputs = torch.rand(4, 10) | ||||
| @@ -85,3 +87,29 @@ class TestSuperLinear(unittest.TestCase): | ||||
|         outputs = mlp(inputs) | ||||
|         output_shape = (4, abstract_child["fc2"]["_out_features"].value) | ||||
|         self.assertEqual(tuple(outputs.shape), output_shape) | ||||
|  | ||||
|     def test_super_attention(self): | ||||
|         proj_dim = spaces.Categorical(12, 24, 36) | ||||
|         num_heads = spaces.Categorical(2, 4, 6) | ||||
|         model = super_core.SuperAttention(10, proj_dim, num_heads) | ||||
|         print(model) | ||||
|         model.apply_verbose(True) | ||||
|  | ||||
|         inputs = torch.rand(4, 20, 10)  # batch size, sequence length, channel | ||||
|         outputs = model(inputs) | ||||
|  | ||||
|         abstract_space = model.abstract_search_space | ||||
|         print( | ||||
|             "The abstract search space for SuperAttention is:\n{:}".format( | ||||
|                 abstract_space | ||||
|             ) | ||||
|         ) | ||||
|         abstract_space.clean_last() | ||||
|         abstract_child = abstract_space.random(reuse_last=True) | ||||
|         print("The abstract child program is:\n{:}".format(abstract_child)) | ||||
|  | ||||
|         model.set_super_run_type(super_core.SuperRunMode.Candidate) | ||||
|         model.apply_candidate(abstract_child) | ||||
|         outputs = model(inputs) | ||||
|         output_shape = (4, 20, abstract_child["proj"]["_out_features"].value) | ||||
|         self.assertEqual(tuple(outputs.shape), output_shape) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user