7 模型构建

学习目标

1 模型构建介绍

通过上面的小节, 我们已经完成了所有组成部分的实现, 接下来就来实现完整的编码器-解码器结构.

2 编码器-解码器结构的代码实现

EncoderDecoder函数完成编码解码的子任务,就是把编码和解码的流程进行封装实现。

# 编码解码内部函数类 EncoderDecoder 实现分析
# init函数 (self, encoder, decoder, source_embed, target_embed, generator)
    # 5个成员属性赋值 encoder 编码器对象 decoder 解码器对象 source_embed source端词嵌入层对象
    # target_embed target端词嵌入层对象 generator 输出层对象
# forward函数 (self, source,  target, source_mask, target_mask)
    # 1 编码 s.encoder(self.src_embed(source), source_mask)
    # 2 解码 s.decoder(self.tgt_embed(target), memory, source_mask, target_mask)
    # 3 输出 s.generator()

# 使用EncoderDecoder类来实现编码器-解码器结构
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, source_embed, target_embed, generator):
        """初始化函数中有5个参数, 分别是编码器对象, 解码器对象, 
           源数据嵌入函数, 目标数据嵌入函数,  以及输出部分的类别生成器对象
        """
        super(EncoderDecoder, self).__init__()
        # 将参数传入到类中
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = source_embed
        self.tgt_embed = target_embed
        self.generator = generator

    def forward(self, source, target, source_mask, target_mask):
        """
        在forward函数中,有四个参数, source代表源数据, target代表目标数据, 
        source_mask和target_mask代表对应的掩码张量
        source mask:padding mask机制,防止填充的pad值影响注意力计算
        target_maskt,sentence mask机制,防止未来的信息被提前利用
        """

        # 在函数中, 将source, source_mask传入编码函数, 得到结果后,
        # 与source_mask,target,和target_mask一同传给解码函数
        return self.generator(self.decode(self.encode(source, source_mask), 
                                          source_mask, target, target_mask))

    def encode(self, source, source_mask):
        """编码函数, 以source和source_mask为参数"""
        # 使用src_embed对source做处理, 然后和source_mask一起传给self.encoder
        return self.encoder(self.src_embed(source), source_mask)

    def decode(self, memory, source_mask, target, target_mask):
        """解码函数, 以memory即编码器的输出, source_mask, target, target_mask为参数"""
        # 使用tgt_embed对target做处理, 然后和source_mask, target_mask, memory一起传给self.decoder
        return self.decoder(self.tgt_embed(target), memory, source_mask, target_mask)

• 实例化参数

vocab_size = 1000
d_model = 512
encoder = en
decoder = de
source_embed = nn.Embedding(vocab_size, d_model)
target_embed = nn.Embedding(vocab_size, d_model)
generator = gen

• 输入参数:

# 假设源数据与目标数据相同, 实际中并不相同
source = target = Variable(torch.LongTensor([[100, 2, 421, 508], [491, 998, 1, 221]]))

# 假设src_mask与tgt_mask相同,实际中并不相同
source_mask = target_mask = Variable(torch.zeros(8, 4, 4))

• 调用:

ed = EncoderDecoder(encoder, decoder, source_embed, target_embed, generator)
ed_result = ed(source, target, source_mask, target_mask)
print(ed_result)
print(ed_result.shape)

• 输出效果:

tensor([[[ 0.2102, -0.0826, -0.0550,  ...,  1.5555,  1.3025, -0.6296],
         [ 0.8270, -0.5372, -0.9559,  ...,  0.3665,  0.4338, -0.7505],
         [ 0.4956, -0.5133, -0.9323,  ...,  1.0773,  1.1913, -0.6240],
         [ 0.5770, -0.6258, -0.4833,  ...,  0.1171,  1.0069, -1.9030]],

        [[-0.4355, -1.7115, -1.5685,  ..., -0.6941, -0.1878, -0.1137],
         [-0.8867, -1.2207, -1.4151,  ..., -0.9618,  0.1722, -0.9562],
         [-0.0946, -0.9012, -1.6388,  ..., -0.2604, -0.3357, -0.6436],
         [-1.1204, -1.4481, -1.5888,  ..., -0.8816, -0.6497,  0.0606]]],
       grad_fn=<AddBackward0>)
torch.Size([2, 4, 512])