Refine Transformer

This commit is contained in:
D-X-Y
2021-07-04 11:59:06 +00:00
parent 9136f33684
commit 11f313288a
10 changed files with 160 additions and 28 deletions

View File

@@ -13,7 +13,7 @@ import torch.nn as nn
import torch.nn.functional as F
from xautodl import spaces
from xautodl.xlayers import trunc_normal_
from xautodl.xlayers import weight_init
from xautodl.xlayers import super_core
@@ -104,7 +104,7 @@ class SuperTransformer(super_core.SuperModule):
self.head = super_core.SuperSequential(
super_core.SuperLayerNorm1D(embed_dim), super_core.SuperLinear(embed_dim, 1)
)
trunc_normal_(self.cls_token, std=0.02)
weight_init.trunc_normal_(self.cls_token, std=0.02)
self.apply(self._init_weights)
@property
@@ -136,11 +136,11 @@ class SuperTransformer(super_core.SuperModule):
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
weight_init.trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, super_core.SuperLinear):
trunc_normal_(m._super_weight, std=0.02)
weight_init.trunc_normal_(m._super_weight, std=0.02)
if m._super_bias is not None:
nn.init.constant_(m._super_bias, 0)
elif isinstance(m, super_core.SuperLayerNorm1D):