> 文章列表 > Transfomer位置编码理解

Transfomer位置编码理解

Transfomer位置编码理解

论文地址:https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf

1. 公式描述

PE(pos,2i)=sin⁡(pos/100002i/dmodel)PE(pos,2i+1)=cos⁡(pos/100002i/dmodel)\\begin{aligned} P E_{(p o s, 2 i)} & =\\sin \\left(p o s / 10000^{2 i / d_{\\mathrm{model}}}\\right) \\\\ P E_{(p o s, 2 i+1)} & =\\cos \\left(p o s / 10000^{2 i / d_{\\mathrm{model}}}\\right) \\end{aligned} PE(pos,2i)PE(pos,2i+1)=sin(pos/100002i/dmodel)=cos(pos/100002i/dmodel)
其中:dmodeld_{model}dmodel 表示向量维度iii词向量某个维度的索引,pospospos 词的位置索引。

2.公式转换:

先做一下公式转换
1100002i/dmodel =elog⁡1100002i/dmodel =e−2i/dmodel ∗log⁡10000=e2i∗(−log⁡10000/dmodel )\\frac{1} {10000^{2i / d_{\\text {model }}}}=e^{{\\log \\frac{1} {10000^{ 2 i / d_{\\text {model }}}}}}=e^{-2 \\mathrm{i} / d_{\\text {model }} * \\log 10000}=e^{2 i *\\left(-\\log 10000 / d_{\\text {model }}\\right)} 100002i/dmodel 1=elog100002i/dmodel 1=e2i/dmodel log10000=e2i(log10000/dmodel )

3. 代码实现

# https://pytorch.org/tutorials/beginner/transformer_tutorial.html
class PositionalEncoding(nn.Module):def __init__(self, d_model, dropout=0.1, max_len=5000):super(PositionalEncoding, self).__init__()pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0).transpose(0, 1)self.register_buffer('pe', pe)def forward(self, x):x = x + self.pe[:x.size(1), :]. squeeze(1)return x 

4.个人理解

位置信息是固定值,是不可以学习的

参考资料:
https://pytorch.org/tutorials/beginner/transformer_tutorial.html

ISO9000网