Refine Transformer
This commit is contained in:
		| @@ -21,10 +21,10 @@ import torch.nn.functional as F | ||||
| import torch.optim as optim | ||||
| import torch.utils.data as th_data | ||||
|  | ||||
| from log_utils import AverageMeter | ||||
| from utils import count_parameters | ||||
| from xautodl.xmisc import AverageMeter | ||||
| from xautodl.xmisc import count_parameters | ||||
|  | ||||
| from xlayers import super_core | ||||
| from xautodl.xlayers import super_core | ||||
| from .transformers import DEFAULT_NET_CONFIG | ||||
| from .transformers import get_transformer | ||||
|  | ||||
|   | ||||
| @@ -13,7 +13,7 @@ import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| from xautodl import spaces | ||||
| from xautodl.xlayers import trunc_normal_ | ||||
| from xautodl.xlayers import weight_init | ||||
| from xautodl.xlayers import super_core | ||||
|  | ||||
|  | ||||
| @@ -104,7 +104,7 @@ class SuperTransformer(super_core.SuperModule): | ||||
|         self.head = super_core.SuperSequential( | ||||
|             super_core.SuperLayerNorm1D(embed_dim), super_core.SuperLinear(embed_dim, 1) | ||||
|         ) | ||||
|         trunc_normal_(self.cls_token, std=0.02) | ||||
|         weight_init.trunc_normal_(self.cls_token, std=0.02) | ||||
|         self.apply(self._init_weights) | ||||
|  | ||||
|     @property | ||||
| @@ -136,11 +136,11 @@ class SuperTransformer(super_core.SuperModule): | ||||
|  | ||||
|     def _init_weights(self, m): | ||||
|         if isinstance(m, nn.Linear): | ||||
|             trunc_normal_(m.weight, std=0.02) | ||||
|             weight_init.trunc_normal_(m.weight, std=0.02) | ||||
|             if isinstance(m, nn.Linear) and m.bias is not None: | ||||
|                 nn.init.constant_(m.bias, 0) | ||||
|         elif isinstance(m, super_core.SuperLinear): | ||||
|             trunc_normal_(m._super_weight, std=0.02) | ||||
|             weight_init.trunc_normal_(m._super_weight, std=0.02) | ||||
|             if m._super_bias is not None: | ||||
|                 nn.init.constant_(m._super_bias, 0) | ||||
|         elif isinstance(m, super_core.SuperLayerNorm1D): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user