Vision Transformer (ViT)详解:从图像到序列的视觉革命

引言

2020年,Google发布的《An Image is Worth 16x16 Words》论文彻底改变了计算机视觉领域的格局。Vision Transformer (ViT)首次成功将原本用于自然语言处理的Transformer架构应用于图像分类任务,并在大规模数据集上取得了超越CNN的性能表现。

ViT的核心创新在于:像处理自然语言一样处理图像,将图像视为一系列"视觉词汇"的序列,通过自注意力机制捕捉全局依赖关系。


1. ViT概述:为什么需要Transformer?

1.1 CNN的局限性

尽管卷积神经网络(CNN)在计算机视觉领域取得了巨大成功,但它存在一些固有的局限性:

  • 归纳偏置(Inductive Bias)

    • 局部性:卷积核每次只关注局部区域,难以直接捕捉图像中远距离区域的关系(如左上角与右下角的关联)
    • 平移不变性:CNN假设图像的局部模式在不同位置具有相似性,但缺乏对全局结构的显式建模
  • 静态权重:卷积核的参数在训练完成后是固定的,对所有输入图像使用相同的过滤方式

1.2 ViT的核心思想

ViT通过以下方式克服了CNN的局限性:

  • 全局感知:自注意力机制允许图像中的每个区域与所有其他区域进行交互
  • 动态权重:注意力权重根据输入内容动态调整,更具适应性
  • 可扩展性:能够有效利用大规模数据集进行预训练

2. ViT架构详解

2.1 核心流程

ViT的处理流程可以分为四个关键步骤:

  1. 图像分块 (Patch Embedding):将图像切割成固定大小的块(Patches)
  2. 线性投影:将每个图像块映射到固定维度的向量
  3. 位置编码 (Position Embedding):为每个图像块添加位置信息
  4. 分类标记 (CLS Token):添加特殊标记用于最终分类

2.2 架构组件

图像分块模块 (Patch Embedding)

  • 将图像划分为不重叠的块
  • 使用线性投影将每个块映射到向量空间

位置编码 (Position Embedding)

  • 为每个图像块添加位置信息
  • 使用可学习的位置嵌入

Transformer编码器

  • 多头自注意力机制
  • 前馈神经网络
  • 层归一化和残差连接

分类头 (Classification Head)

  • 使用CLS标记的输出进行分类
  • 通常是一个简单的MLP

3. PyTorch实现详解

3.1 图像分块模块

import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbedding(nn.Module):
    """
    图像分块嵌入模块
    将图像切割成不重叠的块并映射到向量空间
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        
        # 使用卷积层实现分块和投影(高效的方法)
        self.proj = nn.Conv2d(
            in_chans, embed_dim, 
            kernel_size=patch_size, 
            stride=patch_size
        )

    def forward(self, x):
        """
        输入: (B, C, H, W)
        输出: (B, n_patches, embed_dim)
        """
        x = self.proj(x)  # (B, embed_dim, n_patches_h, n_patches_w)
        x = x.flatten(2)   # (B, embed_dim, n_patches)
        x = x.transpose(1, 2)  # (B, n_patches, embed_dim)
        return x

# 测试图像分块
patch_embed = PatchEmbedding()
x = torch.randn(2, 3, 224, 224)  # 2张224x224的RGB图像
patches = patch_embed(x)
print(f"Patches shape: {patches.shape}")  # (2, 196, 768)

3.2 多头自注意力机制

class MultiHeadAttention(nn.Module):
    """
    多头自注意力机制
    """
    def __init__(self, embed_dim=768, n_heads=12, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.head_dim = embed_dim // n_heads
        self.scale = self.head_dim ** -0.5
        
        # 线性投影矩阵
        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=True)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, N, C = x.shape
        
        # 计算Q, K, V
        qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, n_heads, N, head_dim)
        q, k, v = qkv.unbind(0)  # (B, n_heads, N, head_dim)

        # 计算注意力分数
        attn = (q @ k.transpose(-2, -1)) * self.scale  # (B, n_heads, N, N)
        attn = attn.softmax(dim=-1)
        attn = self.dropout(attn)

        # 应用注意力到V
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

# 测试多头注意力
attention = MultiHeadAttention()
x = torch.randn(2, 197, 768)  # 2个序列,每个序列197个token,每个token 768维
attn_out = attention(x)
print(f"Attention output shape: {attn_out.shape}")  # (2, 197, 768)

3.3 Transformer编码器块

class TransformerBlock(nn.Module):
    """
    Transformer编码器块
    """
    def __init__(self, embed_dim=768, n_heads=12, mlp_ratio=4, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, n_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        # MLP层
        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        # 残差连接 + 层归一化
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

# 测试Transformer块
block = TransformerBlock()
x = torch.randn(2, 197, 768)
out = block(x)
print(f"Transformer block output shape: {out.shape}")  # (2, 197, 768)

3.4 完整ViT模型

class VisionTransformer(nn.Module):
    """
    完整的Vision Transformer模型
    """
    def __init__(
        self, 
        img_size=224, 
        patch_size=16, 
        in_chans=3, 
        n_classes=1000, 
        embed_dim=768, 
        depth=12, 
        n_heads=12, 
        mlp_ratio=4, 
        dropout=0.1,
        attn_dropout=0.1
    ):
        super().__init__()
        
        # 图像分块
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_chans, embed_dim)
        n_patches = self.patch_embed.n_patches
        
        # CLS token和位置编码
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(dropout)
        
        # Transformer编码器层
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, n_heads, mlp_ratio, attn_dropout)
            for _ in range(depth)
        ])
        
        # 分类头
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, n_classes)

    def forward(self, x):
        B = x.shape[0]
        
        # 图像分块
        x = self.patch_embed(x)  # (B, n_patches, embed_dim)
        
        # 添加CLS token
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B, 1, embed_dim)
        x = torch.cat((cls_tokens, x), dim=1)  # (B, n_patches+1, embed_dim)
        
        # 添加位置编码
        x = x + self.pos_embed
        x = self.pos_drop(x)
        
        # 通过Transformer块
        for block in self.blocks:
            x = block(x)
        
        # 应用最终层归一化
        x = self.norm(x)
        
        # 使用CLS token进行分类
        cls_output = x[:, 0]  # (B, embed_dim)
        return self.head(cls_output)

# 测试完整模型
vit = VisionTransformer(n_classes=1000)
x = torch.randn(4, 3, 224, 224)  # 4张图像
logits = vit(x)
print(f"Final output shape: {logits.shape}")  # (4, 1000)

4. 位置编码详解

4.1 可学习位置编码

ViT使用可学习的位置编码,而不是固定的正弦/余弦编码:

class LearnablePositionEncoding(nn.Module):
    """
    可学习的位置编码
    """
    def __init__(self, seq_len, embed_dim):
        super().__init__()
        self.pos_embed = nn.Parameter(torch.zeros(1, seq_len, embed_dim))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

    def forward(self, x):
        return x + self.pos_embed

4.2 二维位置编码

虽然ViT使用一维位置编码,但也可以扩展到二维:

class PositionEmbedding2D(nn.Module):
    """
    2D位置编码(用于某些ViT变体)
    """
    def __init__(self, height, width, embed_dim):
        super().__init__()
        self.height_embed = nn.Embedding(height, embed_dim // 2)
        self.width_embed = nn.Embedding(width, embed_dim // 2)
        
        # 初始化
        nn.init.trunc_normal_(self.height_embed.weight, std=0.02)
        nn.init.trunc_normal_(self.width_embed.weight, std=0.02)

    def forward(self, x, H, W):
        # x: (B, N, C) where N = H * W
        pos_h = torch.arange(H, device=x.device)
        pos_w = torch.arange(W, device=x.device)
        
        embed_h = self.height_embed(pos_h)  # (H, C//2)
        embed_w = self.width_embed(pos_w)  # (W, C//2)
        
        # 组合2D位置编码
        pos_2d = torch.cat([
            embed_h.unsqueeze(1).expand(-1, W, -1),  # (H, W, C//2)
            embed_w.unsqueeze(0).expand(H, -1, -1)   # (H, W, C//2)
        ], dim=-1).flatten(0, 1)  # (H*W, C)
        
        return x + pos_2d.unsqueeze(0)  # (1, H*W, C) -> (B, H*W, C)

5. 自注意力机制深入分析

5.1 注意力计算详解

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    缩放点积注意力
    """
    d_k = Q.size(-1)
    
    # 计算注意力分数
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    
    if mask is not None:
        scores.masked_fill_(mask == 0, -1e9)
    
    # 应用softmax
    attention_weights = F.softmax(scores, dim=-1)
    
    # 应用注意力到V
    output = torch.matmul(attention_weights, V)
    
    return output, attention_weights

class SelfAttentionWithVisualization(nn.Module):
    """
    带可视化功能的自注意力
    """
    def __init__(self, embed_dim=768, n_heads=12):
        super().__init__()
        self.multihead_attn = nn.MultiheadAttention(
            embed_dim, n_heads, batch_first=True
        )

    def forward(self, x):
        # 返回注意力权重用于可视化
        output, attn_weights = self.multihead_attn(x, x, x)
        return output, attn_weights

5.2 注意力热力图可视化

import matplotlib.pyplot as plt
import seaborn as sns

def visualize_attention(attention_weights, patch_size=16):
    """
    可视化注意力热力图
    attention_weights: (batch_size, n_heads, seq_len, seq_len)
    """
    # 取第一个样本的第一个头
    attn = attention_weights[0, 0]  # (seq_len, seq_len)
    
    # 可视化CLS token对其他patch的注意力
    cls_attn = attn[0, 1:]  # 排除CLS token自身,只看对其他patch的注意力
    
    # 重塑为图像形状
    img_size = int(math.sqrt(len(cls_attn)))
    cls_attn_img = cls_attn.reshape(img_size, img_size)
    
    plt.figure(figsize=(8, 8))
    sns.heatmap(cls_attn_img.detach().cpu().numpy(), cmap='viridis')
    plt.title('CLS Token Attention Heatmap')
    plt.show()

6. ViT变体与改进

6.1 DeiT (Data-efficient Image Transformer)

DeiT通过知识蒸馏提高了ViT在小数据集上的性能:

class DistillationToken(nn.Module):
    """
    知识蒸馏token
    """
    def __init__(self, embed_dim=768):
        super().__init__()
        self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

class DeiT(nn.Module):
    """
    Data-efficient Image Transformer
    """
    def __init__(self, vit_model, num_classes=1000):
        super().__init__()
        self.vit = vit_model
        self.dist_token = nn.Parameter(torch.zeros(1, 1, vit_model.embed_dim))
        self.head_dist = nn.Linear(vit_model.embed_dim, num_classes)
        
    def forward(self, x):
        B = x.shape[0]
        
        # 添加distillation token
        x = self.vit.patch_embed(x)
        cls_tokens = self.vit.cls_token.expand(B, -1, -1)
        dist_tokens = self.dist_token.expand(B, -1, -1)
        
        x = torch.cat((cls_tokens, dist_tokens, x), dim=1)
        x = x + self.vit.pos_embed
        
        for block in self.vit.blocks:
            x = block(x)
        
        x = self.vit.norm(x)
        
        # 两个输出头
        cls_output = self.vit.head(x[:, 0])
        dist_output = self.head_dist(x[:, 1])
        
        return cls_output, dist_output

6.2 Swin Transformer

Swin Transformer引入了滑动窗口机制:

class WindowAttention(nn.Module):
    """
    窗口注意力机制
    """
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        # 相对位置偏置
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
        )

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)

    def forward(self, x, mask=None):
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        # 添加相对位置偏置
        relative_position_bias = self.relative_position_bias_table[
            self.relative_position_index.view(-1)
        ].view(self.window_size[0] * self.window_size[1],
               self.window_size[0] * self.window_size[1], -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = F.softmax(attn, dim=-1)
        else:
            attn = F.softmax(attn, dim=-1)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

7. 训练策略与优化

7.1 数据增强策略

import torchvision.transforms as transforms

def get_vit_transforms():
    """
    ViT专用数据增强策略
    """
    return transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        # RandAugment等更高级的数据增强
    ])

7.2 优化器配置

def get_vit_optimizer(model, lr=1e-3, weight_decay=0.05):
    """
    ViT专用优化器配置
    """
    param_groups = []
    
    # 对不同类型的参数使用不同的权重衰减
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        
        if 'norm' in name or 'bias' in name:
            # 归一化层和偏置不使用权重衰减
            param_groups.append({'params': param, 'weight_decay': 0.0})
        else:
            param_groups.append({'params': param, 'weight_decay': weight_decay})
    
    optimizer = torch.optim.AdamW(param_groups, lr=lr, betas=(0.9, 0.999))
    return optimizer

7.3 学习率调度

from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR

def get_vit_scheduler(optimizer, num_epochs, warmup_epochs=10):
    """
    ViT学习率调度策略
    """
    def warmup_lr_lambda(current_epoch):
        if current_epoch < warmup_epochs:
            return float(current_epoch) / float(max(1, warmup_epochs))
        return 1.0
    
    scheduler = LambdaLR(optimizer, lr_lambda=warmup_lr_lambda)
    return scheduler

8. 性能对比与选择指南

8.1 ViT vs CNN对比

特性CNN (如ResNet)ViT (Vision Transformer)
数据量需求中小规模数据表现良好极度依赖大数据 (如ImageNet-21k)
全局感知需要深层堆叠获得第一层就能看到全局
训练难度相对容易收敛训练较慢,对超参数敏感
可解释性关注局部边缘/纹理关注图像各部分间的逻辑关联
计算复杂度O(n)O(n²) 对于标准注意力
参数效率参数利用率相对较低在大数据集上参数效率更高

8.2 应用场景选择

def choose_model_for_task(task_requirements):
    """
    根据任务需求选择合适的模型
    """
    if task_requirements['data_size'] < 10000:
        return "建议使用CNN或DeiT(带知识蒸馏)"
    elif task_requirements['compute_budget'] < 100:  # GFLOPs
        return "建议使用EfficientNet或MobileViT"
    elif task_requirements['accuracy_priority']:
        return "建议使用大型ViT模型(如ViT-H/14)"
    else:
        return "建议使用中等规模ViT模型(如ViT-B/16)"

9. 实际应用与部署

9.1 模型微调

def finetune_vit_on_custom_dataset(pretrained_vit, num_classes_new):
    """
    在自定义数据集上微调预训练ViT
    """
    # 冻结主干网络
    for param in pretrained_vit.parameters():
        param.requires_grad = False
    
    # 替换分类头
    pretrained_vit.head = nn.Linear(pretrained_vit.head.in_features, num_classes_new)
    
    # 只训练分类头
    for param in pretrained_vit.head.parameters():
        param.requires_grad = True
    
    return pretrained_vit

9.2 模型压缩与加速

def compress_vit_model(model, compression_ratio=0.5):
    """
    ViT模型压缩
    """
    # 模型量化
    quantized_model = torch.quantization.quantize_dynamic(
        model, {nn.Linear, nn.Conv2d}, dtype=torch.qint8
    )
    
    # 知识蒸馏(如果teacher模型可用)
    # 这里省略具体实现
    
    # 剪枝
    import torch.nn.utils.prune as prune
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=compression_ratio)
    
    return quantized_model

10. 实践建议

10.1 数据准备建议

  • 大规模数据集:ViT需要大量数据才能发挥优势
  • 高质量标注:确保数据质量,避免噪声标签
  • 数据增强:使用RandAugment、CutMix、MixUp等高级增强技术
  • 预处理一致性:确保训练和推理时的预处理完全一致

10.2 模型调优建议

  • 预训练权重:优先使用在大规模数据集上预训练的权重
  • 学习率策略:使用warmup和cosine decay
  • 正则化:适当使用dropout和weight decay
  • 批归一化:考虑使用更大的batch size

10.3 部署考虑

  • 推理优化:使用TensorRT、ONNX等优化推理
  • 模型压缩:量化、剪枝以减小模型大小
  • 硬件适配:考虑GPU/TPU等硬件的特性
  • 延迟优化:针对实时应用进行延迟优化

11. 发展趋势与未来方向

11.1 技术趋势

  • 混合架构:CNN与Transformer的结合(如CoAtNet)
  • 高效注意力:线性复杂度的注意力机制
  • 多模态融合:图像与文本的统一表示
  • 自监督学习:MAE、SimMIM等预训练方法

11.2 挑战与机遇

  • 计算效率:降低大规模模型的计算成本
  • 可解释性:提高模型决策的透明度
  • 鲁棒性:增强模型对对抗攻击的防御能力
  • 持续学习:支持模型的增量学习能力

12. 总结

Vision Transformer作为计算机视觉领域的里程碑式创新,通过将Transformer架构引入图像处理,展现了强大的建模能力。其核心优势在于:

  • 全局感受野:自注意力机制提供全局信息整合能力
  • 可扩展性:在大规模数据集上表现卓越
  • 统一架构:为视觉任务提供统一的建模框架

通过本文的详细分析和代码实现,读者应该对ViT的核心原理、架构设计和实际应用有了深入的理解。在实际项目中,应根据数据规模、计算资源和性能要求选择合适的模型架构。


相关教程

建议先理解Transformer在NLP中的应用,再学习ViT的视觉应用。通过复现论文中的实验,可以更好地理解ViT的工作原理和优势。

🔗 扩展阅读