autodl-projects/lib/layers/positional_embedding.py

36 lines
1.1 KiB
Python
Raw Normal View History

2021-03-05 14:50:30 +01:00
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
#####################################################
2021-03-03 14:57:48 +01:00
import torch
import torch.nn as nn
import math
class PositionalEncoder(nn.Module):
# Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf
2021-03-05 14:50:30 +01:00
# https://github.com/pytorch/examples/blob/master/word_language_model/model.py#L65
2021-03-03 14:57:48 +01:00
2021-03-05 14:50:30 +01:00
def __init__(self, d_model, max_seq_len, dropout=0.1):
2021-03-03 14:57:48 +01:00
super(PositionalEncoder, self).__init__()
self.d_model = d_model
# create constant 'pe' matrix with values dependant on
# pos and i
pe = torch.zeros(max_seq_len, d_model)
for pos in range(max_seq_len):
for i in range(0, d_model):
div = 10000 ** ((i // 2) * 2 / d_model)
value = pos / div
if i % 2 == 0:
pe[pos, i] = math.sin(value)
else:
pe[pos, i] = math.cos(value)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
batch, seq, fdim = x.shape[:3]
embeddings = self.pe[:, :seq, :fdim]
2021-03-05 14:50:30 +01:00
import pdb; pdb.set_trace()
outs = self.dropout(x + embeddings)
2021-03-03 14:57:48 +01:00
return x + embeddings