Swin Transformer: Detailed explanation of sliding window mechanism and hierarchical feature extraction

Swin Transformer is a revolutionary visual Transformer benchmark architecture proposed by Microsoft Research Asia in 2021. Through local sliding windows and hierarchical structures, it perfectly solves the fatal flaws of the original ViT (Vision Transformer) in computation efficiency and multi-scale feature extraction. It has become the preferred backbone network for full visual tasks such as image classification, target detection, and semantic segmentation.

📂 Stage: Stage 2 - Deep Learning Visual Basics (Visual Transformer Supplement) 🔗 Pre-reading: Vision Transformer (ViT) 详解 · 注意力机制


1. Core innovation: solving ViT’s two major pain points

The original ViT treats the image as "a bunch of independent flat patches", global attention leads to the calculation amount increasing with the square of the number of tokens, and the single resolution structure ** lacks CNN-style multi-scale inductive bias**.

1.1 Improvements compared to ViT

DimensionsViTSwin Transformer
Self-attention rangeGlobalLocal fixed window
Computational complexityProportional to the square of the number of tokensLinear to the number of tokens (window size is fixed)
Feature levelSingle resolution4 levels (similar to ResNet's Stem+3 downsampling)
Cross-window information flowNone (only global indirection)Direct implementation of periodic shifted window (Shifted Window)
High resolution processing efficiencyVery lowGood

2. Detailed explanation of core mechanism

2.1 Window Attention (W-MSA)

The global attention is changed to be calculated within non-overlapping local windows, and the complexity is directly dimensionally reduced. At the same time, relative position coding is introduced to preserve local spatial relationships.

import torch
import torch.nn as nn
from einops import rearrange

class WindowAttention(nn.Module):
    """窗口注意力模块(W-MSA/SW-MSA的核心)"""
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.dim = dim
        self.window_size = window_size  # (Wh, Ww)
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        # 初始化相对位置偏置表
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2*window_size[0]-1)*(2*window_size[1]-1), num_heads)
        )
        nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)

        # 预计算相对位置索引(避免重复计算)
        coords_h, coords_w = torch.arange(window_size[0]), torch.arange(window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij'))  # (2, Wh, Ww)
        coords_flat = torch.flatten(coords, 1)  # (2, Nw=Nw=Nw²)
        relative_coords = coords_flat[:, :, None] - coords_flat[:, None, :]  # (2, Nw, Nw)
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        # 偏移到非负整数范围
        relative_coords[:, :, 0] += window_size[0] - 1
        relative_coords[:, :, 1] += window_size[1] - 1
        # 合并为一维索引
        relative_coords[:, :, 0] *= 2*window_size[1] - 1
        self.register_buffer("relative_position_index", relative_coords.sum(-1))  # (Nw, Nw)

        self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        Args:
            x: (num_windows*B, Nw, C)
            mask: (num_windows, Nw, Nw) or None
        Returns:
            x: (num_windows*B, Nw, C)
        """
        B_, Nw, C = x.shape
        qkv = self.qkv(x).reshape(B_, Nw, 3, self.num_heads, C//self.num_heads).permute(2,0,3,1,4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        q *= self.scale
        attn = q @ k.transpose(-2, -1)

        # 添加相对位置偏置
        rel_pos_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
        rel_pos_bias = rel_pos_bias.view(Nw, Nw, -1).permute(2,0,1).contiguous()
        attn += rel_pos_bias.unsqueeze(0)

        # 应用移位窗口的注意力掩码
        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_//nW, nW, self.num_heads, Nw, Nw) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, Nw, Nw)
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1,2).reshape(B_, Nw, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

2.2 Shift Window (SW-MSA)

If only W-MSA is used, there is no information exchange between windows and the global receptive field cannot be simulated. Swin implements cross-window connections through periodic window shifting and uses masks to avoid invalid attention after shifting.

def window_partition(x, window_size):
    """(B, H, W, C) → (num_windows*B, window_size, window_size, C)"""
    B, H, W, C = x.shape
    x = x.view(B, H//window_size, window_size, W//window_size, window_size, C)
    return x.permute(0,1,3,2,4,5).contiguous().view(-1, window_size, window_size, C)

def window_reverse(windows, window_size, H, W):
    """(num_windows*B, window_size, window_size, C) → (B, H, W, C)"""
    B = int(windows.shape[0] / (H*W/window_size/window_size))
    x = windows.view(B, H//window_size, W//window_size, window_size, window_size, -1)
    return x.permute(0,1,3,2,4,5).contiguous().view(B, H, W, -1)

class SwinTransformerBlock(nn.Module):
    """Swin基本块:W-MSA + FFN → SW-MSA + FFN(相邻块交替使用W/SW)"""
    def __init__(self, dim, input_res, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., drop_path=0., norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.input_res = input_res  # (H, W)
        self.window_size = window_size
        self.shift_size = shift_size
        if min(input_res) <= window_size:
            self.shift_size, self.window_size = 0, min(input_res)

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(dim, (self.window_size,)*2, num_heads)
        self.drop_path = nn.Dropout(drop_path) if drop_path>0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim*mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(dim*mlp_ratio), dim)
        )

        # 预计算SW-MSA的注意力掩码
        if self.shift_size > 0:
            H, W = self.input_res
            img_mask = torch.zeros(1, H, W, 1)
            # 分割移位后的图像为9个区域(仅边界区域需要掩码)
            h_slices = (slice(0, -window_size), slice(-window_size, -shift_size), slice(-shift_size, None))
            w_slices = (slice(0, -window_size), slice(-window_size, -shift_size), slice(-shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1
            mask_windows = window_partition(img_mask, window_size).view(-1, window_size**2)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, 0.0)
        else:
            attn_mask = None
        self.register_buffer("attn_mask", attn_mask)

    def forward(self, x):
        B, L, C = x.shape
        H, W = self.input_res
        assert L == H*W

        shortcut = x
        x = self.norm1(x).view(B, H, W, C)

        # 周期性移位(SW-MSA用)
        if self.shift_size > 0:
            x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1,2))

        # 窗口分割→注意力→窗口还原
        x = window_partition(x, self.window_size).view(-1, self.window_size**2, C)
        x = self.attn(x, self.attn_mask)
        x = window_reverse(x.view(-1, self.window_size, self.window_size, C), self.window_size, H, W)

        # 逆转移位(SW-MSA用)
        if self.shift_size > 0:
            x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1,2))

        x = x.view(B, H*W, C)
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

2.3 Patch Merging (core of hierarchical structure)

Similar to CNN's pooling + channel fusion, the resolution is halved, the number of channels is doubled, and a multi-scale feature pyramid is constructed:

  1. Take adjacent 2×2 patches
  2. Splicing channel (4C)
  3. Linear projection dimensionality reduction to 2C
class PatchMerging(nn.Module):
    def __init__(self, input_res, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_res = input_res
        self.norm = norm_layer(4*dim)
        self.reduction = nn.Linear(4*dim, 2*dim, bias=False)

    def forward(self, x):
        H, W = self.input_res
        B, L, C = x.shape
        assert L == H*W and H%2==0 and W%2==0

        x = x.view(B, H, W, C)
        # 取2×2相邻patch
        x0 = x[:, 0::2, 0::2, :]
        x1 = x[:, 1::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, :]
        x3 = x[:, 1::2, 1::2, :]
        x = torch.cat([x0,x1,x2,x3], -1).view(B, -1, 4*C)
        return self.reduction(self.norm(x))

3. Quickly get started with the pre-trained model

There are two most commonly used ways to load Swin pre-trained models: timm library (simple and efficient) and Hugging Face Transformers (more versatile).

3.1 Using timm library

import torch
import timm

# 查看所有可用的Swin模型
# print(timm.list_models('swin*', pretrained=True))

# 加载预训练分类模型(Base版本)
model = timm.create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=1000)
model.eval()

# 推理示例
input_tensor = torch.randn(1, 3, 224, 224)
with torch.no_grad():
    output = model(input_tensor)  # (1, 1000) ImageNet分类输出
    print(f"分类输出形状: {output.shape}")

# 获取中间多尺度特征(用于检测/分割)
model = timm.create_model('swin_base_patch4_window7_224', pretrained=True, features_only=True)
with torch.no_grad():
    features = model(input_tensor)  # 4个不同尺度的特征
    for i, feat in enumerate(features):
        print(f"Stage {i+1} 特征形状: {feat.shape}")

3.2 Using Hugging Face

from transformers import SwinImageProcessor, SwinForImageClassification
from PIL import Image
import requests

# 加载处理器和模型
processor = SwinImageProcessor.from_pretrained('microsoft/swin-tiny-patch4-window7-224')
model = SwinForImageClassification.from_pretrained('microsoft/swin-tiny-patch4-window7-224')
model.eval()

# 处理真实图像
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(images=image, return_tensors="pt")

# 推理并预测
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
print(f"预测类别: {model.config.id2label[predicted_class_idx]}")

4. Summary and learning suggestions

Three cores of Swin Transformer:

  1. Local sliding window: Make the calculation amount linearly related to the number of tokens
  2. Shift window + mask: achieve global information flow across windows
  3. Patch Merging: Constructing a multi-scale feature pyramid

Study suggestions:

  • First run through the pre-trained model through timm/Hugging Face
  • Focus on understanding the mask generation and periodic shifting of SW-MSA
  • You can try downstream task practice in combination with UpperNet/DINOv2

💡 Recommended Reading: