Vision Transformer:从图像切片到Patch Embedding详解

引言

Vision Transformer (ViT) 是计算机视觉领域的一项重要突破,它成功地将原本用于自然语言处理的Transformer架构应用于图像分类任务。ViT通过将图像分割成小块(patches)并将其作为序列输入到Transformer中,证明了纯Transformer架构在视觉任务上的有效性。本文将深入探讨ViT的核心概念、架构细节和实现方法。

📂 所属阶段:第二阶段 — 深度学习视觉基础(CNN 篇)
🔗 相关章节:关键点检测 (Keypoints) · Swin Transformer


1. Vision Transformer核心思想

1.1 ViT的创新理念

Vision Transformer的核心创新在于将计算机视觉问题转化为序列建模问题。

"""
Vision Transformer的核心思想:

1. 图像分块 (Image Patching):
   - 将图像分割成固定大小的patch序列
   - 每个patch被视为一个"token"
   - 类似于NLP中的词汇序列

2. 序列建模:
   - 将图像patch序列输入Transformer
   - 利用自注意力机制捕获全局关系
   - 避免了CNN的局部感受野限制

3. 全局连接:
   - 每个patch与所有其他patch直接连接
   - 从第一层就能捕获全局信息
"""

def vit_concept_explanation():
    """
    ViT核心概念解释
    """
    concepts = {
        "Image Patching": "将图像分割成固定大小的块",
        "Patch Embedding": "将图像块映射到嵌入空间",
        "Class Token": "用于分类的特殊token",
        "Positional Encoding": "保持位置信息",
        "Global Attention": "捕获全局特征关系"
    }
    
    print("ViT核心概念:")
    for concept, desc in concepts.items():
        print(f"• {concept}: {desc}")

vit_concept_explanation()

1.2 ViT架构演进

def vit_development_timeline():
    """
    ViT发展时间线
    """
    timeline = {
        "2017": "Attention Is All You Need (Transformer提出)",
        "2018": "BERT (NLP领域的突破)",
        "2020": "Vision Transformer (ViT首次应用到视觉)",
        "2021": "Swin Transformer (层次化ViT)",
        "2022": "Efficient Vision Transformers (高效架构)",
        "2023": "Multimodal Vision Transformers (多模态)",
        "2024-2026": "Vision Foundation Models (视觉基础模型)"
    }
    
    print("ViT发展时间线:")
    for year, milestone in timeline.items():
        print(f"• {year}: {milestone}")

vit_development_timeline()

2. ViT架构详解

2.1 图像分块与Patch Embedding

图像分块是ViT的第一个关键步骤,将二维图像转换为一维序列。

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

class ImageToPatches(nn.Module):
    """
    图像到patch的转换模块
    """
    def __init__(self, image_size=224, patch_size=16, channels=3):
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.patch_dim = channels * patch_size ** 2
        
        # 线性投影层
        self.projection = nn.Linear(self.patch_dim, self.patch_dim)
    
    def forward(self, x):
        """
        x: (batch, channels, height, width)
        return: (batch, num_patches, patch_dim)
        """
        batch_size, channels, height, width = x.shape
        
        # 验证图像尺寸
        assert height == self.image_size and width == self.image_size, \
            f"输入图像尺寸应为 ({self.image_size}, {self.image_size})"
        
        # 将图像分割成patch
        x = rearrange(
            x, 
            'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
            p1=self.patch_size,
            p2=self.patch_size
        )
        
        # 线性投影
        x = self.projection(x)
        
        return x

def patch_embedding_process():
    """
    Patch Embedding过程详解
    """
    print("Patch Embedding过程:")
    print("1. 输入图像: (B, C, H, W) = (B, 3, 224, 224)")
    print("2. 分割patch: (B, 3, 14*16, 14*16) -> (B, 14*14, 16*16*3)")
    print("3. 每个patch: 16x16x3 = 768维向量")
    print("4. 输出: (B, 196, 768) # 196个patch,每个768维")

patch_embedding_process()

2.2 ViT完整架构实现

class VisionTransformer(nn.Module):
    """
    Vision Transformer完整实现
    """
    def __init__(
        self,
        image_size=224,
        patch_size=16,
        num_classes=1000,
        dim=768,
        depth=12,
        heads=12,
        mlp_dim=3072,
        dropout=0.1,
        emb_dropout=0.1
    ):
        super(VisionTransformer, self).__init__()
        
        # 计算patch数量和维度
        num_patches = (image_size // patch_size) ** 2
        patch_dim = 3 * patch_size ** 2
        
        # 图像到patch的投影
        self.to_patch_embedding = nn.Sequential(
            nn.Conv2d(3, patch_dim, kernel_size=patch_size, stride=patch_size),
            nn.Flatten(start_dim=2),
            nn.Linear(patch_dim, dim)
        )
        
        # 类别token和位置嵌入
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        
        # Dropout
        self.dropout = nn.Dropout(emb_dropout)
        
        # Transformer编码器
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=dim,
            nhead=heads,
            dim_feedforward=mlp_dim,
            dropout=dropout,
            activation='gelu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        
        # 分类头
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )
    
    def forward(self, img):
        x = self.to_patch_embedding(img)  # (B, num_patches, dim)
        b, n, _ = x.shape
        
        # 添加类别token
        cls_tokens = self.cls_token.repeat(b, 1, 1)  # (B, 1, dim)
        x = torch.cat([cls_tokens, x], dim=1)  # (B, num_patches+1, dim)
        
        # 添加位置编码
        x = x + self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)
        
        # Transformer编码
        x = self.transformer(x)
        
        # 使用CLS token进行分类
        cls_output = x[:, 0]  # (B, dim)
        output = self.mlp_head(cls_output)
        
        return output

def vit_architecture_components():
    """
    ViT架构组件详解
    """
    components = {
        "Patch Embedding Layer": "将图像转换为patch序列",
        "Class Token": "用于最终分类的特殊token",
        "Positional Embedding": "保持空间位置信息",
        "Transformer Encoder": "自注意力机制处理",
        "MLP Head": "分类输出层"
    }
    
    print("ViT架构组件:")
    for component, desc in components.items():
        print(f"• {component}: {desc}")

vit_architecture_components()

2.3 注意力机制详解

class MultiHeadAttention(nn.Module):
    """
    多头自注意力机制
    """
    def __init__(self, d_model=768, num_heads=12, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # 线性投影层
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([self.d_k])).to('cuda' if torch.cuda.is_available() else 'cpu')
    
    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        
        # 线性投影
        Q = self.W_q(x)  # (B, seq_len, d_model)
        K = self.W_k(x)  # (B, seq_len, d_model)
        V = self.W_v(x)  # (B, seq_len, d_model)
        
        # 分割成多个头
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)  # (B, num_heads, seq_len, d_k)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)  # (B, num_heads, seq_len, d_k)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)  # (B, num_heads, seq_len, d_k)
        
        # 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale  # (B, num_heads, seq_len, seq_len)
        attention = F.softmax(scores, dim=-1)
        attention = self.dropout(attention)
        
        # 应用注意力到V
        output = torch.matmul(attention, V)  # (B, num_heads, seq_len, d_k)
        
        # 合并多个头
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        # 最终线性投影
        output = self.W_o(output)
        
        return output

def attention_mechanism_explanation():
    """
    注意力机制解释
    """
    print("多头自注意力机制:")
    print("1. Query, Key, Value: 三个线性投影")
    print("2. 注意力分数: Q*K^T / sqrt(d_k)")
    print("3. Softmax: 归一化注意力权重")
    print("4. 输出: Attention * V")
    print("5. 多头: 并行计算多个注意力头")
    print("6. 拼接: 合并所有头的输出")

attention_mechanism_explanation()

3. 位置编码与Class Token

3.1 位置编码策略

class PositionalEncoding(nn.Module):
    """
    位置编码模块
    """
    def __init__(self, d_model=768, max_len=5000):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                            -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        # x: (batch_size, seq_len, d_model)
        return x + self.pe[:, :x.size(1)]

import math

def positional_encoding_types():
    """
    位置编码类型
    """
    encodings = {
        "Learnable PE": "可学习的位置嵌入参数",
        "Sinusoidal PE": "正弦余弦函数编码",
        "2D PE": "二维空间位置编码",
        "Rotary PE": "旋转位置编码"
    }
    
    print("位置编码类型:")
    for encoding, desc in encodings.items():
        print(f"• {encoding}: {desc}")

positional_encoding_types()

3.2 Class Token的作用

def class_token_purpose():
    """
    Class Token的作用解释
    """
    print("Class Token的作用:")
    print("1. 聚合信息: 收集所有patch的信息")
    print("2. 分类表示: 作为图像的全局表示")
    print("3. 梯度流动: 为反向传播提供路径")
    print("4. 位置不变: 与所有patch交互")

class_token_purpose()

4. ViT变体与改进

4.1 DeiT (Data-efficient Image Transformer)

DeiT通过知识蒸馏提高了ViT的训练效率。

class DistillationToken(nn.Module):
    """
    知识蒸馏token
    """
    def __init__(self, dim=768):
        super().__init__()
        self.dist_token = nn.Parameter(torch.randn(1, 1, dim))
    
    def forward(self, x):
        # x: (B, num_patches+1, dim) - already includes cls token
        b, n, d = x.shape
        dist_tokens = self.dist_token.repeat(b, 1, 1)  # (B, 1, dim)
        x = torch.cat([x, dist_tokens], dim=1)  # (B, num_patches+2, dim)
        return x

def deit_improvements():
    """
    DeiT改进措施
    """
    improvements = [
        "蒸馏token: 使用教师模型指导训练",
        "更强的数据增强: RandAugment, Mixup等",
        "训练技巧: AdamW优化器,学习率调度",
        "正则化: Dropout, Stochastic Depth"
    ]
    
    print("DeiT改进措施:")
    for imp in improvements:
        print(f"• {imp}")

deit_improvements()

4.2 Efficient Vision Transformers

def efficient_vit_variants():
    """
    高效ViT变体
    """
    variants = {
        "MobileViT": "轻量级架构,适合移动设备",
        "Twins": "Spatial Attention + Sequential Self-Attention",
        "PVT": "Pyramid Vision Transformer",
        "Shuffle Transformer": "通道混洗注意力",
        "CMT": "Convolutional Neural Networks Meet Vision Transformers"
    }
    
    print("高效ViT变体:")
    for variant, desc in variants.items():
        print(f"• {variant}: {desc}")

efficient_vit_variants()

5. 预训练模型使用

5.1 使用PyTorch预训练ViT

def use_pretrained_vit():
    """
    使用预训练ViT模型
    """
    print("使用PyTorch预训练ViT:")
    print("""
import torch
import torchvision.models as models

# 加载预训练ViT模型
vit_b_16 = models.vit_b_16(weights='IMAGENET1K_V1')
vit_b_32 = models.vit_b_32(weights='IMAGENET1K_V1')
vit_l_16 = models.vit_l_16(weights='IMAGENET1K_V1')

# 推理示例
model = vit_b_16
model.eval()

# 预处理图像
from torchvision import transforms
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225]),
])

# 推理
with torch.no_grad():
    output = model(input_tensor)
    probabilities = torch.nn.functional.softmax(output[0], dim=0)
""")

use_pretrained_vit()

5.2 使用Hugging Face Transformers

def huggingface_vit():
    """
    使用Hugging Face的Vision Transformer
    """
    print("使用Hugging Face Transformers:")
    print("""
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import requests

# 加载处理器和模型
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')

# 处理图像
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(images=image, return_tensors="pt")

# 推理
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])
""")

huggingface_vit()

6. ViT vs CNN对比分析

6.1 详细对比表格

def detailed_comparison():
    """
    ViT与CNN详细对比
    """
    comparison = {
        "感受野": {
            "CNN": "逐层增长,局部到全局",
            "ViT": "全局连接,从第一层开始"
        },
        "参数效率": {
            "CNN": "参数较少,计算效率高",
            "ViT": "参数较多,需要大数据集"
        },
        "数据需求": {
            "CNN": "小数据集上表现良好",
            "ViT": "需要大规模预训练"
        },
        "可解释性": {
            "CNN": "可视化困难,可解释性低",
            "ViT": "注意力权重可直接可视化"
        },
        "计算复杂度": {
            "CNN": "O(n),其中n是像素数",
            "ViT": "O(n²),其中n是patch数"
        },
        "归纳偏置": {
            "CNN": "强归纳偏置(平移不变性、局部性)",
            "ViT": "弱归纳偏置,更通用"
        },
        "扩展性": {
            "CNN": "扩展受限于架构设计",
            "ViT": "容易扩展到更大规模"
        }
    }
    
    print("ViT vs CNN 详细对比:")
    for aspect, methods in comparison.items():
        print(f"• {aspect}:")
        print(f"  - CNN: {methods['CNN']}")
        print(f"  - ViT: {methods['ViT']}")

detailed_comparison()

6.2 适用场景分析

def use_case_analysis():
    """
    ViT和CNN适用场景分析
    """
    scenarios = {
        "CNN适用场景": [
            "小数据集训练",
            "实时推理要求高",
            "移动设备部署",
            "边缘计算场景",
            "传统视觉任务"
        ],
        "ViT适用场景": [
            "大数据集预训练",
            "需要全局信息的任务",
            "可解释性要求高",
            "多模态任务",
            "研究和前沿应用"
        ],
        "混合架构": [
            "CvT: CNN + Transformer",
            "CoAtNet: Convolution + Attention",
            "ConViT: Convolutional Vision Transformer"
        ]
    }
    
    for category, items in scenarios.items():
        print(f"{category}:")
        for item in items:
            print(f"  • {item}")
        print()

use_case_analysis()

7. 实践应用与调优

7.1 训练技巧

def training_tips():
    """
    ViT训练技巧
    """
    tips = [
        "使用大规模数据集预训练",
        "采用AdamW优化器,权重衰减",
        "学习率预热和余弦退火调度",
        "强数据增强(RandAugment, Mixup)",
        "Dropout和标签平滑正则化",
        "知识蒸馏提升小模型性能"
    ]
    
    print("ViT训练技巧:")
    for i, tip in enumerate(tips, 1):
        print(f"{i}. {tip}")

training_tips()

7.2 性能优化

def performance_optimization():
    """
    ViT性能优化策略
    """
    optimizations = [
        "模型量化: INT8量化减少内存占用",
        "知识蒸馏: 用大模型训练小模型",
        "模型剪枝: 移除冗余参数",
        "混合精度训练: 节省显存加速训练",
        "序列长度优化: 减少patch数量",
        "注意力稀疏化: 降低计算复杂度"
    ]
    
    print("ViT性能优化策略:")
    for opt in optimizations:
        print(f"• {opt}")

performance_optimization()

相关教程

ViT是计算机视觉的重要里程碑。建议先理解Transformer基础,再学习ViT的具体实现。在实践中可以从预训练模型开始,逐步了解架构细节。

8. 总结

Vision Transformer代表了计算机视觉的新范式:

核心创新:

  1. 图像分块:将图像转换为序列
  2. 全局注意力:捕获长距离依赖
  3. 可扩展性:容易扩展到大规模

关键技术:

  • Patch Embedding
  • Class Token
  • Positional Encoding
  • Multi-head Attention

💡 重要提醒:ViT证明了纯Transformer架构在视觉任务上的有效性,开启了视觉与语言统一建模的新时代。

🔗 扩展阅读