title: Vision Transformer (ViT) Detailed Explanation: Vision Revolution from Image to Sequence | Daoman PythonAI description: An in-depth analysis of the Vision Transformer (ViT) model and an introduction to its innovative method of applying Transformer to computer vision, including detailed architecture analysis, PyTorch implementation and practical application scenarios. keywords: [Vision Transformer, ViT, Transformer, computer vision, deep learning, self-attention mechanism, image classification, PyTorch]

#Vision Transformer (ViT) Detailed explanation: Vision revolution from image to sequence

Introduction

In 2020, Google published an article "An Image is Worth 16x16 Words", which directly shook CNN's long-term dominance in the field of computer vision. The Vision Transformer (ViT) proposed in this article has for the first time allowed the big killer in the NLP world - Transformer - to gain a firm foothold in image classification tasks, and even surpassed classic convolutional networks such as ResNet on large-scale data.

The core concept of ViT is only one sentence, but it is shocking enough:

Process images like natural language - Cut the image into "visual words", and then use the self-attention mechanism to capture the global relationship at once.

How subversive is this line of thinking? Let's look down.


1. The birth of ViT: CNN’s “ceiling” and breakthrough

1.1 Why is CNN not enough?

Although convolutional neural networks such as ResNet and EfficientNet achieve the ultimate in local feature extraction, they inherently have two unavoidable prior settings:

Inductive biasAdvantagesImplicit limitations
Local receptive fieldLearn basic features such as edges and textures quicklyYou have to stack dozens of layers to barely "see" the whole picture (such as the outline of a dog's body)
Translation invarianceNot so sensitive to changes in object positionLack of explicit modeling of positional relationships ("cat ears above cat eyes")
Static weightHigh reasoning efficiencyAll images use the same set of convolution kernels, and it is impossible to dynamically focus on the "key areas of the current image"

To put it simply, CNN is like a painter who only looks at the details. After painting all the parts, it still takes a lot of effort to piece together the complete picture. At that time, researchers began to think: Is there a way to enable the model to see the whole picture from the beginning?

1.2 ViT’s ideas for breaking the situation

ViT's approach is very simple: it directly overthrows the "local priority" design of CNN and replaces it with the "global priority" paradigm of Transformer.

Several key upgrades brought by ViT:

  1. Global perception in one step: The first layer of self-attention allows the pixels in the upper left corner and the pixels in the lower right corner to "talk" directly
  2. Dynamic generation of attention weights: Automatically adjust the importance of different areas according to the image content, no longer a rigid static convolution kernel
  3. Extremely scalable: The larger the model and the more data, the more obvious the performance improvement - Scaling Law also takes effect in the visual field

2. Dismantling of minimalist architecture: What exactly does ViT do?

The overall structure of ViT almost completely reuses the NLP Transformer encoder. The only change is to replace "text sequence" with "visual sequence". The whole process can be summarized into 4 steps:

graph TD
    A[输入图像<br/>224×224×3] --> B[切割为不重叠的块<br/>16×16×3 × 196块]
    B --> C[线性投影<br/>每块→768维向量]
    C --> D[拼接CLS Token + 位置编码<br/>197×768]
    D --> E[堆叠N层Transformer编码器]
    E --> F[取CLS Token输出<br/>分类头→1000类]

Quick overview of key components

  1. Patch Embedding (the core of image conversion sequence): Use convolution or flattening + linear layer to convert image blocks into fixed-dimensional vectors, which is equivalent to translating the image into "words" recognized by the Transformer
  2. CLS Token: Splice a learnable "global summary vector" at the front of the input sequence, and finally use it for classification - this ingenious idea is directly borrowed from BERT
  3. Learnable position coding: Inject the "position information" of the image block into the vector (because the Transformer itself has no sense of position and must tell it which block is where)
  4. Transformer Encoder: multi-head self-attention + fully connected feed-forward network + residual connection + layer normalization, classic recipe

3. PyTorch minimalist implementation: build ViT-B/16 from scratch

Below we use PyTorch to implement the most classic variant of ViT - ViT-B/16 (Base scale, 16×16 block size). The code strives to be clear, and key steps are commented.

3.1 Step 1: Cut the image into “visual words”

import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbedding(nn.Module):
    """
    图像分块嵌入:用卷积高效实现“分块+线性投影”
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.n_patches = (img_size // patch_size) ** 2  # 14×14=196个块
        
        # 卷积核大小=stride=patch_size,一步到位分块+投影
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # 输入: (B, C, H, W) → 输出: (B, n_patches, embed_dim)
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

# 测试一下
if __name__ == "__main__":
    patch_embed = PatchEmbedding()
    dummy_img = torch.randn(2, 3, 224, 224)  # 2张RGB图
    print(f"Patch嵌入后形状: {patch_embed(dummy_img).shape}")  # 输出: torch.Size([2, 196, 768])

3.2 Step 2: Build the Transformer encoder block

class MultiHeadAttention(nn.Module):
    """
    多头自注意力:简化版实现
    """
    def __init__(self, embed_dim=768, n_heads=12, dropout=0.1):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = embed_dim // n_heads
        self.scale = self.head_dim ** -0.5
        
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, N, C = x.shape
        # 1. 计算QKV并拆分多头
        qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        
        # 2. 缩放点积注意力
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = self.dropout(attn.softmax(dim=-1))
        
        # 3. 拼接多头并投影
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        return self.proj(x)

class TransformerBlock(nn.Module):
    """
    Transformer编码器块:Pre-Norm结构
    """
    def __init__(self, embed_dim=768, n_heads=12, mlp_ratio=4, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, n_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        # MLP:中间层维度是embed_dim的4倍
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))  # 残差连接1
        return x + self.mlp(self.norm2(x))  # 残差连接2

3.3 Step 3: Assemble the complete ViT model

class VisionTransformer(nn.Module):
    """
    完整ViT-B/16模型
    """
    def __init__(self, img_size=224, n_classes=1000, depth=12):
        super().__init__()
        self.patch_embed = PatchEmbedding()
        
        # CLS Token + 可学习位置编码
        self.cls_token = nn.Parameter(torch.zeros(1, 1, 768))
        self.pos_embed = nn.Parameter(torch.zeros(1, 196 + 1, 768))  # +1是CLS Token
        self.pos_drop = nn.Dropout(0.1)
        
        # 堆叠12层Transformer块
        self.blocks = nn.ModuleList([TransformerBlock() for _ in range(depth)])
        
        # 分类头
        self.norm = nn.LayerNorm(768)
        self.head = nn.Linear(768, n_classes)

    def forward(self, x):
        B = x.shape[0]
        
        # 1. 图像转Patch
        x = self.patch_embed(x)
        
        # 2. 拼接CLS Token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        # 3. 加位置编码
        x = self.pos_drop(x + self.pos_embed)
        
        # 4. 通过Transformer编码器
        for block in self.blocks:
            x = block(x)
        
        # 5. 取CLS Token输出分类
        return self.head(self.norm(x)[:, 0])

# 测试完整模型
if __name__ == "__main__":
    vit = VisionTransformer()
    dummy_img = torch.randn(4, 3, 224, 224)
    print(f"ViT输出形状: {vit(dummy_img).shape}")  # 输出: torch.Size([4, 1000])

4. Guide to avoid pitfalls: 3 keys to making good use of ViT

ViT does not have the prior knowledge of "local receptive fields" like CNN. If it is directly trained on small and medium-sized data sets (less than 1 million images), the effect will definitely be inferior to CNN**. ✅ Correct approach: Use large-scale pre-trained weights (such as pre-trained on ImageNet-21k, or MAE-based self-supervised weights), and then fine-tune it on your task.

4.1 Hyperparameter selection

HyperparametersRecommended configurationDescription
OptimizerAdamW (lr=1e-3, weight_decay=0.05)Must use AdamW with weighted attenuation
Learning Rate SchedulingWarmup (10 epochs) + Cosine DecayUse a small learning rate in the early stage of training to avoid CLS Token shock
Data AugmentationRandAugment + CutMix + MixUpAdvanced enhancement is the key to ViT’s convergence on small data sets
Batch SizeThe bigger the better (at least 256, 1024+ recommended)Large Batch can stabilize the attention weight training

4.2 When to use ViT and when to use CNN?

def choose_between_vit_cnn(data_size, is_speed_critical):
    if data_size < 100_000:
        return "首选CNN(ResNet/EfficientNet),可考虑用DeiT蒸馏"
    elif data_size > 1_000_000 and not is_speed_critical:
        return "首选ViT(或Swin Transformer),用大规模预训练权重微调"
    else:
        return "折中方案:MobileViT(移动端)/CoAtNet(混合架构)"

5. Summary

Vision Transformer uses the unified paradigm of "sequence modeling" to open a new door for computer vision. Although it does have the shortcomings of "eating data" and "relatively large amount of calculation", under the mode of large-scale pre-training + downstream fine-tuning, ViT has become one of the mainstream choices for tasks such as image classification, target detection, and semantic segmentation.

If you plan to further explore the evolution family of ViT, it is recommended to read in this order:

  1. DeiT: Solve the difficulty of training ViT on small data sets and use distillation methods to improve performance
  2. Swin Transformer: Introduces hierarchical structure and sliding window, more suitable for detection and segmentation tasks
  3. MAE: A masterpiece of self-supervised pre-training, which greatly reduces ViT’s need for annotated data.

1. use`timm`The library loads the pre-trained ViT-B/16, fine-tunes it on CIFAR-10, and feels its effect. 2. Try to visualize ViT’s attention heat map and intuitively understand which parts of the image the model is “looking at”.

🔗 Extended reading