MAE (Masked Autoencoders):自监督学习的视觉预训练方法详解

引言

Masked Autoencoders (MAE) 是何恺明等人在2021年提出的革命性自监督学习方法,它将NLP领域中BERT的掩码语言建模思想成功迁移到计算机视觉领域。MAE通过随机遮盖图像中的75% patch,并训练模型重建被遮盖的部分,实现了高效的视觉表征学习。这一方法极大地推动了自监督学习在计算机视觉中的发展,为视觉Transformer的预训练提供了新的范式。

📂 所属阶段:第二阶段 — 深度学习视觉基础(视觉Transformer 篇)
🔗 相关章节:Swin Transformer · Vision-Language 多模态


1. MAE核心思想与动机

1.1 自监督学习的兴起

自监督学习是当前深度学习的重要发展方向,它旨在利用海量未标注数据自主构造监督信号进行预训练,主要动机包括:

  • 数据效率:避免昂贵且耗时的人工标注,直接复用互联网/工业界的公开/私有无标签图像库
  • 成本效益:无需专业标注团队,大幅降低AI模型的开发成本
  • 泛化能力:从无约束的自然数据中学习更通用的底层视觉特征,而非局限于特定标注任务
  • 可扩展性:天然适配大规模数据集与大模型训练,随着数据量/模型参数量提升性能持续增长

1.2 MAE的创新点

MAE的成功源于3大关键技术创新:

  1. 不对称编码器-解码器架构:编码器仅处理可见patch(计算量降低75%),轻量高效;解码器处理所有patch,专门负责重建任务
  2. 高比例随机掩码:采用75%的极端随机掩码比例,迫使模型学习全局语义关联而非局部纹理先验
  3. 轻量级像素级重建目标:直接预测被掩码patch的RGB像素值,无需额外预训练的VAE/Tokenizer等辅助模块,降低实现复杂度

2. MAE架构详解

2.1 不对称编码器-解码器设计

编码器基于标准Vision Transformer(ViT),仅保留对未掩码patch的处理逻辑,核心模块包括:

import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbedding(nn.Module):
    """
    将224×224图像分割为14×14=196个16×16 patch,并嵌入为768维向量
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        
    def forward(self, x):
        x = self.proj(x).flatten(2).transpose(1, 2)  # (B, N, D)
        return x

class MAEEncoder(nn.Module):
    """
    MAE轻量级编码器,仅处理未被掩码的patch
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, 
                 num_heads=12, mlp_ratio=4.):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_chans, embed_dim)
        
        # 类别token(用于下游分类)+ 所有patch的位置嵌入
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim))
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        
        # ViT编码器层
        self.blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=embed_dim, nhead=num_heads, dim_feedforward=int(embed_dim*mlp_ratio),
                dropout=0.1, activation='gelu', batch_first=True
            ) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        
    def forward(self, x, mask):
        # 1. Patch嵌入
        x = self.patch_embed(x)  # (B, 196, 768)
        
        # 2. 应用掩码:仅保留未被遮盖的49个patch
        x = x[~mask].reshape(x.shape[0], -1, x.shape[-1])  # (B, 49, 768)
        
        # 3. 添加类别token和对应位置的嵌入
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat([cls_token, x], dim=1)
        
        pos_keep = self.pos_embed[:, 1:, :][~mask].reshape(x.shape[0], -1, x.shape[-1])
        pos_cls = self.pos_embed[:, :1, :]
        x = x + torch.cat([pos_cls, pos_keep], dim=1)
        
        # 4. Transformer编码
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return x

2.2 MAE解码器设计

解码器比编码器轻量但更专注重建,包含掩码token占位符、完整位置嵌入和重建投影层:

class MAEDecoder(nn.Module):
    """
    MAE重建解码器,处理所有196个patch
    """
    def __init__(self, num_patches=196, patch_size=16, embed_dim=768, decoder_embed_dim=512,
                 decoder_depth=8, decoder_num_heads=16):
        super().__init__()
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim)
        
        # 掩码token(代表被遮盖的patch)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        nn.init.trunc_normal_(self.mask_token, std=0.02)
        
        # 所有patch(含cls)的解码器位置嵌入
        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim))
        nn.init.trunc_normal_(self.decoder_pos_embed, std=0.02)
        
        # 解码器Transformer层
        self.decoder_blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=decoder_embed_dim, nhead=decoder_num_heads, dim_feedforward=int(decoder_embed_dim*4),
                dropout=0.1, activation='gelu', batch_first=True
            ) for _ in range(decoder_depth)
        ])
        
        # 重建每个patch的RGB像素值
        self.decoder_norm = nn.LayerNorm(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * 3)
        
    def forward(self, x, ids_restore):
        # 1. 维度降维到解码器嵌入维度
        x = self.decoder_embed(x)  # (B, 50, 512)
        
        # 2. 拼接可见patch和掩码token,并恢复原始196个patch的顺序
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
        x_no_cls = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
        x_no_cls = torch.gather(x_no_cls, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
        x = torch.cat([x[:, :1, :], x_no_cls], dim=1)
        
        # 3. 添加完整位置嵌入并解码
        x = x + self.decoder_pos_embed
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)
        
        # 4. 移除cls token,重建像素
        x = self.decoder_pred(x[:, 1:, :])
        return x

3. 掩码策略与完整模型

3.1 随机高比例掩码实现

MAE的随机掩码逻辑需保证批量处理一致性,并返回恢复原始顺序的索引:

def generate_random_mask(B, N, mask_ratio=0.75):
    """
    生成随机高比例掩码
    """
    len_keep = int(N * (1 - mask_ratio))
    
    # 生成随机噪声排序
    noise = torch.rand(B, N)
    ids_shuffle = torch.argsort(noise, dim=1)
    ids_restore = torch.argsort(ids_shuffle, dim=1)
    
    # 创建掩码:True表示被遮盖
    mask = torch.ones([B, N])
    mask[:, :len_keep] = 0
    mask = torch.gather(mask, dim=1, index=ids_restore).bool()
    return mask, ids_restore

3.2 MAE完整模型与训练流程

完整模型整合了编码、解码、损失计算逻辑,训练时仅优化被掩码patch的重建损失:

class MaskedAutoencoder(nn.Module):
    """
    完整MAE模型
    """
    def __init__(self, img_size=224, patch_size=16, mask_ratio=0.75):
        super().__init__()
        self.patch_size = patch_size
        self.mask_ratio = mask_ratio
        self.norm_pix_loss = True
        
        # 编码器与解码器(使用默认ViT-Base配置)
        self.encoder = MAEEncoder()
        self.decoder = MAEDecoder()
        
    def patchify(self, imgs):
        """将图像分割为patch"""
        p = self.patch_size
        h = w = imgs.shape[2] // p
        x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
        x = x.permute(0, 2, 4, 3, 5, 1).flatten(1, 2).flatten(2, 4)
        return x
    
    def forward_loss(self, imgs, pred, mask):
        """仅计算被掩码patch的标准化像素损失"""
        target = self.patchify(imgs)
        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1e-6)**0.5
        loss = (pred - target)**2
        loss = loss.mean(dim=-1)
        loss = (loss * mask).sum() / mask.sum()
        return loss
    
    def forward(self, imgs):
        mask, ids_restore = generate_random_mask(imgs.shape[0], self.encoder.patch_embed.num_patches, self.mask_ratio)
        latent = self.encoder(imgs, mask)
        pred = self.decoder(latent, ids_restore)
        loss = self.forward_loss(imgs, pred, mask)
        return loss, pred, mask

4. 预训练与下游应用

4.1 预训练要点

使用AdamW优化器,学习率预热+余弦退火,数据增强仅用随机缩放裁剪和水平翻转即可。

4.2 微调步骤

  1. 提取编码器:丢弃解码器,仅保留MAE的ViT编码器
  2. 添加分类头/任务头:例如在ImageNet分类中,接一个线性层映射到1000类
  3. 微调策略:可先冻结编码器只训练任务头(Linear Probe),再端到端全量微调

4.3 使用timm库预训练模型

import torch
import timm

# 加载预训练MAE ViT-Base模型
model = timm.create_model('mae_vit_base_patch16_224', pretrained=True, num_classes=0)  # num_classes=0仅返回特征

# 提取图像特征
model.eval()
with torch.no_grad():
    features = model(torch.randn(2, 3, 224, 224))  # (2, 768)

总结

MAE通过高比例随机掩码+不对称编码器-解码器+像素级重建的组合,成功将NLP的掩码建模迁移到计算机视觉,大幅提升了ViT在下游任务的性能(ImageNet Top-1从监督ViT-B的82.2%提升到MAE+ViT-B的83.6%)。这一方法实现简单、数据效率高,已成为现代视觉Transformer预训练的标配范式。

💡 扩展阅读