MAE (Masked Autoencoders): Detailed explanation of visual pre-training method for self-supervised learning

Introduction

Masked Autoencoders (MAE) is a revolutionary self-supervised learning method proposed by He Kaiming and others in 2021. It cleverly transfers BERT's masked language modeling ideas in natural language processing to the field of computer vision. The core idea of ​​MAE is to randomly occlude up to 75% of patches in the image, and then train a model to reconstruct these occluded parts. In this process, the model is forced to learn the deep structural information of the image, thereby obtaining powerful visual representation capabilities. This method has greatly promoted the development of self-supervised learning in the field of vision and provided a new and efficient paradigm for pre-training of Vision Transformer (ViT).

📂 Stage: Stage 2 - Deep Learning Visual Basics (Visual Transformer) 🔗 Related chapters: Swin Transformer · Vision-Language 多模态


1. MAE core ideas and motivations

1.1 The rise of self-supervised learning

Self-supervised learning is an important direction in current deep learning. Its goal is to use massive unlabeled data to allow the model to construct its own supervision signals for pre-training. This idea is strongly motivated by the following factors:

  • Data efficiency: Avoid expensive and time-consuming manual annotation, and can directly reuse a large number of public or private unlabeled images on the Internet or in industry.
  • Cost Effectiveness: No need for a professional labeling team, significantly reducing the development cost of AI models.
  • Generalization ability: The model can learn more general underlying visual features from unconstrained natural data, rather than just being limited to a specific labeling task.
  • Scalability: Naturally adaptable to large-scale data sets and large model training. As the amount of data and model parameters increases, performance can continue to improve.

1.2 Innovations of MAE

MAE’s success mainly comes from three key technological innovations:

  1. Asymmetric encoder-decoder architecture: The encoder only processes visible patches that are not occluded (the amount of calculation is reduced by about 75%), so it is very lightweight and efficient; while the decoder needs to process all patches and is specifically responsible for reconstructing the occluded image content.
  2. High-proportion random mask: Using an extreme random occlusion ratio of 75% forces the model to understand the global semantic association between different areas in the image, and cannot simply rely on local texture filling.
  3. Lightweight pixel-level reconstruction target: Directly predict the original RGB pixel value of the occluded patch. There is no need to introduce additional auxiliary modules such as pre-trained VAE (variational autoencoder) or Tokenizer. The implementation is simple and the training is stable.

2. Detailed explanation of MAE architecture

2.1 Asymmetric encoder-decoder design

The encoder is based on the standard Vision Transformer (ViT), but only retains the processing logic for unmasked patches. It is responsible for compressing visible patches into high-dimensional feature representations. The core modules include patch embedding, position encoding and Transformer encoding layers.

The following code shows the specific implementation of the MAE encoder:

import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbedding(nn.Module):
    """
    将224×224图像分割为14×14=196个16×16 patch,并嵌入为768维向量
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x).flatten(2).transpose(1, 2)  # (B, N, D)
        return x

class MAEEncoder(nn.Module):
    """
    MAE轻量级编码器,仅处理未被掩码的patch
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12,
                 num_heads=12, mlp_ratio=4.):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_chans, embed_dim)

        # 类别token(用于下游分类)+ 所有patch的位置嵌入
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim))
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

        # ViT编码器层
        self.blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=embed_dim, nhead=num_heads, dim_feedforward=int(embed_dim*mlp_ratio),
                dropout=0.1, activation='gelu', batch_first=True
            ) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x, mask):
        # 1. Patch嵌入
        x = self.patch_embed(x)  # (B, 196, 768)

        # 2. 应用掩码:仅保留未被遮盖的49个patch
        x = x[~mask].reshape(x.shape[0], -1, x.shape[-1])  # (B, 49, 768)

        # 3. 添加类别token和对应位置的嵌入
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat([cls_token, x], dim=1)

        pos_keep = self.pos_embed[:, 1:, :][~mask].reshape(x.shape[0], -1, x.shape[-1])
        pos_cls = self.pos_embed[:, :1, :]
        x = x + torch.cat([pos_cls, pos_keep], dim=1)

        # 4. Transformer编码
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return x

2.2 MAE decoder design

Although the decoder is overall lighter than the encoder, it needs to process all 196 patches including occluded patches, so it focuses more on the reconstruction task. Key components of the decoder include:

  • A linear mapping that maps the high-dimensional features output by the encoder to the lower dimensions used by the decoder.
  • A learnable mask token used to occupy all occluded patches.
  • Complete position embedding, ensuring that regardless of whether patches are occluded, the model knows their location in the original image.
  • Several Transformer layers used to fuse information from visible patches and infer the contents of occluded areas.
  • A linear projection layer that maps the features of each patch back to the original pixel values ​​(e.g. 16×16×3 RGB values).
class MAEDecoder(nn.Module):
    """
    MAE重建解码器,处理所有196个patch
    """
    def __init__(self, num_patches=196, patch_size=16, embed_dim=768, decoder_embed_dim=512,
                 decoder_depth=8, decoder_num_heads=16):
        super().__init__()
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim)

        # 掩码token(代表被遮盖的patch)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        nn.init.trunc_normal_(self.mask_token, std=0.02)

        # 所有patch(含cls)的解码器位置嵌入
        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim))
        nn.init.trunc_normal_(self.decoder_pos_embed, std=0.02)

        # 解码器Transformer层
        self.decoder_blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=decoder_embed_dim, nhead=decoder_num_heads, dim_feedforward=int(decoder_embed_dim*4),
                dropout=0.1, activation='gelu', batch_first=True
            ) for _ in range(decoder_depth)
        ])

        # 重建每个patch的RGB像素值
        self.decoder_norm = nn.LayerNorm(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * 3)

    def forward(self, x, ids_restore):
        # 1. 维度降维到解码器嵌入维度
        x = self.decoder_embed(x)  # (B, 50, 512)

        # 2. 拼接可见patch和掩码token,并恢复原始196个patch的顺序
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
        x_no_cls = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
        x_no_cls = torch.gather(x_no_cls, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
        x = torch.cat([x[:, :1, :], x_no_cls], dim=1)

        # 3. 添加完整位置嵌入并解码
        x = x + self.decoder_pos_embed
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)

        # 4. 移除cls token,重建像素
        x = self.decoder_pred(x[:, 1:, :])
        return x

3. Masking strategy and complete model

3.1 Random high proportion mask implementation

MAE's masking strategy may seem simple, but it is crucial to training. The specific method is: for each sample, randomly shuffle the order of all patches, then select the previous part as the retained visible patch, and the remaining as occlusion patch. In order for the decoder to correctly restore the position of the original image block, a "restoration index" needs to be saved, which is used to put the occluded tokens back into the original order.

def generate_random_mask(B, N, mask_ratio=0.75):
    """
    生成随机高比例掩码
    """
    len_keep = int(N * (1 - mask_ratio))

    # 生成随机噪声排序
    noise = torch.rand(B, N)
    ids_shuffle = torch.argsort(noise, dim=1)
    ids_restore = torch.argsort(ids_shuffle, dim=1)

    # 创建掩码:True表示被遮盖
    mask = torch.ones([B, N])
    mask[:, :len_keep] = 0
    mask = torch.gather(mask, dim=1, index=ids_restore).bool()
    return mask, ids_restore

3.2 MAE complete model and training process

The complete model encapsulates the encoder, decoder, and loss calculation. During training, loss is only calculated on occluded patches, and standardized pixel loss is used by default (that is, the pixels of each patch are normalized by mean variance before calculating MSE), which can further improve the stability and final performance of the model.

class MaskedAutoencoder(nn.Module):
    """
    完整MAE模型
    """
    def __init__(self, img_size=224, patch_size=16, mask_ratio=0.75):
        super().__init__()
        self.patch_size = patch_size
        self.mask_ratio = mask_ratio
        self.norm_pix_loss = True

        # 编码器与解码器(使用默认ViT-Base配置)
        self.encoder = MAEEncoder()
        self.decoder = MAEDecoder()

    def patchify(self, imgs):
        """将图像分割为patch"""
        p = self.patch_size
        h = w = imgs.shape[2] // p
        x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
        x = x.permute(0, 2, 4, 3, 5, 1).flatten(1, 2).flatten(2, 4)
        return x

    def forward_loss(self, imgs, pred, mask):
        """仅计算被掩码patch的标准化像素损失"""
        target = self.patchify(imgs)
        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1e-6)**0.5
        loss = (pred - target)**2
        loss = loss.mean(dim=-1)
        loss = (loss * mask).sum() / mask.sum()
        return loss

    def forward(self, imgs):
        mask, ids_restore = generate_random_mask(imgs.shape[0], self.encoder.patch_embed.num_patches, self.mask_ratio)
        latent = self.encoder(imgs, mask)
        pred = self.decoder(latent, ids_restore)
        loss = self.forward_loss(imgs, pred, mask)
        return loss, pred, mask

4. Pre-training and downstream applications

4.1 Key points of pre-training

In actual pre-training, it is recommended to use the AdamW optimizer in conjunction with the learning rate warm-up and cosine annealing strategies. Data augmentation only requires simple random scaling and cropping and horizontal flipping, without the need for complex data augmentation operations, because MAE's own high-ratio mask already provides a strong regularization effect.

4.2 Fine-tuning steps

After pre-training is completed, applying the model to downstream tasks (such as image classification, target detection, etc.) is generally divided into the following steps:

  1. Extract encoder: Discard the decoder used during training, and only retain the ViT encoder part of MAE.
  2. Add task header: For example, in the ImageNet classification task, the category token output by the encoder is followed by a linear layer and mapped to 1000 categories.
  3. Fine-tuning strategy: You can first use Linear Probe to freeze the encoder parameters and only train the classification head; then unfreeze all parameters for end-to-end full fine-tuning to obtain the best results.

4.3 Use the timm library to load the pre-trained model

With the help oftimmlibrary, you can easily load pre-trained MAE models and extract image features without writing all the above code from scratch.

import torch
import timm

# 加载预训练MAE ViT-Base模型
# num_classes=0 表示仅返回特征,不附加分类头
model = timm.create_model('mae_vit_base_patch16_224', pretrained=True, num_classes=0)

# 提取图像特征
model.eval()
with torch.no_grad():
    features = model(torch.randn(2, 3, 224, 224))  # (2, 768)

Summarize

MAE successfully migrates the mask modeling ideas in the NLP field to computer vision through the concise combination of high-proportion random mask + asymmetric encoder-decoder + pixel-level reconstruction. It significantly improves the performance of Vision Transformer in downstream tasks (e.g., ImageNet Top-1 accuracy increases from 82.2% for purely supervised ViT-B to 83.6% for ViT-B after MAE pretraining). This method is simple to implement and has high data efficiency. It has now become a standard paradigm for modern visual Transformer pre-training.

💡 Extended Reading