Vision Transformer:从图像切片到Patch Embedding详解

引言

Vision Transformer (ViT) 是计算机视觉领域的一项重要突破,它成功地将原本用于自然语言处理的Transformer架构应用于图像分类任务。ViT通过将图像分割成小块(patches)并将其作为序列输入到Transformer中,证明了纯Transformer架构在视觉任务上的有效性。本文将深入探讨ViT的核心概念、架构细节和实现方法。

📂 所属阶段:第二阶段 — 深度学习视觉基础(CNN 篇)
🔗 相关章节:关键点检测 (Keypoints) · Swin Transformer


1. Vision Transformer核心思想

1.1 ViT的创新理念

Vision Transformer的核心创新在于将计算机视觉问题转化为序列建模问题,主要包括以下几点:

  1. 图像分块:将图像分割成固定大小的patch序列,每个patch被视为一个"token",类似于NLP中的词汇序列
  2. 序列建模:将图像patch序列输入Transformer,利用自注意力机制捕获全局关系
  3. 全局连接:每个patch与所有其他patch直接连接,从第一层就能捕获全局信息

1.2 ViT架构演进

ViT的发展历程与Transformer在NLP领域的成功密不可分:

  • 2017年:"Attention Is All You Need" 提出Transformer架构
  • 2018年:BERT在NLP领域取得突破性进展
  • 2020年:Vision Transformer首次成功应用到视觉任务
  • 2021年:Swin Transformer提出层次化ViT架构
  • 2022年至今:各种高效ViT变体和多模态Vision Transformers不断涌现

2. ViT架构详解

2.1 图像分块与Patch Embedding

图像分块是ViT的第一个关键步骤,将二维图像转换为一维序列。下面是具体实现:

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

class ImageToPatches(nn.Module):
    """
    图像到patch的转换模块
    """
    def __init__(self, image_size=224, patch_size=16, channels=3):
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.patch_dim = channels * patch_size ** 2
        
        # 线性投影层
        self.projection = nn.Linear(self.patch_dim, self.patch_dim)
    
    def forward(self, x):
        """
        x: (batch, channels, height, width)
        return: (batch, num_patches, patch_dim)
        """
        batch_size, channels, height, width = x.shape
        
        # 验证图像尺寸
        assert height == self.image_size and width == self.image_size, \
            f"输入图像尺寸应为 ({self.image_size}, {self.image_size})"
        
        # 将图像分割成patch
        x = rearrange(
            x, 
            'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
            p1=self.patch_size,
            p2=self.patch_size
        )
        
        # 线性投影
        x = self.projection(x)
        
        return x

Patch Embedding过程详解

  1. 输入图像: (B, C, H, W) = (B, 3, 224, 224)
  2. 分割patch: (B, 3, 14×16, 14×16) → (B, 14×14, 16×16×3)
  3. 每个patch: 16×16×3 = 768维向量
  4. 输出: (B, 196, 768) # 196个patch,每个768维

2.2 ViT完整架构实现

下面是Vision Transformer的完整实现:

class VisionTransformer(nn.Module):
    """
    Vision Transformer完整实现
    """
    def __init__(
        self,
        image_size=224,
        patch_size=16,
        num_classes=1000,
        dim=768,
        depth=12,
        heads=12,
        mlp_dim=3072,
        dropout=0.1,
        emb_dropout=0.1
    ):
        super(VisionTransformer, self).__init__()
        
        # 计算patch数量和维度
        num_patches = (image_size // patch_size) ** 2
        patch_dim = 3 * patch_size ** 2
        
        # 图像到patch的投影
        self.to_patch_embedding = nn.Sequential(
            nn.Conv2d(3, patch_dim, kernel_size=patch_size, stride=patch_size),
            nn.Flatten(start_dim=2),
            nn.Linear(patch_dim, dim)
        )
        
        # 类别token和位置嵌入
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        
        # Dropout
        self.dropout = nn.Dropout(emb_dropout)
        
        # Transformer编码器
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=dim,
            nhead=heads,
            dim_feedforward=mlp_dim,
            dropout=dropout,
            activation='gelu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        
        # 分类头
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )
    
    def forward(self, img):
        x = self.to_patch_embedding(img)  # (B, num_patches, dim)
        b, n, _ = x.shape
        
        # 添加类别token
        cls_tokens = self.cls_token.repeat(b, 1, 1)  # (B, 1, dim)
        x = torch.cat([cls_tokens, x], dim=1)  # (B, num_patches+1, dim)
        
        # 添加位置编码
        x = x + self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)
        
        # Transformer编码
        x = self.transformer(x)
        
        # 使用CLS token进行分类
        cls_output = x[:, 0]  # (B, dim)
        output = self.mlp_head(cls_output)
        
        return output

ViT架构组件

  • Patch Embedding Layer:将图像转换为patch序列
  • Class Token:用于最终分类的特殊token
  • Positional Embedding:保持空间位置信息
  • Transformer Encoder:自注意力机制处理
  • MLP Head:分类输出层

2.3 注意力机制详解

多头自注意力机制是Transformer的核心,下面是具体实现:

class MultiHeadAttention(nn.Module):
    """
    多头自注意力机制
    """
    def __init__(self, d_model=768, num_heads=12, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # 线性投影层
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([self.d_k])).to('cuda' if torch.cuda.is_available() else 'cpu')
    
    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        
        # 线性投影
        Q = self.W_q(x)  # (B, seq_len, d_model)
        K = self.W_k(x)  # (B, seq_len, d_model)
        V = self.W_v(x)  # (B, seq_len, d_model)
        
        # 分割成多个头
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)  # (B, num_heads, seq_len, d_k)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)  # (B, num_heads, seq_len, d_k)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)  # (B, num_heads, seq_len, d_k)
        
        # 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale  # (B, num_heads, seq_len, seq_len)
        attention = F.softmax(scores, dim=-1)
        attention = self.dropout(attention)
        
        # 应用注意力到V
        output = torch.matmul(attention, V)  # (B, num_heads, seq_len, d_k)
        
        # 合并多个头
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        # 最终线性投影
        output = self.W_o(output)
        
        return output

3. 位置编码与Class Token

3.1 位置编码策略

ViT中常用的位置编码类型包括:

  • Learnable PE:可学习的位置嵌入参数
  • Sinusoidal PE:正弦余弦函数编码
  • 2D PE:二维空间位置编码
  • Rotary PE:旋转位置编码

3.2 Class Token的作用

Class Token是ViT中的一个关键设计,它的主要作用包括:

  1. 聚合信息:收集所有patch的信息
  2. 分类表示:作为图像的全局表示
  3. 梯度流动:为反向传播提供路径
  4. 位置不变:与所有patch交互

4. ViT变体与改进

4.1 DeiT (Data-efficient Image Transformer)

DeiT通过知识蒸馏提高了ViT的训练效率,主要改进包括:

  • 蒸馏token:使用教师模型指导训练
  • 更强的数据增强:RandAugment, Mixup等
  • 训练技巧:AdamW优化器,学习率调度
  • 正则化:Dropout, Stochastic Depth

4.2 Efficient Vision Transformers

为了提高ViT的效率,研究人员提出了多种高效变体:

  • MobileViT:轻量级架构,适合移动设备
  • Twins:Spatial Attention + Sequential Self-Attention
  • PVT:Pyramid Vision Transformer
  • Shuffle Transformer:通道混洗注意力
  • CMT:Convolutional Neural Networks Meet Vision Transformers

5. 预训练模型使用

5.1 使用PyTorch预训练ViT

import torch
import torchvision.models as models
from torchvision import transforms

# 加载预训练ViT模型
vit_b_16 = models.vit_b_16(weights='IMAGENET1K_V1')
model = vit_b_16
model.eval()

# 预处理图像
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225]),
])

# 推理
with torch.no_grad():
    output = model(input_tensor)
    probabilities = torch.nn.functional.softmax(output[0], dim=0)

5.2 使用Hugging Face Transformers

from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import requests

# 加载处理器和模型
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')

# 处理图像
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(images=image, return_tensors="pt")

# 推理
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])

6. ViT vs CNN对比分析

6.1 详细对比

特性CNNViT
感受野逐层增长,局部到全局全局连接,从第一层开始
参数效率参数较少,计算效率高参数较多,需要大数据集
数据需求小数据集上表现良好需要大规模预训练
可解释性可视化困难,可解释性低注意力权重可直接可视化
计算复杂度O(n),其中n是像素数O(n²),其中n是patch数
归纳偏置强归纳偏置(平移不变性、局部性)弱归纳偏置,更通用
扩展性扩展受限于架构设计容易扩展到更大规模

6.2 适用场景分析

CNN适用场景

  • 小数据集训练
  • 实时推理要求高
  • 移动设备部署
  • 边缘计算场景
  • 传统视觉任务

ViT适用场景

  • 大数据集预训练
  • 需要全局信息的任务
  • 可解释性要求高
  • 多模态任务
  • 研究和前沿应用

7. 实践应用与调优

7.1 训练技巧

  1. 使用大规模数据集预训练
  2. 采用AdamW优化器,权重衰减
  3. 学习率预热和余弦退火调度
  4. 强数据增强(RandAugment, Mixup)
  5. Dropout和标签平滑正则化
  6. 知识蒸馏提升小模型性能

7.2 性能优化

  • 模型量化:INT8量化减少内存占用
  • 知识蒸馏:用大模型训练小模型
  • 模型剪枝:移除冗余参数
  • 混合精度训练:节省显存加速训练
  • 序列长度优化:减少patch数量
  • 注意力稀疏化:降低计算复杂度

相关教程

ViT是计算机视觉的重要里程碑。建议先理解Transformer基础,再学习ViT的具体实现。在实践中可以从预训练模型开始,逐步了解架构细节。

8. 总结

Vision Transformer代表了计算机视觉的新范式:

核心创新

  1. 图像分块:将图像转换为序列
  2. 全局注意力:捕获长距离依赖
  3. 可扩展性:容易扩展到大规模

关键技术

  • Patch Embedding
  • Class Token
  • Positional Encoding
  • Multi-head Attention

💡 重要提醒:ViT证明了纯Transformer架构在视觉任务上的有效性,开启了视觉与语言统一建模的新时代。

🔗 扩展阅读