#MAE (Masked Autoencoders):自监督学习的视觉预训练方法详解
#引言
Masked Autoencoders (MAE) 是何恺明等人在2021年提出的革命性自监督学习方法,它将NLP领域中BERT的掩码语言建模思想成功迁移到计算机视觉领域。MAE通过随机遮盖图像中的大部分patch,并训练模型重建被遮盖的部分,实现了高效的视觉表征学习。这一方法极大地推动了自监督学习在计算机视觉中的发展,为视觉Transformer的预训练提供了新的范式。
📂 所属阶段:第二阶段 — 深度学习视觉基础(CNN 篇)
🔗 相关章节:Swin Transformer · Vision-Language 多模态
#1. MAE核心思想与动机
#1.1 自监督学习的兴起
自监督学习是当前深度学习的一个重要发展方向,它旨在利用大量未标注数据进行预训练。
"""
MAE的核心动机:
1. 标注数据稀缺且昂贵
2. 人类视觉系统无需监督即可学习
3. 从NLP的掩码语言模型获得启发
4. 探索更高效的视觉表征学习方法
"""
def self_supervised_learning_motivation():
"""
自监督学习动机分析
"""
motivations = {
"数据效率": "利用海量未标注数据进行预训练",
"成本效益": "避免昂贵的数据标注过程",
"泛化能力": "学习更通用的视觉表征",
"可扩展性": "适应大规模数据集训练"
}
print("自监督学习动机:")
for motivation, desc in motivations.items():
print(f"• {motivation}: {desc}")
self_supervised_learning_motivation()#1.2 MAE的创新点
def mae_innovations():
"""
MAE核心创新点
"""
innovations = [
"不对称编码器-解码器架构",
"高比例掩码策略(75%)",
"仅对掩码部分计算损失",
"高效的重建目标"
]
print("MAE核心创新:")
for i, innovation in enumerate(innovations, 1):
print(f"{i}. {innovation}")
mae_innovations()#2. MAE架构详解
#2.1 不对称编码器-解码器设计
MAE的最大特点是采用了不对称的编码器-解码器架构。
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class PatchEmbedding(nn.Module):
"""
图像到patch的嵌入层
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
num_patches = (img_size // patch_size) ** 2
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class MAEEncoder(nn.Module):
"""
MAE编码器 - 基于Vision Transformer
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., norm_layer=nn.LayerNorm):
super().__init__()
self.patch_embed = PatchEmbedding(img_size, patch_size, in_chans, embed_dim)
num_patches = self.patch_embed.num_patches
# 类别token和位置嵌入
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
# Transformer编码器层
self.blocks = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=num_heads,
dim_feedforward=int(embed_dim * mlp_ratio),
dropout=0.1,
activation='gelu',
batch_first=True
) for _ in range(depth)
])
self.norm = norm_layer(embed_dim)
# 初始化位置嵌入
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
def forward(self, x, mask):
# patch嵌入
x = self.patch_embed(x) # (B, N, D)
# 应用掩码 - 只保留未被掩码的patch
x = x[~mask].reshape(x.shape[0], -1, x.shape[-1]) # (B, N_visible, D)
# 添加类别token
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat([cls_token, x], dim=1) # (B, 1+N_visible, D)
# 添加位置嵌入
pos_embed = self.pos_embed[:, 1:, :] # 移除cls token的位置嵌入
pos_embed = pos_embed[~mask].reshape(pos_embed.shape[0], -1, pos_embed.shape[-1])
cls_pos_embed = self.pos_embed[:, :1, :] # cls token的位置嵌入
pos_embed = torch.cat([cls_pos_embed, pos_embed], dim=1)
x = x + pos_embed
# Transformer编码
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x
def asymmetric_architecture_explanation():
"""
不对称架构解释
"""
print("MAE不对称架构设计:")
print("• 编码器: 只处理可见patch,轻量级")
print("• 解码器: 处理所有patch,重建图像")
print("• 高效性: 编码器跳过掩码patch")
print("• 重建目标: 仅重建被掩码的部分")
asymmetric_architecture_explanation()#2.2 MAE解码器设计
class MAEDecoder(nn.Module):
"""
MAE解码器 - 用于重建被掩码的patch
"""
def __init__(self, num_patches=196, patch_size=16, embed_dim=768, decoder_embed_dim=512,
decoder_depth=8, decoder_num_heads=16, mlp_ratio=4., norm_layer=nn.LayerNorm):
super().__init__()
self.embed_dim = embed_dim
self.decoder_embed_dim = decoder_embed_dim
self.num_patches = num_patches
self.patch_size = patch_size
# 解码器嵌入层
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim)
# 掩码token
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
# 位置嵌入
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim))
# Transformer解码器层
self.decoder_blocks = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=decoder_embed_dim,
nhead=decoder_num_heads,
dim_feedforward=int(decoder_embed_dim * mlp_ratio),
dropout=0.1,
activation='gelu',
batch_first=True
) for _ in range(decoder_depth)
])
# 输出投影
self.decoder_norm = norm_layer(decoder_embed_dim)
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * 3) # 每个patch的像素
# 初始化
nn.init.trunc_normal_(self.mask_token, std=0.02)
nn.init.trunc_normal_(self.decoder_pos_embed, std=0.02)
def forward(self, x, ids_restore):
# 嵌入维度转换
x = self.decoder_embed(x) # (B, 1+N_visible, D_dec)
# 获取可见patch的数量
N_visible = x.shape[1] - 1 # 减去cls token
# 扩展mask tokens
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - N_visible, 1)
# 拼接可见patch和mask tokens
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # 移除cls token,拼接mask tokens
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
x = torch.cat([x[:, :1, :], x_], dim=1) # 重新加入cls token
# 添加位置嵌入
x = x + self.decoder_pos_embed
# 解码器Transformer
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
# 移除cls token
x = x[:, 1:, :]
# 预测被掩码patch的像素值
x = self.decoder_pred(x)
return x
def decoder_design_insights():
"""
解码器设计要点
"""
print("MAE解码器设计要点:")
print("1. 掩码token: 代表被掩码patch的占位符")
print("2. ids_restore: 恢复原始patch顺序")
print("3. 重建目标: 每个patch的RGB像素值")
print("4. 轻量编码器 + 重型解码器的设计")
decoder_design_insights()#3. 掩码策略详解
#3.1 掩码实现
MAE使用高比例的掩码策略,通常掩码75%的patch。
def generate_random_mask(B, N, mask_ratio):
"""
生成随机掩码
Args:
B: batch size
N: number of patches
mask_ratio: 掩码比例
Returns:
mask: bool tensor, True表示被掩码
ids_shuffle: 按可见patch优先排序的索引
ids_restore: 恢复原始顺序的索引
"""
len_keep = int(N * (1 - mask_ratio))
noise = torch.rand(B, N) # [0, 1), 随机噪声
# 排序得到索引
ids_shuffle = torch.argsort(noise, dim=1) # 升序排列的索引
ids_restore = torch.argsort(ids_shuffle, dim=1) # 恢复原始索引的索引
# 获取要保留的patch索引
ids_keep = ids_shuffle[:, :len_keep]
# 创建掩码
mask = torch.ones([B, N])
mask[:, :len_keep] = 0
# 未shuffle前的掩码状态
mask = torch.gather(mask, dim=1, index=ids_restore)
return mask.bool(), ids_shuffle, ids_restore
def masking_strategy_analysis():
"""
掩码策略分析
"""
strategies = {
"高掩码比例": "75%掩码促使模型学习更好的表征",
"随机掩码": "避免模型学习位置先验",
"结构化掩码": "可以考虑使用块状掩码策略",
"自适应掩码": "根据patch重要性动态调整掩码"
}
print("MAE掩码策略分析:")
for strategy, desc in strategies.items():
print(f"• {strategy}: {desc}")
masking_strategy_analysis()#3.2 MAE完整模型
class MaskedAutoencoder(nn.Module):
"""
完整的Masked Autoencoder模型
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3,
encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mask_ratio=0.75):
super().__init__()
self.patch_size = patch_size
self.mask_ratio = mask_ratio
# 编码器
self.encoder = MAEEncoder(
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
embed_dim=encoder_embed_dim, depth=encoder_depth, num_heads=encoder_num_heads
)
# 解码器
num_patches = self.encoder.patch_embed.num_patches
self.decoder = MAEDecoder(
num_patches=num_patches, patch_size=patch_size,
embed_dim=encoder_embed_dim, decoder_embed_dim=decoder_embed_dim,
decoder_depth=decoder_depth, decoder_num_heads=decoder_num_heads
)
# 用于重建的标准化参数
self.norm_pix_loss = True
def patchify(self, imgs):
"""
将图像分割成patch
"""
p = self.patch_size
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
h = w = imgs.shape[2] // p
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
return x
def unpatchify(self, x):
"""
将patch重建为图像
"""
p = self.patch_size
h = w = int(x.shape[1]**.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
x = torch.einsum('nhwpqc->nchpwq', x)
x = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
return x
def forward_encoder(self, x, mask_ratio):
# 生成掩码
mask, ids_shuffle, ids_restore = generate_random_mask(x.shape[0],
self.encoder.patch_embed.num_patches,
mask_ratio)
# 编码
x = self.encoder(x, mask)
return x, mask, ids_restore
def forward_decoder(self, x, ids_restore):
# 解码
x = self.decoder(x, ids_restore)
return x
def forward_loss(self, imgs, pred, mask):
"""
计算重建损失
"""
target = self.patchify(imgs)
if self.norm_pix_loss:
# 对每个patch进行标准化
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.e-6)**.5
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
# 只计算被掩码patch的损失
loss = (loss * mask).sum() / mask.sum() # 平均掩码patch的损失
return loss
def forward(self, imgs):
latent, mask, ids_restore = self.forward_encoder(imgs, self.mask_ratio)
pred = self.forward_decoder(latent, ids_restore)
loss = self.forward_loss(imgs, pred, mask)
return loss, pred, mask
def mae_training_process():
"""
MAE训练过程详解
"""
print("MAE训练过程:")
print("1. 输入图像分割为196个16x16 patch")
print("2. 随机掩码75%的patch (147个)")
print("3. 编码器处理25%可见patch (49个)")
print("4. 解码器重建所有196个patch")
print("5. 计算仅在掩码patch上的重建损失")
print("6. 反向传播更新模型参数")
mae_training_process()#4. 自监督学习与预训练
#4.1 自监督学习范式
def self_supervised_paradigms():
"""
自监督学习范式对比
"""
paradigms = {
"对比学习": "通过对比正负样本来学习表征",
"生成式方法": "通过重建原始数据来学习表征",
"掩码建模": "通过预测被掩码部分来学习表征",
"预测式方法": "通过预测未来帧或上下文来学习"
}
print("自监督学习范式:")
for paradigm, desc in paradigms.items():
print(f"• {paradigm}: {desc}")
print("\nMAE属于生成式掩码建模方法")
self_supervised_paradigms()#4.2 预训练与微调
def pretraining_finetuning_pipeline():
"""
预训练与微调流水线
"""
print("MAE预训练与微调流程:")
print("""
# 1. 预训练阶段
mae_model = MaskedAutoencoder()
# 在大规模未标注图像上训练
# loss = reconstruction_loss
# 2. 提取编码器
encoder = mae_model.encoder # 取出编码器部分
# 3. 微调阶段
# 添加分类头
classifier = nn.Linear(encoder.embed_dim, num_classes)
model = nn.Sequential(encoder, classifier)
# 在标注数据上微调
# loss = cross_entropy_loss
""")
pretraining_finetuning_pipeline()#5. MAE变体与改进
#5.1 相关工作对比
def compare_masked_methods():
"""
掩码方法对比
"""
methods = {
"BEiT": "使用离散VAE进行图像tokenization",
"SimMIM": "简化掩码策略,无编码器-解码器结构",
"CAE": "对比掩码图像建模",
"iBOT": "使用知识蒸馏的掩码图像建模"
}
print("掩码图像建模方法对比:")
for method, desc in methods.items():
print(f"• {method}: {desc}")
compare_masked_methods()#5.2 MAE的优势分析
def mae_advantages():
"""
MAE优势分析
"""
advantages = [
"高掩码比例促进学习更好的表征",
"不对称架构提高训练效率",
"适用于各种视觉任务",
"可扩展到大规模模型",
"数据效率高"
]
print("MAE的主要优势:")
for i, advantage in enumerate(advantages, 1):
print(f"{i}. {advantage}")
mae_advantages()#6. 实际应用与实验
#6.1 使用预训练MAE
def use_pretrained_mae():
"""
使用预训练MAE模型
"""
print("使用预训练MAE模型:")
print("""
import torch
import timm
# 加载预训练MAE模型
model = timm.create_model('mae_vit_base_patch16', pretrained=True)
# 提取特征用于下游任务
model.eval()
with torch.no_grad():
features = model.forward_features(images)
# 微调分类任务
classifier = nn.Linear(model.embed_dim, num_classes)
full_model = nn.Sequential(model, classifier)
# 微调
for param in model.parameters():
param.requires_grad = False # 冻结预训练模型参数
for param in classifier.parameters():
param.requires_grad = True # 只训练分类器
# 或者进行端到端微调
for param in model.parameters():
param.requires_grad = True # 解冻所有参数
""")
use_pretrained_mae()#6.2 与其他方法比较
def performance_comparison():
"""
性能对比分析
"""
comparison = {
"ImageNet Top-1 Acc": {
"Supervised ViT-B": "82.2%",
"MAE + ViT-B": "83.6%",
"Supervised Swin-B": "83.5%",
"MAE + ViT-L": "85.9%"
},
"迁移学习效果": {
"COCO检测": "显著提升",
"ADE20K分割": "显著提升",
"下游任务": "普遍提升3-5%"
}
}
print("MAE性能对比:")
for metric, results in comparison.items():
print(f"\n{metric}:")
for method, score in results.items():
print(f" • {method}: {score}")
performance_comparison()#7. 实现细节与技巧
#7.1 训练技巧
def mae_training_tricks():
"""
MAE训练技巧
"""
tricks = [
"使用高掩码比例(0.75)提高学习效率",
"不对称编码器-解码器架构设计",
"标准化像素值减少重建难度",
"学习率预热和余弦退火调度",
"大批量训练获得更好表征",
"数据增强提高泛化能力"
]
print("MAE训练技巧:")
for i, trick in enumerate(tricks, 1):
print(f"{i}. {trick}")
mae_training_tricks()#7.2 代码实现注意事项
def implementation_notes():
"""
实现注意事项
"""
notes = [
"掩码索引的正确处理",
"位置嵌入的恰当使用",
"批处理维度的一致性",
"梯度裁剪防止爆炸",
"模型检查点保存策略"
]
print("MAE实现注意事项:")
for note in notes:
print(f"• {note}")
implementation_notes()#相关教程
#8. 总结
Masked Autoencoders代表了自监督视觉学习的新范式:
核心创新:
- 高比例掩码:75%掩码比例促进学习
- 不对称架构:高效的编码器-解码器设计
- 重建目标:像素级重建学习表征
技术影响:
- 推动自监督学习发展
- 提升下游任务性能
- 降低数据标注依赖
💡 重要提醒:MAE证明了掩码建模在视觉领域的有效性,为视觉Transformer的预训练提供了新思路。这是现代视觉模型不可或缺的预训练方法。
🔗 扩展阅读

