| 
									
										
										
										
											2021-03-20 15:56:37 +08:00
										 |  |  | ##################################################### | 
					
						
							|  |  |  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | 
					
						
							|  |  |  | ##################################################### | 
					
						
							| 
									
										
										
										
											2021-03-15 03:36:36 +00:00
										 |  |  | from __future__ import division | 
					
						
							|  |  |  | from __future__ import print_function | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import math | 
					
						
							|  |  |  | from functools import partial | 
					
						
							| 
									
										
										
										
											2021-03-21 20:52:22 +08:00
										 |  |  | from typing import Optional, Text, List | 
					
						
							| 
									
										
										
										
											2021-03-15 03:36:36 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | import torch | 
					
						
							|  |  |  | import torch.nn as nn | 
					
						
							|  |  |  | import torch.nn.functional as F | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-21 20:52:22 +08:00
										 |  |  | import spaces | 
					
						
							|  |  |  | from xlayers import trunc_normal_ | 
					
						
							|  |  |  | from xlayers import super_core | 
					
						
							| 
									
										
										
										
											2021-03-15 03:36:36 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-21 05:59:56 -07:00
										 |  |  | __all__ = ["DefaultSearchSpace", "DEFAULT_NET_CONFIG", "get_transformer"] | 
					
						
							| 
									
										
										
										
											2021-03-15 03:36:36 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-21 20:52:22 +08:00
										 |  |  | def _get_mul_specs(candidates, num): | 
					
						
							|  |  |  |     results = [] | 
					
						
							|  |  |  |     for i in range(num): | 
					
						
							|  |  |  |         results.append(spaces.Categorical(*candidates)) | 
					
						
							|  |  |  |     return results | 
					
						
							| 
									
										
										
										
											2021-03-15 03:36:36 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-21 20:52:22 +08:00
										 |  |  | def _get_list_mul(num, multipler): | 
					
						
							|  |  |  |     results = [] | 
					
						
							|  |  |  |     for i in range(1, num + 1): | 
					
						
							|  |  |  |         results.append(i * multipler) | 
					
						
							|  |  |  |     return results | 
					
						
							| 
									
										
										
										
											2021-03-15 03:36:36 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-21 20:52:22 +08:00
										 |  |  | def _assert_types(x, expected_types): | 
					
						
							|  |  |  |     if not isinstance(x, expected_types): | 
					
						
							|  |  |  |         raise TypeError( | 
					
						
							|  |  |  |             "The type [{:}] is expected to be {:}.".format(type(x), expected_types) | 
					
						
							| 
									
										
										
										
											2021-03-18 15:04:14 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-03-15 03:36:36 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-21 05:59:56 -07:00
										 |  |  | DEFAULT_NET_CONFIG = None | 
					
						
							| 
									
										
										
										
											2021-03-21 20:52:22 +08:00
										 |  |  | _default_max_depth = 5 | 
					
						
							|  |  |  | DefaultSearchSpace = dict( | 
					
						
							|  |  |  |     d_feat=6, | 
					
						
							| 
									
										
										
										
											2021-03-24 05:33:52 -07:00
										 |  |  |     embed_dim=spaces.Categorical(*_get_list_mul(8, 16)), | 
					
						
							| 
									
										
										
										
											2021-03-21 20:52:22 +08:00
										 |  |  |     num_heads=_get_mul_specs((1, 2, 4, 8), _default_max_depth), | 
					
						
							|  |  |  |     mlp_hidden_multipliers=_get_mul_specs((0.5, 1, 2, 4, 8), _default_max_depth), | 
					
						
							|  |  |  |     qkv_bias=True, | 
					
						
							|  |  |  |     pos_drop=0.0, | 
					
						
							|  |  |  |     other_drop=0.0, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2021-03-15 03:36:36 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-21 20:52:22 +08:00
										 |  |  | class SuperTransformer(super_core.SuperModule): | 
					
						
							|  |  |  |     """The super model for transformer.""" | 
					
						
							| 
									
										
										
										
											2021-03-15 03:36:36 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def __init__( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         d_feat: int = 6, | 
					
						
							| 
									
										
										
										
											2021-03-24 05:33:52 -07:00
										 |  |  |         embed_dim: List[super_core.IntSpaceType] = DefaultSearchSpace["embed_dim"], | 
					
						
							| 
									
										
										
										
											2021-03-21 20:52:22 +08:00
										 |  |  |         num_heads: List[super_core.IntSpaceType] = DefaultSearchSpace["num_heads"], | 
					
						
							|  |  |  |         mlp_hidden_multipliers: List[super_core.IntSpaceType] = DefaultSearchSpace[ | 
					
						
							|  |  |  |             "mlp_hidden_multipliers" | 
					
						
							|  |  |  |         ], | 
					
						
							|  |  |  |         qkv_bias: bool = DefaultSearchSpace["qkv_bias"], | 
					
						
							|  |  |  |         pos_drop: float = DefaultSearchSpace["pos_drop"], | 
					
						
							|  |  |  |         other_drop: float = DefaultSearchSpace["other_drop"], | 
					
						
							| 
									
										
										
										
											2021-03-15 03:36:36 +00:00
										 |  |  |         max_seq_len: int = 65, | 
					
						
							|  |  |  |     ): | 
					
						
							| 
									
										
										
										
											2021-03-21 20:52:22 +08:00
										 |  |  |         super(SuperTransformer, self).__init__() | 
					
						
							| 
									
										
										
										
											2021-03-24 05:33:52 -07:00
										 |  |  |         self._embed_dim = embed_dim | 
					
						
							| 
									
										
										
										
											2021-03-21 20:52:22 +08:00
										 |  |  |         self._num_heads = num_heads | 
					
						
							|  |  |  |         self._mlp_hidden_multipliers = mlp_hidden_multipliers | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # the stem part | 
					
						
							| 
									
										
										
										
											2021-03-24 21:17:55 +08:00
										 |  |  |         self.input_embed = super_core.SuperAlphaEBDv1(d_feat, embed_dim) | 
					
						
							|  |  |  |         self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) | 
					
						
							| 
									
										
										
										
											2021-03-21 20:52:22 +08:00
										 |  |  |         self.pos_embed = super_core.SuperPositionalEncoder( | 
					
						
							| 
									
										
										
										
											2021-03-24 21:17:55 +08:00
										 |  |  |             d_model=embed_dim, max_seq_len=max_seq_len, dropout=pos_drop | 
					
						
							| 
									
										
										
										
											2021-03-18 15:04:14 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-03-21 20:52:22 +08:00
										 |  |  |         # build the transformer encode layers -->> check params | 
					
						
							|  |  |  |         _assert_types(num_heads, (tuple, list)) | 
					
						
							|  |  |  |         _assert_types(mlp_hidden_multipliers, (tuple, list)) | 
					
						
							| 
									
										
										
										
											2021-03-24 05:33:52 -07:00
										 |  |  |         assert len(num_heads) == len(mlp_hidden_multipliers), "{:} vs {:}".format( | 
					
						
							|  |  |  |             len(num_heads), len(mlp_hidden_multipliers) | 
					
						
							| 
									
										
										
										
											2021-03-15 03:36:36 +00:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-03-21 20:52:22 +08:00
										 |  |  |         # build the transformer encode layers -->> backbone | 
					
						
							| 
									
										
										
										
											2021-03-24 05:33:52 -07:00
										 |  |  |         layers = [] | 
					
						
							|  |  |  |         for num_head, mlp_hidden_multiplier in zip(num_heads, mlp_hidden_multipliers): | 
					
						
							| 
									
										
										
										
											2021-03-21 20:52:22 +08:00
										 |  |  |             layer = super_core.SuperTransformerEncoderLayer( | 
					
						
							|  |  |  |                 embed_dim, | 
					
						
							|  |  |  |                 num_head, | 
					
						
							|  |  |  |                 qkv_bias, | 
					
						
							|  |  |  |                 mlp_hidden_multiplier, | 
					
						
							|  |  |  |                 other_drop, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             layers.append(layer) | 
					
						
							|  |  |  |         self.backbone = super_core.SuperSequential(*layers) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # the regression head | 
					
						
							| 
									
										
										
										
											2021-03-24 05:33:52 -07:00
										 |  |  |         self.head = super_core.SuperSequential( | 
					
						
							|  |  |  |             super_core.SuperLayerNorm1D(embed_dim), super_core.SuperLinear(embed_dim, 1) | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-03-21 20:52:22 +08:00
										 |  |  |         trunc_normal_(self.cls_token, std=0.02) | 
					
						
							| 
									
										
										
										
											2021-03-15 03:36:36 +00:00
										 |  |  |         self.apply(self._init_weights) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-21 20:52:22 +08:00
										 |  |  |     @property | 
					
						
							| 
									
										
										
										
											2021-03-24 21:17:55 +08:00
										 |  |  |     def embed_dim(self): | 
					
						
							|  |  |  |         return spaces.get_max(self._embed_dim) | 
					
						
							| 
									
										
										
										
											2021-03-21 20:52:22 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     @property | 
					
						
							|  |  |  |     def abstract_search_space(self): | 
					
						
							|  |  |  |         root_node = spaces.VirtualNode(id(self)) | 
					
						
							| 
									
										
										
										
											2021-03-24 21:17:55 +08:00
										 |  |  |         if not spaces.is_determined(self._embed_dim): | 
					
						
							| 
									
										
										
										
											2021-03-24 05:33:52 -07:00
										 |  |  |             root_node.append("_embed_dim", self._embed_dim.abstract(reuse_last=True)) | 
					
						
							| 
									
										
										
										
											2021-03-21 20:52:22 +08:00
										 |  |  |         xdict = dict( | 
					
						
							|  |  |  |             input_embed=self.input_embed.abstract_search_space, | 
					
						
							|  |  |  |             pos_embed=self.pos_embed.abstract_search_space, | 
					
						
							|  |  |  |             backbone=self.backbone.abstract_search_space, | 
					
						
							|  |  |  |             head=self.head.abstract_search_space, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         for key, space in xdict.items(): | 
					
						
							|  |  |  |             if not spaces.is_determined(space): | 
					
						
							|  |  |  |                 root_node.append(key, space) | 
					
						
							|  |  |  |         return root_node | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def apply_candidate(self, abstract_child: spaces.VirtualNode): | 
					
						
							|  |  |  |         super(SuperTransformer, self).apply_candidate(abstract_child) | 
					
						
							|  |  |  |         xkeys = ("input_embed", "pos_embed", "backbone", "head") | 
					
						
							|  |  |  |         for key in xkeys: | 
					
						
							|  |  |  |             if key in abstract_child: | 
					
						
							|  |  |  |                 getattr(self, key).apply_candidate(abstract_child[key]) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-15 03:36:36 +00:00
										 |  |  |     def _init_weights(self, m): | 
					
						
							|  |  |  |         if isinstance(m, nn.Linear): | 
					
						
							| 
									
										
										
										
											2021-03-21 20:52:22 +08:00
										 |  |  |             trunc_normal_(m.weight, std=0.02) | 
					
						
							| 
									
										
										
										
											2021-03-15 03:36:36 +00:00
										 |  |  |             if isinstance(m, nn.Linear) and m.bias is not None: | 
					
						
							|  |  |  |                 nn.init.constant_(m.bias, 0) | 
					
						
							| 
									
										
										
										
											2021-03-21 20:52:22 +08:00
										 |  |  |         elif isinstance(m, super_core.SuperLinear): | 
					
						
							|  |  |  |             trunc_normal_(m._super_weight, std=0.02) | 
					
						
							|  |  |  |             if m._super_bias is not None: | 
					
						
							|  |  |  |                 nn.init.constant_(m._super_bias, 0) | 
					
						
							|  |  |  |         elif isinstance(m, super_core.SuperLayerNorm1D): | 
					
						
							| 
									
										
										
										
											2021-03-15 03:36:36 +00:00
										 |  |  |             nn.init.constant_(m.weight, 1.0) | 
					
						
							| 
									
										
										
										
											2021-03-21 20:52:22 +08:00
										 |  |  |             nn.init.constant_(m.bias, 0) | 
					
						
							| 
									
										
										
										
											2021-03-15 03:36:36 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-21 20:52:22 +08:00
										 |  |  |     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | 
					
						
							|  |  |  |         batch, flatten_size = input.shape | 
					
						
							|  |  |  |         feats = self.input_embed(input)  # batch * 60 * 64 | 
					
						
							| 
									
										
										
										
											2021-03-24 21:17:55 +08:00
										 |  |  |         if not spaces.is_determined(self._embed_dim): | 
					
						
							|  |  |  |             embed_dim = self.abstract_child["_embed_dim"].value | 
					
						
							| 
									
										
										
										
											2021-03-21 20:52:22 +08:00
										 |  |  |         else: | 
					
						
							| 
									
										
										
										
											2021-03-24 21:17:55 +08:00
										 |  |  |             embed_dim = spaces.get_determined_value(self._embed_dim) | 
					
						
							| 
									
										
										
										
											2021-03-21 20:52:22 +08:00
										 |  |  |         cls_tokens = self.cls_token.expand(batch, -1, -1) | 
					
						
							| 
									
										
										
										
											2021-03-21 05:59:56 -07:00
										 |  |  |         cls_tokens = F.interpolate( | 
					
						
							| 
									
										
										
										
											2021-03-24 21:17:55 +08:00
										 |  |  |             cls_tokens, size=(embed_dim), mode="linear", align_corners=True | 
					
						
							| 
									
										
										
										
											2021-03-21 05:59:56 -07:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-03-15 03:36:36 +00:00
										 |  |  |         feats_w_ct = torch.cat((cls_tokens, feats), dim=1) | 
					
						
							|  |  |  |         feats_w_tp = self.pos_embed(feats_w_ct) | 
					
						
							| 
									
										
										
										
											2021-03-21 20:52:22 +08:00
										 |  |  |         xfeats = self.backbone(feats_w_tp) | 
					
						
							|  |  |  |         xfeats = xfeats[:, 0, :]  # use the feature for the first token | 
					
						
							|  |  |  |         predicts = self.head(xfeats).squeeze(-1) | 
					
						
							|  |  |  |         return predicts | 
					
						
							| 
									
										
										
										
											2021-03-15 03:36:36 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-21 20:52:22 +08:00
										 |  |  |     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | 
					
						
							|  |  |  |         batch, flatten_size = input.shape | 
					
						
							|  |  |  |         feats = self.input_embed(input)  # batch * 60 * 64 | 
					
						
							|  |  |  |         cls_tokens = self.cls_token.expand(batch, -1, -1) | 
					
						
							|  |  |  |         feats_w_ct = torch.cat((cls_tokens, feats), dim=1) | 
					
						
							|  |  |  |         feats_w_tp = self.pos_embed(feats_w_ct) | 
					
						
							|  |  |  |         xfeats = self.backbone(feats_w_tp) | 
					
						
							|  |  |  |         xfeats = xfeats[:, 0, :]  # use the feature for the first token | 
					
						
							|  |  |  |         predicts = self.head(xfeats).squeeze(-1) | 
					
						
							| 
									
										
										
										
											2021-03-15 03:36:36 +00:00
										 |  |  |         return predicts | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def get_transformer(config): | 
					
						
							| 
									
										
										
										
											2021-03-21 20:52:22 +08:00
										 |  |  |     if config is None: | 
					
						
							|  |  |  |         return SuperTransformer(6) | 
					
						
							| 
									
										
										
										
											2021-03-15 03:36:36 +00:00
										 |  |  |     if not isinstance(config, dict): | 
					
						
							|  |  |  |         raise ValueError("Invalid Configuration: {:}".format(config)) | 
					
						
							|  |  |  |     name = config.get("name", "basic") | 
					
						
							|  |  |  |     if name == "basic": | 
					
						
							| 
									
										
										
										
											2021-03-23 11:13:51 +00:00
										 |  |  |         model = SuperTransformer( | 
					
						
							| 
									
										
										
										
											2021-03-15 03:36:36 +00:00
										 |  |  |             d_feat=config.get("d_feat"), | 
					
						
							| 
									
										
										
										
											2021-03-24 05:33:52 -07:00
										 |  |  |             embed_dim=config.get("embed_dim"), | 
					
						
							| 
									
										
										
										
											2021-03-15 03:36:36 +00:00
										 |  |  |             num_heads=config.get("num_heads"), | 
					
						
							| 
									
										
										
										
											2021-03-23 11:13:51 +00:00
										 |  |  |             mlp_hidden_multipliers=config.get("mlp_hidden_multipliers"), | 
					
						
							| 
									
										
										
										
											2021-03-15 03:36:36 +00:00
										 |  |  |             qkv_bias=config.get("qkv_bias"), | 
					
						
							|  |  |  |             pos_drop=config.get("pos_drop"), | 
					
						
							| 
									
										
										
										
											2021-03-23 11:13:51 +00:00
										 |  |  |             other_drop=config.get("other_drop"), | 
					
						
							| 
									
										
										
										
											2021-03-15 03:36:36 +00:00
										 |  |  |         ) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         raise ValueError("Unknown model name: {:}".format(name)) | 
					
						
							|  |  |  |     return model |