Transformer完整架构详解:从自注意力机制到编码器解码器设计及PyTorch实现

目录

Transformer架构概述

Transformer是2017年Google提出的革命性架构,完全基于注意力机制,摒弃了传统的循环和卷积结构。它已成为现代NLP和多模态AI的基础架构。

Transformer的整体结构

完整的Transformer架构 = N个编码器层 + N个解码器层

编码器(Encoder) - 用于理解输入:
  输入序列 → 词嵌入 + 位置编码 → 多头自注意力 → 前馈网络 → 输出表示

解码器(Decoder) - 用于生成输出:
  输出序列 → 词嵌入 + 位置编码 → 掩码多头自注意力 → 交叉注意力 → 前馈网络 → 输出概率

关键创新:Self-Attention机制允许模型关注序列中的任意位置,解决了RNN的并行化问题。

Transformer相比传统模型的优势

  • 并行化:不像RNN需要顺序处理,Transformer可以并行处理所有位置
  • 长距离依赖:注意力机制可以直接捕获任意距离的依赖关系
  • 可解释性:注意力权重可以可视化,显示模型关注的位置
  • 扩展性:架构简单,容易扩展到更大的模型

自注意力机制详解

自注意力是Transformer的核心,它允许序列中的每个位置关注序列中的所有位置。

自注意力的数学原理

import torch
import torch.nn as nn
import math

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    缩放点积注意力的实现
    Q: 查询矩阵 (batch_size, seq_len, d_k)
    K: 键矩阵 (batch_size, seq_len, d_k) 
    V: 值矩阵 (batch_size, seq_len, d_v)
    """
    # 计算注意力分数
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    
    # 应用掩码(如果提供)
    if mask is not None:
        scores.masked_fill_(mask == 0, -1e9)
    
    # 应用softmax得到注意力权重
    attention_weights = torch.softmax(scores, dim=-1)
    
    # 计算输出
    output = torch.matmul(attention_weights, V)
    
    return output, attention_weights

def manual_self_attention_example():
    """
    手动实现自注意力的示例
    """
    print("自注意力计算步骤:")
    print("1. 输入序列 X 通过线性变换得到 Q, K, V")
    print("2. 计算注意力分数: Q @ K^T / sqrt(d_k)")
    print("3. 应用softmax得到注意力权重")
    print("4. 加权求和: AttentionWeights @ V")
    
    # 示例计算
    batch_size, seq_len, d_model = 1, 4, 8
    X = torch.randn(batch_size, seq_len, d_model)
    
    # 线性变换得到Q, K, V
    W_q = torch.randn(d_model, d_model)
    W_k = torch.randn(d_model, d_model)
    W_v = torch.randn(d_model, d_model)
    
    Q = torch.matmul(X, W_q)
    K = torch.matmul(X, W_k)
    V = torch.matmul(X, W_v)
    
    # 计算注意力
    output, weights = scaled_dot_product_attention(Q, K, V)
    
    print(f"输入形状: {X.shape}")
    print(f"输出形状: {output.shape}")
    print(f"注意力权重形状: {weights.shape}")
    print(f"注意力权重示例:\n{weights[0]}")

manual_self_attention_example()

自注意力的直观理解

自注意力机制可以理解为一种查询-键-值机制:

  • 查询(Query):当前位置想要获取信息
  • 键(Key):其他位置用来匹配的特征
  • 值(Value):其他位置的实际内容
  • 注意力分数:查询与键的相似度,决定值的权重

多头注意力机制

多头注意力允许模型同时关注来自不同表示子空间的不同信息。

多头注意力的实现

class MultiHeadAttention(nn.Module):
    """
    多头注意力机制实现
    """
    def __init__(self, d_model, num_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # 线性变换
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def split_heads(self, x, batch_size):
        """
        将输入分割为多头
        """
        x = x.view(batch_size, -1, self.num_heads, self.d_k)
        return x.transpose(1, 2)
    
    def combine_heads(self, x, batch_size):
        """
        将多头输出合并
        """
        x = x.transpose(1, 2).contiguous()
        return x.view(batch_size, -1, self.d_model)
    
    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
        
        # 线性变换
        Q = self.W_q(Q)
        K = self.W_k(K)
        V = self.W_v(V)
        
        # 分割为多头
        Q = self.split_heads(Q, batch_size)
        K = self.split_heads(K, batch_size)
        V = self.split_heads(V, batch_size)
        
        # 计算缩放点积注意力
        scaled_attention, attention_weights = scaled_dot_product_attention(
            Q, K, V, mask
        )
        
        # 合并多头
        concat_attention = self.combine_heads(scaled_attention, batch_size)
        
        # 最终线性变换
        output = self.W_o(concat_attention)
        
        return output, attention_weights

def multi_head_attention_example():
    """
    多头注意力示例
    """
    d_model = 512
    num_heads = 8
    seq_len = 10
    
    mha = MultiHeadAttention(d_model, num_heads)
    
    # 创建示例输入
    batch_size = 2
    Q = torch.randn(batch_size, seq_len, d_model)
    K = torch.randn(batch_size, seq_len, d_model)
    V = torch.randn(batch_size, seq_len, d_model)
    
    output, attention_weights = mha(Q, K, V)
    
    print(f"多头注意力输出形状: {output.shape}")
    print(f"注意力权重形状: {attention_weights.shape}")
    print(f"多头数量: {num_heads}")
    print(f"每个头的维度: {d_model // num_heads}")

multi_head_attention_example()

多头注意力的优势

  • 多样化表示:不同的头可以学习不同的注意力模式
  • 并行计算:多个头可以并行计算
  • 增强表达能力:能够捕获不同类型的依赖关系

位置编码(Positional Encoding)

由于Transformer没有循环结构,需要显式地加入位置信息。

位置编码的实现

class PositionalEncoding(nn.Module):
    """
    位置编码实现
    """
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # 创建位置编码矩阵
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        
        # 计算分母项
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * 
            (-math.log(10000.0) / d_model)
        )
        
        # 偶数位置使用sin,奇数位置使用cos
        pe[:, 0::2] = torch.sin(position * div_term)  # 偶数索引
        pe[:, 1::2] = torch.cos(position * div_term)  # 奇数索引
        
        # 注册为buffer(不会被优化器更新)
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        """
        将位置编码添加到输入中
        x: (batch_size, seq_len, d_model)
        """
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

def positional_encoding_example():
    """
    位置编码示例
    """
    d_model = 512
    max_len = 50
    pe = PositionalEncoding(d_model, max_len)
    
    # 创建示例输入
    batch_size, seq_len = 2, 10
    x = torch.randn(batch_size, seq_len, d_model)
    
    encoded_x = pe(x)
    
    print(f"输入形状: {x.shape}")
    print(f"输出形状: {encoded_x.shape}")
    
    # 可视化位置编码的一些维度
    import matplotlib.pyplot as plt
    
    # 取前几个位置的编码进行可视化
    pos_encoding = pe.pe[0, :seq_len, :8]  # 取前8个维度
    print(f"位置编码示例 (前{seq_len}个位置,前8个维度):")
    print(pos_encoding)

positional_encoding_example()

位置编码的特点

  • 可学习vs固定:可以是固定的sin/cos函数,也可以是可学习的
  • 唯一性:每个位置都有独特的位置编码
  • 连续性:相邻位置的编码相似,远距离位置的编码差异较大

编码器(Encoder)架构

编码器负责将输入序列转换为上下文表示。

编码器层实现

class EncoderLayer(nn.Module):
    """
    单个编码器层
    """
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(EncoderLayer, self).__init__()
        
        self.multi_head_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        # 多头自注意力 + 残差连接 + 层归一化
        attn_output, _ = self.multi_head_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # 前馈网络 + 残差连接 + 层归一化
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        
        return x

class Encoder(nn.Module):
    """
    完整的编码器
    """
    def __init__(self, num_layers, d_model, num_heads, d_ff, input_vocab_size, 
                 maximum_position_encoding, dropout=0.1):
        super(Encoder, self).__init__()
        
        self.d_model = d_model
        self.num_layers = num_layers
        
        self.embedding = nn.Embedding(input_vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, maximum_position_encoding)
        
        self.enc_layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout) 
            for _ in range(num_layers)
        ])
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        seq_len = x.size(1)
        
        # 词嵌入 + 位置编码
        x = self.embedding(x) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        
        # 通过所有编码器层
        for i in range(self.num_layers):
            x = self.enc_layers[i](x, mask)
        
        return x

def encoder_example():
    """
    编码器示例
    """
    # 参数设置
    num_layers = 6
    d_model = 512
    num_heads = 8
    d_ff = 2048
    input_vocab_size = 10000
    maximum_position_encoding = 100
    
    encoder = Encoder(num_layers, d_model, num_heads, d_ff, 
                      input_vocab_size, maximum_position_encoding)
    
    # 示例输入
    batch_size, seq_len = 32, 50
    input_seq = torch.randint(0, input_vocab_size, (batch_size, seq_len))
    
    encoder_output = encoder(input_seq)
    
    print(f"编码器输入形状: {input_seq.shape}")
    print(f"编码器输出形状: {encoder_output.shape}")
    print(f"编码器层数: {num_layers}")
    print(f"模型维度: {d_model}")

encoder_example()

解码器(Decoder)架构

解码器负责生成输出序列,包含掩码多头注意力和交叉注意力。

解码器层实现

class DecoderLayer(nn.Module):
    """
    单个解码器层
    """
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(DecoderLayer, self).__init__()
        
        self.masked_multi_head_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.cross_multi_head_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, look_ahead_mask=None, padding_mask=None):
        # 掩码多头自注意力 + 残差连接 + 层归一化
        attn1, _ = self.masked_multi_head_attn(x, x, x, look_ahead_mask)
        x = self.norm1(x + self.dropout(attn1))
        
        # 交叉注意力 + 残差连接 + 层归一化
        attn2, _ = self.cross_multi_head_attn(
            x, enc_output, enc_output, padding_mask
        )
        x = self.norm2(x + self.dropout(attn2))
        
        # 前馈网络 + 残差连接 + 层归一化
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        
        return x

class Decoder(nn.Module):
    """
    完整的解码器
    """
    def __init__(self, num_layers, d_model, num_heads, d_ff, target_vocab_size,
                 maximum_position_encoding, dropout=0.1):
        super(Decoder, self).__init__()
        
        self.d_model = d_model
        self.num_layers = num_layers
        
        self.embedding = nn.Embedding(target_vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, maximum_position_encoding)
        
        self.dec_layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout) 
            for _ in range(num_layers)
        ])
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, look_ahead_mask=None, padding_mask=None):
        seq_len = x.size(1)
        
        # 词嵌入 + 位置编码
        x = self.embedding(x) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        
        # 通过所有解码器层
        for i in range(self.num_layers):
            x = self.dec_layers[i](x, enc_output, look_ahead_mask, padding_mask)
        
        return x

def decoder_example():
    """
    解码器示例
    """
    # 参数设置
    num_layers = 6
    d_model = 512
    num_heads = 8
    d_ff = 2048
    target_vocab_size = 8000
    maximum_position_encoding = 100
    
    decoder = Decoder(num_layers, d_model, num_heads, d_ff, 
                      target_vocab_size, maximum_position_encoding)
    
    # 示例输入
    batch_size, seq_len = 32, 30
    target_seq = torch.randint(0, target_vocab_size, (batch_size, seq_len))
    
    # 示例编码器输出(来自编码器)
    enc_output = torch.randn(batch_size, seq_len, d_model)
    
    decoder_output = decoder(target_seq, enc_output)
    
    print(f"解码器输入形状: {target_seq.shape}")
    print(f"编码器输出形状: {enc_output.shape}")
    print(f"解码器输出形状: {decoder_output.shape}")

decoder_example()

残差连接与层归一化

残差连接和层归一化是训练深度Transformer的关键技术。

残差连接

def residual_connection_example():
    """
    残差连接原理解释
    """
    print("残差连接:")
    print("F(x) + x,其中F(x)是子层的输出,x是输入")
    print("优势:")
    print("1. 缓解梯度消失问题")
    print("2. 允许训练更深的网络")
    print("3. 保持信息流动")

def layer_normalization_example():
    """
    层归一化示例
    """
    # 层归一化与批量归一化的对比
    batch_norm = nn.BatchNorm1d(512)  # 对批次维度归一化
    layer_norm = nn.LayerNorm(512)    # 对特征维度归一化
    
    # 示例输入: (batch_size, seq_len, d_model)
    x = torch.randn(32, 50, 512)
    
    # 层归一化是对最后一个维度归一化
    normalized_x = layer_norm(x)
    
    print(f"输入形状: {x.shape}")
    print(f"层归一化输出形状: {normalized_x.shape}")
    print("层归一化在NLP中更合适,因为它独立处理每个样本")

layer_normalization_example()

完整Transformer模型实现

完整的Transformer架构

class Transformer(nn.Module):
    """
    完整的Transformer模型
    """
    def __init__(self, num_layers, d_model, num_heads, d_ff, input_vocab_size,
                 target_vocab_size, pe_input, pe_target, dropout=0.1):
        super(Transformer, self).__init__()
        
        self.encoder = Encoder(num_layers, d_model, num_heads, d_ff,
                               input_vocab_size, pe_input, dropout)
        
        self.decoder = Decoder(num_layers, d_model, num_heads, d_ff,
                               target_vocab_size, pe_target, dropout)
        
        self.final_layer = nn.Linear(d_model, target_vocab_size)
        
    def forward(self, inp, tar, enc_padding_mask, combined_mask, dec_padding_mask):
        # 编码器输出
        enc_output = self.encoder(inp, enc_padding_mask)
        
        # 解码器输出
        dec_output = self.decoder(
            tar, enc_output, combined_mask, dec_padding_mask
        )
        
        # 最终线性层 + softmax
        final_output = self.final_layer(dec_output)
        
        return final_output

def create_padding_mask(seq):
    """
    创建填充掩码
    """
    # 填充位置为1,非填充位置为0
    return (seq == 0).unsqueeze(1).unsqueeze(2)

def create_look_ahead_mask(size):
    """
    创建前瞻掩码(防止看到未来信息)
    """
    mask = torch.triu(torch.ones((size, size)), diagonal=1).type(torch.uint8)
    return mask == 0

def transformer_classifier_example():
    """
    Transformer用于分类任务的示例
    """
    class TransformerClassifier(nn.Module):
        def __init__(self, vocab_size, num_classes, d_model=256, num_heads=8,
                     num_layers=6, d_ff=1024, max_len=512, dropout=0.1):
            super(TransformerClassifier, self).__init__()
            
            self.embedding = nn.Embedding(vocab_size, d_model)
            self.pos_encoding = PositionalEncoding(d_model, max_len)
            
            encoder_layer = nn.TransformerEncoderLayer(
                d_model=d_model,
                nhead=num_heads,
                dim_feedforward=d_ff,
                dropout=dropout,
                batch_first=True
            )
            
            self.transformer_encoder = nn.TransformerEncoder(
                encoder_layer, num_layers
            )
            
            self.classifier = nn.Sequential(
                nn.Dropout(dropout),
                nn.Linear(d_model, d_ff),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(d_ff, num_classes),
            )
            
        def forward(self, input_ids, attention_mask=None):
            # 词嵌入 + 位置编码
            x = self.embedding(input_ids) * math.sqrt(self.embedding.embedding_dim)
            x = self.pos_encoding(x)
            
            # Transformer编码器
            if attention_mask is not None:
                # 将padding位置设为极小值
                expanded_mask = attention_mask.unsqueeze(1).unsqueeze(2)
                expanded_mask = expanded_mask.expand(-1, -1, input_ids.size(1), -1)
                x = self.transformer_encoder(x, mask=expanded_mask)
            else:
                x = self.transformer_encoder(x)
            
            # 使用[CLS] token或平均池化
            pooled_output = x[:, 0]  # [CLS] token
            # 或者使用平均池化: pooled_output = x.mean(dim=1)
            
            # 分类头
            logits = self.classifier(pooled_output)
            
            return logits
    
    # 示例使用
    vocab_size = 10000
    num_classes = 2
    model = TransformerClassifier(vocab_size, num_classes)
    
    batch_size, seq_len = 32, 100
    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
    
    output = model(input_ids)
    
    print(f"分类模型输入形状: {input_ids.shape}")
    print(f"分类模型输出形状: {output.shape}")
    print(f"类别数量: {num_classes}")

transformer_classifier_example()

实际应用与变体

BERT vs GPT 架构差异

def architecture_comparison():
    """
    Transformer变体对比
    """
    print("BERT (Bidirectional Encoder Representations from Transformers):")
    print("- 只有编码器部分")
    print("- 双向注意力(可以看到完整的输入序列)")
    print("- 用于理解任务(分类、问答等)")
    print("- 使用[CLS] token进行分类")
    
    print("\nGPT (Generative Pre-trained Transformer):")
    print("- 只有解码器部分(掩码自注意力)")
    print("- 单向注意力(只能看到前面的token)")
    print("- 用于生成任务(文本生成、对话等)")
    print("- 自回归生成")
    
    print("\nT5 (Text-to-Text Transfer Transformer):")
    print("- 完整的编码器-解码器架构")
    print("- 将所有NLP任务转化为文本到文本的形式")
    print("- 统一的框架处理多种任务")

architecture_comparison()

现代Transformer优化

def modern_transformer_improvements():
    """
    现代Transformer的改进
    """
    print("现代Transformer架构改进:")
    print("\n1. 位置编码改进:")
    print("   - 相对位置编码")
    print("   - 旋转位置编码(RoPE)")
    print("   - ALiBi位置偏置")
    
    print("\n2. 注意力机制改进:")
    print("   - 稀疏注意力(Longformer, BigBird)")
    print("   - 线性注意力(Performer)")
    print("   - FlashAttention")
    
    print("\n3. 架构改进:")
    print("   - SwiGLU激活函数")
    print("   - RMSNorm替代LayerNorm")
    print("   - 注意力头分离")

modern_transformer_improvements()

相关教程

Transformer是现代NLP的基石,建议先深入理解自注意力机制,再逐步掌握多头注意力、位置编码等组件。实际项目中优先使用Hugging Face的预训练模型。

总结

Transformer架构是NLP领域的革命性突破,其核心贡献包括:

  1. 自注意力机制:允许模型关注序列中的任意位置
  2. 并行化处理:解决了RNN的序列依赖问题
  3. 可扩展性:架构简单,容易扩展到更大模型
  4. 通用性:适用于多种NLP任务和领域

💡 核心要点:Transformer的成功催生了BERT、GPT等现代大模型,理解其架构对掌握现代NLP至关重要。


🔗 扩展阅读

📂 所属阶段:第三阶段 — Transformer 革命(核心篇)
🔗 相关章节:注意力机制详解 · BERT家族详解