通过上面的小节, 我们已经完成了所有组成部分的实现, 接下来就来实现完整的编码器-解码器结构.

EncoderDecoder函数完成编码解码的子任务,就是把编码和解码的流程进行封装实现。
# 编码解码内部函数类 EncoderDecoder 实现分析
# init函数 (self, encoder, decoder, source_embed, target_embed, generator)
# 5个成员属性赋值 encoder 编码器对象 decoder 解码器对象 source_embed source端词嵌入层对象
# target_embed target端词嵌入层对象 generator 输出层对象
# forward函数 (self, source, target, source_mask, target_mask)
# 1 编码 s.encoder(self.src_embed(source), source_mask)
# 2 解码 s.decoder(self.tgt_embed(target), memory, source_mask, target_mask)
# 3 输出 s.generator()
# 使用EncoderDecoder类来实现编码器-解码器结构
class EncoderDecoder(nn.Module):
def __init__(self, encoder, decoder, source_embed, target_embed, generator):
"""初始化函数中有5个参数, 分别是编码器对象, 解码器对象,
源数据嵌入函数, 目标数据嵌入函数, 以及输出部分的类别生成器对象
"""
super(EncoderDecoder, self).__init__()
# 将参数传入到类中
self.encoder = encoder
self.decoder = decoder
self.src_embed = source_embed
self.tgt_embed = target_embed
self.generator = generator
def forward(self, source, target, source_mask, target_mask):
"""
在forward函数中,有四个参数, source代表源数据, target代表目标数据,
source_mask和target_mask代表对应的掩码张量
source mask:padding mask机制,防止填充的pad值影响注意力计算
target_maskt,sentence mask机制,防止未来的信息被提前利用
"""
# 在函数中, 将source, source_mask传入编码函数, 得到结果后,
# 与source_mask,target,和target_mask一同传给解码函数
return self.generator(self.decode(self.encode(source, source_mask),
source_mask, target, target_mask))
def encode(self, source, source_mask):
"""编码函数, 以source和source_mask为参数"""
# 使用src_embed对source做处理, 然后和source_mask一起传给self.encoder
return self.encoder(self.src_embed(source), source_mask)
def decode(self, memory, source_mask, target, target_mask):
"""解码函数, 以memory即编码器的输出, source_mask, target, target_mask为参数"""
# 使用tgt_embed对target做处理, 然后和source_mask, target_mask, memory一起传给self.decoder
return self.decoder(self.tgt_embed(target), memory, source_mask, target_mask)
• 实例化参数
vocab_size = 1000
d_model = 512
encoder = en
decoder = de
source_embed = nn.Embedding(vocab_size, d_model)
target_embed = nn.Embedding(vocab_size, d_model)
generator = gen
• 输入参数:
# 假设源数据与目标数据相同, 实际中并不相同
source = target = Variable(torch.LongTensor([[100, 2, 421, 508], [491, 998, 1, 221]]))
# 假设src_mask与tgt_mask相同,实际中并不相同
source_mask = target_mask = Variable(torch.zeros(8, 4, 4))
• 调用:
ed = EncoderDecoder(encoder, decoder, source_embed, target_embed, generator)
ed_result = ed(source, target, source_mask, target_mask)
print(ed_result)
print(ed_result.shape)
• 输出效果:
tensor([[[ 0.2102, -0.0826, -0.0550, ..., 1.5555, 1.3025, -0.6296],
[ 0.8270, -0.5372, -0.9559, ..., 0.3665, 0.4338, -0.7505],
[ 0.4956, -0.5133, -0.9323, ..., 1.0773, 1.1913, -0.6240],
[ 0.5770, -0.6258, -0.4833, ..., 0.1171, 1.0069, -1.9030]],
[[-0.4355, -1.7115, -1.5685, ..., -0.6941, -0.1878, -0.1137],
[-0.8867, -1.2207, -1.4151, ..., -0.9618, 0.1722, -0.9562],
[-0.0946, -0.9012, -1.6388, ..., -0.2604, -0.3357, -0.6436],
[-1.1204, -1.4481, -1.5888, ..., -0.8816, -0.6497, 0.0606]]],
grad_fn=<AddBackward0>)
torch.Size([2, 4, 512])