输入部分包含:

# 导入必备的工具包
import torch
# 预定义的网络层torch.nn, 工具开发者已经帮助我们开发好的一些常用层,
# 比如,卷积层, lstm层, embedding层等, 不需要我们再重新造轮子.
import torch.nn as nn
# 数学计算工具包
import math
# torch中变量封装函数Variable.
from torch.autograd import Variable
# Embeddings类 实现思路分析
# 1 init函数 (self, d_model, vocab)
# 设置类属性 定义词嵌入层 self.lut层
# 2 forward(x)函数
# self.lut(x) * math.sqrt(self.d_model)
class Embeddings(nn.Module):# 为什么加s,因为这里要和nn.Embedding做区分
def __init__(self, d_model, vocab):
# 参数d_model 每个词汇的特征尺寸 词嵌入维度
# 参数vocab 词汇表大小
super(Embeddings, self).__init__()
self.d_model = d_model
self.vocab = vocab
# 定义词嵌入层
self.lut = nn.Embedding(self.vocab, self.d_model)
def forward(self, x):
# 将x传给self.lut并与根号下self.d_model相乘作为结果返回
# 目的是为了平衡梯度,为了避免梯度爆炸和梯度消失
# x经过词嵌入后 增大x的值, 词嵌入后的embedding_vector+位置编码信息,值量纲差差不多
return self.lut(x) * math.sqrt(self.d_model)
>>> embedding = nn.Embedding(10, 3)
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)
tensor([[[-0.0251, -1.6902, 0.7172],
[-0.6431, 0.0748, 0.6969],
[ 1.4970, 1.3448, -0.9685],
[-0.3677, -2.7265, -0.1685]],
[[ 1.4970, 1.3448, -0.9685],
[ 0.4362, -0.4004, 0.9400],
[-0.6431, 0.0748, 0.6969],
[ 0.9124, -2.3616, 1.1151]]])
>>> embedding = nn.Embedding(10, 3, padding_idx=0)
>>> input = torch.LongTensor([[0,2,0,5]])
>>> embedding(input)
tensor([[[ 0.0000, 0.0000, 0.0000],
[ 0.1535, -2.0309, 0.9315],
[ 0.0000, 0.0000, 0.0000],
[-0.1655, 0.9897, 0.0635]]])
• 调用
def dm_test_Embeddings():
d_model = 512 # 词嵌入维度是512维
vocab = 1000 # 词表大小是1000
# 实例化词嵌入层
my_embeddings = Embeddings(d_model, vocab)
x = Variable(torch.LongTensor([[100,2,421,508],[491,998,1,221]]))
embed = my_embeddings(x)
print('embed.shape', embed.shape, '\nembed--->\n',embed)
• 输出效果
embed.shape torch.Size([2, 4, 512])
embed--->
tensor([[[-19.0429, -44.2167, 2.6662, ..., -21.1199, -36.5275, -15.6872],
[-25.4621, 25.6046, -45.5382, ..., 43.7159, 0.9437, -3.1733],
[-15.7487, 8.1787, -20.6409, ..., -8.7201, -3.2585, -22.1298],
[ 21.5044, 2.0660, -1.4059, ..., -6.3673, 3.4387, -22.4600]],
[[ 15.7010, 2.6187, 14.1192, ..., -19.1751, 10.5954, 9.1155],
[-21.5745, 9.6403, 17.9778, ..., 2.3668, 30.1526, -30.3724],
[-17.6655, 33.6687, 19.3059, ..., -10.6276, -0.8653, 10.0715],
[ 12.9400, -23.6355, -2.4750, ..., 19.1028, 6.6492, -45.1315]]],
grad_fn=<MulBackward0>)