3 输入部分实现

学习目标

1 输入部分介绍

输入部分包含:

2 文本嵌入层的作用

# 导入必备的工具包
import torch

# 预定义的网络层torch.nn, 工具开发者已经帮助我们开发好的一些常用层, 
# 比如,卷积层, lstm层, embedding层等, 不需要我们再重新造轮子.
import torch.nn as nn

# 数学计算工具包
import math

# torch中变量封装函数Variable.
from torch.autograd import Variable
# Embeddings类 实现思路分析
# 1 init函数 (self, d_model, vocab)
    # 设置类属性 定义词嵌入层 self.lut层
# 2 forward(x)函数
    # self.lut(x) * math.sqrt(self.d_model)
class Embeddings(nn.Module):# 为什么加s,因为这里要和nn.Embedding做区分
    def __init__(self, d_model, vocab):
        # 参数d_model 每个词汇的特征尺寸 词嵌入维度
        # 参数vocab   词汇表大小
        super(Embeddings, self).__init__()
        self.d_model = d_model
        self.vocab = vocab

        # 定义词嵌入层
        self.lut = nn.Embedding(self.vocab, self.d_model)

    def forward(self, x):
        # 将x传给self.lut并与根号下self.d_model相乘作为结果返回
        # 目的是为了平衡梯度,为了避免梯度爆炸和梯度消失
        # x经过词嵌入后 增大x的值, 词嵌入后的embedding_vector+位置编码信息,值量纲差差不多
        return self.lut(x) * math.sqrt(self.d_model)
>>> embedding = nn.Embedding(10, 3)
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)
tensor([[[-0.0251, -1.6902,  0.7172],
         [-0.6431,  0.0748,  0.6969],
         [ 1.4970,  1.3448, -0.9685],
         [-0.3677, -2.7265, -0.1685]],

        [[ 1.4970,  1.3448, -0.9685],
         [ 0.4362, -0.4004,  0.9400],
         [-0.6431,  0.0748,  0.6969],
         [ 0.9124, -2.3616,  1.1151]]])

>>> embedding = nn.Embedding(10, 3, padding_idx=0)
>>> input = torch.LongTensor([[0,2,0,5]])
>>> embedding(input)
tensor([[[ 0.0000,  0.0000,  0.0000],
         [ 0.1535, -2.0309,  0.9315],
         [ 0.0000,  0.0000,  0.0000],
         [-0.1655,  0.9897,  0.0635]]])

• 调用

def dm_test_Embeddings():
    d_model = 512   # 词嵌入维度是512维
    vocab = 1000    # 词表大小是1000
    # 实例化词嵌入层
    my_embeddings = Embeddings(d_model, vocab)
    x = Variable(torch.LongTensor([[100,2,421,508],[491,998,1,221]]))
    embed = my_embeddings(x)
    print('embed.shape', embed.shape, '\nembed--->\n',embed)

• 输出效果

embed.shape torch.Size([2, 4, 512]) 
embed--->
 tensor([[[-19.0429, -44.2167,   2.6662,  ..., -21.1199, -36.5275, -15.6872],
         [-25.4621,  25.6046, -45.5382,  ...,  43.7159,   0.9437,  -3.1733],
         [-15.7487,   8.1787, -20.6409,  ...,  -8.7201,  -3.2585, -22.1298],
         [ 21.5044,   2.0660,  -1.4059,  ...,  -6.3673,   3.4387, -22.4600]],

        [[ 15.7010,   2.6187,  14.1192,  ..., -19.1751,  10.5954,   9.1155],
         [-21.5745,   9.6403,  17.9778,  ...,   2.3668,  30.1526, -30.3724],
         [-17.6655,  33.6687,  19.3059,  ..., -10.6276,  -0.8653,  10.0715],
         [ 12.9400, -23.6355,  -2.4750,  ...,  19.1028,   6.6492, -45.1315]]],
       grad_fn=<MulBackward0>)