#Vision Transformer:从图像切片到Patch Embedding详解
#引言
Vision Transformer (ViT) 是计算机视觉领域的一项重要突破,它成功地将原本用于自然语言处理的Transformer架构应用于图像分类任务。ViT通过将图像分割成小块(patches)并将其作为序列输入到Transformer中,证明了纯Transformer架构在视觉任务上的有效性。本文将深入探讨ViT的核心概念、架构细节和实现方法。
📂 所属阶段:第二阶段 — 深度学习视觉基础(CNN 篇)
🔗 相关章节:关键点检测 (Keypoints) · Swin Transformer
#1. Vision Transformer核心思想
#1.1 ViT的创新理念
Vision Transformer的核心创新在于将计算机视觉问题转化为序列建模问题。
"""
Vision Transformer的核心思想:
1. 图像分块 (Image Patching):
- 将图像分割成固定大小的patch序列
- 每个patch被视为一个"token"
- 类似于NLP中的词汇序列
2. 序列建模:
- 将图像patch序列输入Transformer
- 利用自注意力机制捕获全局关系
- 避免了CNN的局部感受野限制
3. 全局连接:
- 每个patch与所有其他patch直接连接
- 从第一层就能捕获全局信息
"""
def vit_concept_explanation():
"""
ViT核心概念解释
"""
concepts = {
"Image Patching": "将图像分割成固定大小的块",
"Patch Embedding": "将图像块映射到嵌入空间",
"Class Token": "用于分类的特殊token",
"Positional Encoding": "保持位置信息",
"Global Attention": "捕获全局特征关系"
}
print("ViT核心概念:")
for concept, desc in concepts.items():
print(f"• {concept}: {desc}")
vit_concept_explanation()#1.2 ViT架构演进
def vit_development_timeline():
"""
ViT发展时间线
"""
timeline = {
"2017": "Attention Is All You Need (Transformer提出)",
"2018": "BERT (NLP领域的突破)",
"2020": "Vision Transformer (ViT首次应用到视觉)",
"2021": "Swin Transformer (层次化ViT)",
"2022": "Efficient Vision Transformers (高效架构)",
"2023": "Multimodal Vision Transformers (多模态)",
"2024-2026": "Vision Foundation Models (视觉基础模型)"
}
print("ViT发展时间线:")
for year, milestone in timeline.items():
print(f"• {year}: {milestone}")
vit_development_timeline()#2. ViT架构详解
#2.1 图像分块与Patch Embedding
图像分块是ViT的第一个关键步骤,将二维图像转换为一维序列。
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class ImageToPatches(nn.Module):
"""
图像到patch的转换模块
"""
def __init__(self, image_size=224, patch_size=16, channels=3):
super().__init__()
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = (image_size // patch_size) ** 2
self.patch_dim = channels * patch_size ** 2
# 线性投影层
self.projection = nn.Linear(self.patch_dim, self.patch_dim)
def forward(self, x):
"""
x: (batch, channels, height, width)
return: (batch, num_patches, patch_dim)
"""
batch_size, channels, height, width = x.shape
# 验证图像尺寸
assert height == self.image_size and width == self.image_size, \
f"输入图像尺寸应为 ({self.image_size}, {self.image_size})"
# 将图像分割成patch
x = rearrange(
x,
'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
p1=self.patch_size,
p2=self.patch_size
)
# 线性投影
x = self.projection(x)
return x
def patch_embedding_process():
"""
Patch Embedding过程详解
"""
print("Patch Embedding过程:")
print("1. 输入图像: (B, C, H, W) = (B, 3, 224, 224)")
print("2. 分割patch: (B, 3, 14*16, 14*16) -> (B, 14*14, 16*16*3)")
print("3. 每个patch: 16x16x3 = 768维向量")
print("4. 输出: (B, 196, 768) # 196个patch,每个768维")
patch_embedding_process()#2.2 ViT完整架构实现
class VisionTransformer(nn.Module):
"""
Vision Transformer完整实现
"""
def __init__(
self,
image_size=224,
patch_size=16,
num_classes=1000,
dim=768,
depth=12,
heads=12,
mlp_dim=3072,
dropout=0.1,
emb_dropout=0.1
):
super(VisionTransformer, self).__init__()
# 计算patch数量和维度
num_patches = (image_size // patch_size) ** 2
patch_dim = 3 * patch_size ** 2
# 图像到patch的投影
self.to_patch_embedding = nn.Sequential(
nn.Conv2d(3, patch_dim, kernel_size=patch_size, stride=patch_size),
nn.Flatten(start_dim=2),
nn.Linear(patch_dim, dim)
)
# 类别token和位置嵌入
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
# Dropout
self.dropout = nn.Dropout(emb_dropout)
# Transformer编码器
encoder_layer = nn.TransformerEncoderLayer(
d_model=dim,
nhead=heads,
dim_feedforward=mlp_dim,
dropout=dropout,
activation='gelu',
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
# 分类头
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img) # (B, num_patches, dim)
b, n, _ = x.shape
# 添加类别token
cls_tokens = self.cls_token.repeat(b, 1, 1) # (B, 1, dim)
x = torch.cat([cls_tokens, x], dim=1) # (B, num_patches+1, dim)
# 添加位置编码
x = x + self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
# Transformer编码
x = self.transformer(x)
# 使用CLS token进行分类
cls_output = x[:, 0] # (B, dim)
output = self.mlp_head(cls_output)
return output
def vit_architecture_components():
"""
ViT架构组件详解
"""
components = {
"Patch Embedding Layer": "将图像转换为patch序列",
"Class Token": "用于最终分类的特殊token",
"Positional Embedding": "保持空间位置信息",
"Transformer Encoder": "自注意力机制处理",
"MLP Head": "分类输出层"
}
print("ViT架构组件:")
for component, desc in components.items():
print(f"• {component}: {desc}")
vit_architecture_components()#2.3 注意力机制详解
class MultiHeadAttention(nn.Module):
"""
多头自注意力机制
"""
def __init__(self, d_model=768, num_heads=12, dropout=0.1):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# 线性投影层
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.scale = torch.sqrt(torch.FloatTensor([self.d_k])).to('cuda' if torch.cuda.is_available() else 'cpu')
def forward(self, x):
batch_size, seq_len, _ = x.shape
# 线性投影
Q = self.W_q(x) # (B, seq_len, d_model)
K = self.W_k(x) # (B, seq_len, d_model)
V = self.W_v(x) # (B, seq_len, d_model)
# 分割成多个头
Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) # (B, num_heads, seq_len, d_k)
K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) # (B, num_heads, seq_len, d_k)
V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) # (B, num_heads, seq_len, d_k)
# 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale # (B, num_heads, seq_len, seq_len)
attention = F.softmax(scores, dim=-1)
attention = self.dropout(attention)
# 应用注意力到V
output = torch.matmul(attention, V) # (B, num_heads, seq_len, d_k)
# 合并多个头
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
# 最终线性投影
output = self.W_o(output)
return output
def attention_mechanism_explanation():
"""
注意力机制解释
"""
print("多头自注意力机制:")
print("1. Query, Key, Value: 三个线性投影")
print("2. 注意力分数: Q*K^T / sqrt(d_k)")
print("3. Softmax: 归一化注意力权重")
print("4. 输出: Attention * V")
print("5. 多头: 并行计算多个注意力头")
print("6. 拼接: 合并所有头的输出")
attention_mechanism_explanation()#3. 位置编码与Class Token
#3.1 位置编码策略
class PositionalEncoding(nn.Module):
"""
位置编码模块
"""
def __init__(self, d_model=768, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
-(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer('pe', pe)
def forward(self, x):
# x: (batch_size, seq_len, d_model)
return x + self.pe[:, :x.size(1)]
import math
def positional_encoding_types():
"""
位置编码类型
"""
encodings = {
"Learnable PE": "可学习的位置嵌入参数",
"Sinusoidal PE": "正弦余弦函数编码",
"2D PE": "二维空间位置编码",
"Rotary PE": "旋转位置编码"
}
print("位置编码类型:")
for encoding, desc in encodings.items():
print(f"• {encoding}: {desc}")
positional_encoding_types()#3.2 Class Token的作用
def class_token_purpose():
"""
Class Token的作用解释
"""
print("Class Token的作用:")
print("1. 聚合信息: 收集所有patch的信息")
print("2. 分类表示: 作为图像的全局表示")
print("3. 梯度流动: 为反向传播提供路径")
print("4. 位置不变: 与所有patch交互")
class_token_purpose()#4. ViT变体与改进
#4.1 DeiT (Data-efficient Image Transformer)
DeiT通过知识蒸馏提高了ViT的训练效率。
class DistillationToken(nn.Module):
"""
知识蒸馏token
"""
def __init__(self, dim=768):
super().__init__()
self.dist_token = nn.Parameter(torch.randn(1, 1, dim))
def forward(self, x):
# x: (B, num_patches+1, dim) - already includes cls token
b, n, d = x.shape
dist_tokens = self.dist_token.repeat(b, 1, 1) # (B, 1, dim)
x = torch.cat([x, dist_tokens], dim=1) # (B, num_patches+2, dim)
return x
def deit_improvements():
"""
DeiT改进措施
"""
improvements = [
"蒸馏token: 使用教师模型指导训练",
"更强的数据增强: RandAugment, Mixup等",
"训练技巧: AdamW优化器,学习率调度",
"正则化: Dropout, Stochastic Depth"
]
print("DeiT改进措施:")
for imp in improvements:
print(f"• {imp}")
deit_improvements()#4.2 Efficient Vision Transformers
def efficient_vit_variants():
"""
高效ViT变体
"""
variants = {
"MobileViT": "轻量级架构,适合移动设备",
"Twins": "Spatial Attention + Sequential Self-Attention",
"PVT": "Pyramid Vision Transformer",
"Shuffle Transformer": "通道混洗注意力",
"CMT": "Convolutional Neural Networks Meet Vision Transformers"
}
print("高效ViT变体:")
for variant, desc in variants.items():
print(f"• {variant}: {desc}")
efficient_vit_variants()#5. 预训练模型使用
#5.1 使用PyTorch预训练ViT
def use_pretrained_vit():
"""
使用预训练ViT模型
"""
print("使用PyTorch预训练ViT:")
print("""
import torch
import torchvision.models as models
# 加载预训练ViT模型
vit_b_16 = models.vit_b_16(weights='IMAGENET1K_V1')
vit_b_32 = models.vit_b_32(weights='IMAGENET1K_V1')
vit_l_16 = models.vit_l_16(weights='IMAGENET1K_V1')
# 推理示例
model = vit_b_16
model.eval()
# 预处理图像
from torchvision import transforms
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# 推理
with torch.no_grad():
output = model(input_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
""")
use_pretrained_vit()#5.2 使用Hugging Face Transformers
def huggingface_vit():
"""
使用Hugging Face的Vision Transformer
"""
print("使用Hugging Face Transformers:")
print("""
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import requests
# 加载处理器和模型
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
# 处理图像
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(images=image, return_tensors="pt")
# 推理
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])
""")
huggingface_vit()#6. ViT vs CNN对比分析
#6.1 详细对比表格
def detailed_comparison():
"""
ViT与CNN详细对比
"""
comparison = {
"感受野": {
"CNN": "逐层增长,局部到全局",
"ViT": "全局连接,从第一层开始"
},
"参数效率": {
"CNN": "参数较少,计算效率高",
"ViT": "参数较多,需要大数据集"
},
"数据需求": {
"CNN": "小数据集上表现良好",
"ViT": "需要大规模预训练"
},
"可解释性": {
"CNN": "可视化困难,可解释性低",
"ViT": "注意力权重可直接可视化"
},
"计算复杂度": {
"CNN": "O(n),其中n是像素数",
"ViT": "O(n²),其中n是patch数"
},
"归纳偏置": {
"CNN": "强归纳偏置(平移不变性、局部性)",
"ViT": "弱归纳偏置,更通用"
},
"扩展性": {
"CNN": "扩展受限于架构设计",
"ViT": "容易扩展到更大规模"
}
}
print("ViT vs CNN 详细对比:")
for aspect, methods in comparison.items():
print(f"• {aspect}:")
print(f" - CNN: {methods['CNN']}")
print(f" - ViT: {methods['ViT']}")
detailed_comparison()#6.2 适用场景分析
def use_case_analysis():
"""
ViT和CNN适用场景分析
"""
scenarios = {
"CNN适用场景": [
"小数据集训练",
"实时推理要求高",
"移动设备部署",
"边缘计算场景",
"传统视觉任务"
],
"ViT适用场景": [
"大数据集预训练",
"需要全局信息的任务",
"可解释性要求高",
"多模态任务",
"研究和前沿应用"
],
"混合架构": [
"CvT: CNN + Transformer",
"CoAtNet: Convolution + Attention",
"ConViT: Convolutional Vision Transformer"
]
}
for category, items in scenarios.items():
print(f"{category}:")
for item in items:
print(f" • {item}")
print()
use_case_analysis()#7. 实践应用与调优
#7.1 训练技巧
def training_tips():
"""
ViT训练技巧
"""
tips = [
"使用大规模数据集预训练",
"采用AdamW优化器,权重衰减",
"学习率预热和余弦退火调度",
"强数据增强(RandAugment, Mixup)",
"Dropout和标签平滑正则化",
"知识蒸馏提升小模型性能"
]
print("ViT训练技巧:")
for i, tip in enumerate(tips, 1):
print(f"{i}. {tip}")
training_tips()#7.2 性能优化
def performance_optimization():
"""
ViT性能优化策略
"""
optimizations = [
"模型量化: INT8量化减少内存占用",
"知识蒸馏: 用大模型训练小模型",
"模型剪枝: 移除冗余参数",
"混合精度训练: 节省显存加速训练",
"序列长度优化: 减少patch数量",
"注意力稀疏化: 降低计算复杂度"
]
print("ViT性能优化策略:")
for opt in optimizations:
print(f"• {opt}")
performance_optimization()#相关教程
#8. 总结
Vision Transformer代表了计算机视觉的新范式:
核心创新:
- 图像分块:将图像转换为序列
- 全局注意力:捕获长距离依赖
- 可扩展性:容易扩展到大规模
关键技术:
- Patch Embedding
- Class Token
- Positional Encoding
- Multi-head Attention
💡 重要提醒:ViT证明了纯Transformer架构在视觉任务上的有效性,开启了视觉与语言统一建模的新时代。
🔗 扩展阅读

