#Transformer完整架构详解:从自注意力机制到编码器解码器设计及PyTorch实现
#目录
- Transformer架构概述
- 自注意力机制详解
- 多头注意力机制
- 位置编码(Positional Encoding)
- 编码器(Encoder)架构
- 解码器(Decoder)架构
- 残差连接与层归一化
- 完整Transformer模型实现
- 实际应用与变体
#Transformer架构概述
Transformer是2017年Google提出的革命性架构,完全基于注意力机制,摒弃了传统的循环和卷积结构。它已成为现代NLP和多模态AI的基础架构。
#Transformer的整体结构
完整的Transformer架构 = N个编码器层 + N个解码器层
编码器(Encoder) - 用于理解输入:
输入序列 → 词嵌入 + 位置编码 → 多头自注意力 → 前馈网络 → 输出表示
解码器(Decoder) - 用于生成输出:
输出序列 → 词嵌入 + 位置编码 → 掩码多头自注意力 → 交叉注意力 → 前馈网络 → 输出概率
关键创新:Self-Attention机制允许模型关注序列中的任意位置,解决了RNN的并行化问题。#Transformer相比传统模型的优势
- 并行化:不像RNN需要顺序处理,Transformer可以并行处理所有位置
- 长距离依赖:注意力机制可以直接捕获任意距离的依赖关系
- 可解释性:注意力权重可以可视化,显示模型关注的位置
- 扩展性:架构简单,容易扩展到更大的模型
#自注意力机制详解
自注意力是Transformer的核心,它允许序列中的每个位置关注序列中的所有位置。
#自注意力的数学原理
import torch
import torch.nn as nn
import math
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
缩放点积注意力的实现
Q: 查询矩阵 (batch_size, seq_len, d_k)
K: 键矩阵 (batch_size, seq_len, d_k)
V: 值矩阵 (batch_size, seq_len, d_v)
"""
# 计算注意力分数
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# 应用掩码(如果提供)
if mask is not None:
scores.masked_fill_(mask == 0, -1e9)
# 应用softmax得到注意力权重
attention_weights = torch.softmax(scores, dim=-1)
# 计算输出
output = torch.matmul(attention_weights, V)
return output, attention_weights
def manual_self_attention_example():
"""
手动实现自注意力的示例
"""
print("自注意力计算步骤:")
print("1. 输入序列 X 通过线性变换得到 Q, K, V")
print("2. 计算注意力分数: Q @ K^T / sqrt(d_k)")
print("3. 应用softmax得到注意力权重")
print("4. 加权求和: AttentionWeights @ V")
# 示例计算
batch_size, seq_len, d_model = 1, 4, 8
X = torch.randn(batch_size, seq_len, d_model)
# 线性变换得到Q, K, V
W_q = torch.randn(d_model, d_model)
W_k = torch.randn(d_model, d_model)
W_v = torch.randn(d_model, d_model)
Q = torch.matmul(X, W_q)
K = torch.matmul(X, W_k)
V = torch.matmul(X, W_v)
# 计算注意力
output, weights = scaled_dot_product_attention(Q, K, V)
print(f"输入形状: {X.shape}")
print(f"输出形状: {output.shape}")
print(f"注意力权重形状: {weights.shape}")
print(f"注意力权重示例:\n{weights[0]}")
manual_self_attention_example()#自注意力的直观理解
自注意力机制可以理解为一种查询-键-值机制:
- 查询(Query):当前位置想要获取信息
- 键(Key):其他位置用来匹配的特征
- 值(Value):其他位置的实际内容
- 注意力分数:查询与键的相似度,决定值的权重
#多头注意力机制
多头注意力允许模型同时关注来自不同表示子空间的不同信息。
#多头注意力的实现
class MultiHeadAttention(nn.Module):
"""
多头注意力机制实现
"""
def __init__(self, d_model, num_heads, dropout=0.1):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# 线性变换
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def split_heads(self, x, batch_size):
"""
将输入分割为多头
"""
x = x.view(batch_size, -1, self.num_heads, self.d_k)
return x.transpose(1, 2)
def combine_heads(self, x, batch_size):
"""
将多头输出合并
"""
x = x.transpose(1, 2).contiguous()
return x.view(batch_size, -1, self.d_model)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# 线性变换
Q = self.W_q(Q)
K = self.W_k(K)
V = self.W_v(V)
# 分割为多头
Q = self.split_heads(Q, batch_size)
K = self.split_heads(K, batch_size)
V = self.split_heads(V, batch_size)
# 计算缩放点积注意力
scaled_attention, attention_weights = scaled_dot_product_attention(
Q, K, V, mask
)
# 合并多头
concat_attention = self.combine_heads(scaled_attention, batch_size)
# 最终线性变换
output = self.W_o(concat_attention)
return output, attention_weights
def multi_head_attention_example():
"""
多头注意力示例
"""
d_model = 512
num_heads = 8
seq_len = 10
mha = MultiHeadAttention(d_model, num_heads)
# 创建示例输入
batch_size = 2
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)
output, attention_weights = mha(Q, K, V)
print(f"多头注意力输出形状: {output.shape}")
print(f"注意力权重形状: {attention_weights.shape}")
print(f"多头数量: {num_heads}")
print(f"每个头的维度: {d_model // num_heads}")
multi_head_attention_example()#多头注意力的优势
- 多样化表示:不同的头可以学习不同的注意力模式
- 并行计算:多个头可以并行计算
- 增强表达能力:能够捕获不同类型的依赖关系
#位置编码(Positional Encoding)
由于Transformer没有循环结构,需要显式地加入位置信息。
#位置编码的实现
class PositionalEncoding(nn.Module):
"""
位置编码实现
"""
def __init__(self, d_model, max_len=5000, dropout=0.1):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
# 创建位置编码矩阵
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
# 计算分母项
div_term = torch.exp(
torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model)
)
# 偶数位置使用sin,奇数位置使用cos
pe[:, 0::2] = torch.sin(position * div_term) # 偶数索引
pe[:, 1::2] = torch.cos(position * div_term) # 奇数索引
# 注册为buffer(不会被优化器更新)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
"""
将位置编码添加到输入中
x: (batch_size, seq_len, d_model)
"""
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
def positional_encoding_example():
"""
位置编码示例
"""
d_model = 512
max_len = 50
pe = PositionalEncoding(d_model, max_len)
# 创建示例输入
batch_size, seq_len = 2, 10
x = torch.randn(batch_size, seq_len, d_model)
encoded_x = pe(x)
print(f"输入形状: {x.shape}")
print(f"输出形状: {encoded_x.shape}")
# 可视化位置编码的一些维度
import matplotlib.pyplot as plt
# 取前几个位置的编码进行可视化
pos_encoding = pe.pe[0, :seq_len, :8] # 取前8个维度
print(f"位置编码示例 (前{seq_len}个位置,前8个维度):")
print(pos_encoding)
positional_encoding_example()#位置编码的特点
- 可学习vs固定:可以是固定的sin/cos函数,也可以是可学习的
- 唯一性:每个位置都有独特的位置编码
- 连续性:相邻位置的编码相似,远距离位置的编码差异较大
#编码器(Encoder)架构
编码器负责将输入序列转换为上下文表示。
#编码器层实现
class EncoderLayer(nn.Module):
"""
单个编码器层
"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super(EncoderLayer, self).__init__()
self.multi_head_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# 多头自注意力 + 残差连接 + 层归一化
attn_output, _ = self.multi_head_attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
# 前馈网络 + 残差连接 + 层归一化
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return x
class Encoder(nn.Module):
"""
完整的编码器
"""
def __init__(self, num_layers, d_model, num_heads, d_ff, input_vocab_size,
maximum_position_encoding, dropout=0.1):
super(Encoder, self).__init__()
self.d_model = d_model
self.num_layers = num_layers
self.embedding = nn.Embedding(input_vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, maximum_position_encoding)
self.enc_layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
seq_len = x.size(1)
# 词嵌入 + 位置编码
x = self.embedding(x) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
# 通过所有编码器层
for i in range(self.num_layers):
x = self.enc_layers[i](x, mask)
return x
def encoder_example():
"""
编码器示例
"""
# 参数设置
num_layers = 6
d_model = 512
num_heads = 8
d_ff = 2048
input_vocab_size = 10000
maximum_position_encoding = 100
encoder = Encoder(num_layers, d_model, num_heads, d_ff,
input_vocab_size, maximum_position_encoding)
# 示例输入
batch_size, seq_len = 32, 50
input_seq = torch.randint(0, input_vocab_size, (batch_size, seq_len))
encoder_output = encoder(input_seq)
print(f"编码器输入形状: {input_seq.shape}")
print(f"编码器输出形状: {encoder_output.shape}")
print(f"编码器层数: {num_layers}")
print(f"模型维度: {d_model}")
encoder_example()#解码器(Decoder)架构
解码器负责生成输出序列,包含掩码多头注意力和交叉注意力。
#解码器层实现
class DecoderLayer(nn.Module):
"""
单个解码器层
"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super(DecoderLayer, self).__init__()
self.masked_multi_head_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.cross_multi_head_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, enc_output, look_ahead_mask=None, padding_mask=None):
# 掩码多头自注意力 + 残差连接 + 层归一化
attn1, _ = self.masked_multi_head_attn(x, x, x, look_ahead_mask)
x = self.norm1(x + self.dropout(attn1))
# 交叉注意力 + 残差连接 + 层归一化
attn2, _ = self.cross_multi_head_attn(
x, enc_output, enc_output, padding_mask
)
x = self.norm2(x + self.dropout(attn2))
# 前馈网络 + 残差连接 + 层归一化
ff_output = self.feed_forward(x)
x = self.norm3(x + self.dropout(ff_output))
return x
class Decoder(nn.Module):
"""
完整的解码器
"""
def __init__(self, num_layers, d_model, num_heads, d_ff, target_vocab_size,
maximum_position_encoding, dropout=0.1):
super(Decoder, self).__init__()
self.d_model = d_model
self.num_layers = num_layers
self.embedding = nn.Embedding(target_vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, maximum_position_encoding)
self.dec_layers = nn.ModuleList([
DecoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.dropout = nn.Dropout(dropout)
def forward(self, x, enc_output, look_ahead_mask=None, padding_mask=None):
seq_len = x.size(1)
# 词嵌入 + 位置编码
x = self.embedding(x) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
# 通过所有解码器层
for i in range(self.num_layers):
x = self.dec_layers[i](x, enc_output, look_ahead_mask, padding_mask)
return x
def decoder_example():
"""
解码器示例
"""
# 参数设置
num_layers = 6
d_model = 512
num_heads = 8
d_ff = 2048
target_vocab_size = 8000
maximum_position_encoding = 100
decoder = Decoder(num_layers, d_model, num_heads, d_ff,
target_vocab_size, maximum_position_encoding)
# 示例输入
batch_size, seq_len = 32, 30
target_seq = torch.randint(0, target_vocab_size, (batch_size, seq_len))
# 示例编码器输出(来自编码器)
enc_output = torch.randn(batch_size, seq_len, d_model)
decoder_output = decoder(target_seq, enc_output)
print(f"解码器输入形状: {target_seq.shape}")
print(f"编码器输出形状: {enc_output.shape}")
print(f"解码器输出形状: {decoder_output.shape}")
decoder_example()#残差连接与层归一化
残差连接和层归一化是训练深度Transformer的关键技术。
#残差连接
def residual_connection_example():
"""
残差连接原理解释
"""
print("残差连接:")
print("F(x) + x,其中F(x)是子层的输出,x是输入")
print("优势:")
print("1. 缓解梯度消失问题")
print("2. 允许训练更深的网络")
print("3. 保持信息流动")
def layer_normalization_example():
"""
层归一化示例
"""
# 层归一化与批量归一化的对比
batch_norm = nn.BatchNorm1d(512) # 对批次维度归一化
layer_norm = nn.LayerNorm(512) # 对特征维度归一化
# 示例输入: (batch_size, seq_len, d_model)
x = torch.randn(32, 50, 512)
# 层归一化是对最后一个维度归一化
normalized_x = layer_norm(x)
print(f"输入形状: {x.shape}")
print(f"层归一化输出形状: {normalized_x.shape}")
print("层归一化在NLP中更合适,因为它独立处理每个样本")
layer_normalization_example()#完整Transformer模型实现
#完整的Transformer架构
class Transformer(nn.Module):
"""
完整的Transformer模型
"""
def __init__(self, num_layers, d_model, num_heads, d_ff, input_vocab_size,
target_vocab_size, pe_input, pe_target, dropout=0.1):
super(Transformer, self).__init__()
self.encoder = Encoder(num_layers, d_model, num_heads, d_ff,
input_vocab_size, pe_input, dropout)
self.decoder = Decoder(num_layers, d_model, num_heads, d_ff,
target_vocab_size, pe_target, dropout)
self.final_layer = nn.Linear(d_model, target_vocab_size)
def forward(self, inp, tar, enc_padding_mask, combined_mask, dec_padding_mask):
# 编码器输出
enc_output = self.encoder(inp, enc_padding_mask)
# 解码器输出
dec_output = self.decoder(
tar, enc_output, combined_mask, dec_padding_mask
)
# 最终线性层 + softmax
final_output = self.final_layer(dec_output)
return final_output
def create_padding_mask(seq):
"""
创建填充掩码
"""
# 填充位置为1,非填充位置为0
return (seq == 0).unsqueeze(1).unsqueeze(2)
def create_look_ahead_mask(size):
"""
创建前瞻掩码(防止看到未来信息)
"""
mask = torch.triu(torch.ones((size, size)), diagonal=1).type(torch.uint8)
return mask == 0
def transformer_classifier_example():
"""
Transformer用于分类任务的示例
"""
class TransformerClassifier(nn.Module):
def __init__(self, vocab_size, num_classes, d_model=256, num_heads=8,
num_layers=6, d_ff=1024, max_len=512, dropout=0.1):
super(TransformerClassifier, self).__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_len)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=num_heads,
dim_feedforward=d_ff,
dropout=dropout,
batch_first=True
)
self.transformer_encoder = nn.TransformerEncoder(
encoder_layer, num_layers
)
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, num_classes),
)
def forward(self, input_ids, attention_mask=None):
# 词嵌入 + 位置编码
x = self.embedding(input_ids) * math.sqrt(self.embedding.embedding_dim)
x = self.pos_encoding(x)
# Transformer编码器
if attention_mask is not None:
# 将padding位置设为极小值
expanded_mask = attention_mask.unsqueeze(1).unsqueeze(2)
expanded_mask = expanded_mask.expand(-1, -1, input_ids.size(1), -1)
x = self.transformer_encoder(x, mask=expanded_mask)
else:
x = self.transformer_encoder(x)
# 使用[CLS] token或平均池化
pooled_output = x[:, 0] # [CLS] token
# 或者使用平均池化: pooled_output = x.mean(dim=1)
# 分类头
logits = self.classifier(pooled_output)
return logits
# 示例使用
vocab_size = 10000
num_classes = 2
model = TransformerClassifier(vocab_size, num_classes)
batch_size, seq_len = 32, 100
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
output = model(input_ids)
print(f"分类模型输入形状: {input_ids.shape}")
print(f"分类模型输出形状: {output.shape}")
print(f"类别数量: {num_classes}")
transformer_classifier_example()#实际应用与变体
#BERT vs GPT 架构差异
def architecture_comparison():
"""
Transformer变体对比
"""
print("BERT (Bidirectional Encoder Representations from Transformers):")
print("- 只有编码器部分")
print("- 双向注意力(可以看到完整的输入序列)")
print("- 用于理解任务(分类、问答等)")
print("- 使用[CLS] token进行分类")
print("\nGPT (Generative Pre-trained Transformer):")
print("- 只有解码器部分(掩码自注意力)")
print("- 单向注意力(只能看到前面的token)")
print("- 用于生成任务(文本生成、对话等)")
print("- 自回归生成")
print("\nT5 (Text-to-Text Transfer Transformer):")
print("- 完整的编码器-解码器架构")
print("- 将所有NLP任务转化为文本到文本的形式")
print("- 统一的框架处理多种任务")
architecture_comparison()#现代Transformer优化
def modern_transformer_improvements():
"""
现代Transformer的改进
"""
print("现代Transformer架构改进:")
print("\n1. 位置编码改进:")
print(" - 相对位置编码")
print(" - 旋转位置编码(RoPE)")
print(" - ALiBi位置偏置")
print("\n2. 注意力机制改进:")
print(" - 稀疏注意力(Longformer, BigBird)")
print(" - 线性注意力(Performer)")
print(" - FlashAttention")
print("\n3. 架构改进:")
print(" - SwiGLU激活函数")
print(" - RMSNorm替代LayerNorm")
print(" - 注意力头分离")
modern_transformer_improvements()#相关教程
#总结
Transformer架构是NLP领域的革命性突破,其核心贡献包括:
- 自注意力机制:允许模型关注序列中的任意位置
- 并行化处理:解决了RNN的序列依赖问题
- 可扩展性:架构简单,容易扩展到更大模型
- 通用性:适用于多种NLP任务和领域
💡 核心要点:Transformer的成功催生了BERT、GPT等现代大模型,理解其架构对掌握现代NLP至关重要。
🔗 扩展阅读
- Attention is All You Need 论文
- The Illustrated Transformer
- Hugging Face Transformers 文档
- Transformer Explained Visually

