Swin Transformer:滑动窗口机制与层级特征提取详解

引言

Swin Transformer是微软亚洲研究院提出的革命性视觉Transformer架构,通过引入滑动窗口机制和层级特征提取,有效解决了原始Vision Transformer在计算效率和多尺度特征提取方面的局限性。Swin Transformer已成为计算机视觉领域的主流架构之一,广泛应用于图像分类、目标检测、语义分割等多个任务。本文将深入探讨Swin Transformer的核心概念、架构细节和实现方法。

📂 所属阶段:第二阶段 — 深度学习视觉基础(CNN 篇)
🔗 相关章节:Vision Transformer (ViT) 详解 · MAE (Masked Autoencoders)


1. Swin Transformer核心创新

1.1 解决ViT的挑战

Swin Transformer主要解决了原始ViT的两个关键问题:

"""
Swin Transformer解决的问题:

1. 计算复杂度问题:
   - ViT: O(n²) 全局注意力
   - Swin: O(n) 局部注意力
   
2. 层级结构缺失:
   - ViT: 平坦架构,无尺度变化
   - Swin: 层级结构,多尺度特征
"""

def swin_vs_vit_problems():
    """
    Swin Transformer与ViT的对比问题
    """
    problems = {
        "ViT问题": [
            "全局注意力计算复杂度过高",
            "缺乏层级特征提取能力",
            "对高分辨率图像处理效率低",
            "缺少类似CNN的归纳偏置"
        ],
        "Swin解决方案": [
            "滑动窗口注意力降低复杂度",
            "层级结构实现多尺度特征",
            "移位窗口实现跨窗口连接",
            "保持了Transformer的优势"
        ]
    }
    
    print("Swin Transformer解决的问题:")
    for category, issues in problems.items():
        print(f"{category}:")
        for issue in issues:
            print(f"  • {issue}")
        print()

swin_vs_vit_problems()

1.2 核心创新点

def swin_core_innovations():
    """
    Swin Transformer核心创新点
    """
    innovations = {
        "Shifted Window Attention": "局部窗口注意力机制",
        "Hierarchical Feature Maps": "多尺度特征提取",
        "Window Partition": "图像分割为局部窗口",
        "Cyclic Shift": "周期性移位实现跨窗口连接",
        "Patch Merging": "分辨率递减,通道数递增"
    }
    
    print("Swin Transformer核心创新:")
    for innovation, desc in innovations.items():
        print(f"• {innovation}: {desc}")

swin_core_innovations()

2. 滑动窗口机制详解

2.1 窗口注意力机制

窗口注意力是Swin Transformer的核心组件,它将全局注意力限制在局部窗口内。

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

class WindowAttention(nn.Module):
    """
    窗口注意力模块
    """
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # 相对位置偏置
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
        )

        # 位置索引
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

def window_attention_explanation():
    """
    窗口注意力机制解释
    """
    print("窗口注意力机制:")
    print("1. 将特征图分割为不重叠的窗口")
    print("2. 每个窗口内计算局部自注意力")
    print("3. 使用相对位置编码")
    print("4. 时间复杂度从O(n²)降至O(n)")

window_attention_explanation()

2.2 移位窗口机制

移位窗口是Swin Transformer的另一个关键创新,通过周期性移位实现跨窗口连接。

def window_partition(x, window_size):
    """
    将特征图分割为不重叠的窗口
    Args:
        x: (B, H, W, C)
        window_size (int): 窗口大小
    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    """
    将窗口重新组合为特征图
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): 窗口大小
        H (int): 图像高度
        W (int): 图像宽度
    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

class SwinTransformerBlock(nn.Module):
    """
    Swin Transformer基本块
    """
    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # 如果分辨率小于等于窗口大小,不移位
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=torch.Size([self.window_size, self.window_size]), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = nn.Dropout(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        if self.shift_size > 0:
            # 计算注意力掩码
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # 循环移位
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # 分割为窗口
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

        # 窗口重排
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

        # 逆转移位
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

def shifted_window_explanation():
    """
    移位窗口机制解释
    """
    print("移位窗口机制:")
    print("1. 非移位层:普通窗口注意力")
    print("2. 移位层:窗口循环移位")
    print("3. 移位大小:window_size//2")
    print("4. 实现跨窗口连接")
    print("5. 保持计算效率的同时实现全局感受野")

shifted_window_explanation()

3. 层级结构设计

3.1 Patch Merging

层级结构通过Patch Merging实现分辨率递减和通道数递增。

class PatchMerging(nn.Module):
    """
    Patch Merging层,用于降低分辨率、增加通道数
    """
    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

def hierarchical_structure_explanation():
    """
    层级结构解释
    """
    print("Swin Transformer层级结构:")
    print("Stage 0: Resolution H×W, Channels C")
    print("Stage 1: Resolution H/2×W/2, Channels 2C")
    print("Stage 2: Resolution H/4×W/4, Channels 4C")
    print("Stage 3: Resolution H/8×W/8, Channels 8C")

hierarchical_structure_explanation()

3.2 完整Swin Transformer架构

class BasicLayer(nn.Module):
    """
    一个Swin Transformer基本层,包含多个Swin Transformer块
    """
    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        # 构建Swin Transformer块
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(
                dim=dim, input_resolution=input_resolution,
                num_heads=num_heads, window_size=window_size,
                shift_size=0 if (i % 2 == 0) else window_size // 2,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop, attn_drop=attn_drop,
                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                norm_layer=norm_layer)
            for i in range(depth)])

        # 下采样层
        if downsample is not None:
            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def forward(self, x):
        for blk in self.blocks:
            x = blk(x)
        if self.downsample is not None:
            x = self.downsample(x)
        return x

class Mlp(nn.Module):
    """
    MLP模块
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

def swin_architecture_overview():
    """
    Swin Transformer架构概览
    """
    print("Swin Transformer整体架构:")
    print("1. Patch Partition: 将图像分割为patch")
    print("2. Linear Embedding: 线性投影")
    print("3. Stage 0: 基础特征提取")
    print("4. Stage 1: 分辨率减半,通道加倍")
    print("5. Stage 2: 进一步降采样")
    print("6. Stage 3: 最深层特征提取")
    print("7. Global Average Pooling: 全局池化")
    print("8. Classifier: 分类头")

swin_architecture_overview()

4. 计算复杂度分析

4.1 复杂度对比

def complexity_analysis():
    """
    Swin Transformer与ViT复杂度对比
    """
    analysis = {
        "ViT (Vision Transformer)": {
            "Self-Attention Complexity": "O(n²)",
            "Where n": "number of patches",
            "For 224x224 image": "196² = 38,416 operations per layer",
            "Memory Usage": "High due to global attention"
        },
        "Swin Transformer": {
            "Window Attention Complexity": "O(n)",
            "Where n": "number of patches",
            "Per Window": "O(window_size²)",
            "Overall": "O(n) for fixed window size",
            "Memory Usage": "Lower due to local attention"
        }
    }
    
    print("计算复杂度分析:")
    print("• ViT: 全局注意力,复杂度O(n²)")
    print("• Swin: 窗口注意力,复杂度O(n)")
    print("• 优势: Swin在处理高分辨率图像时更加高效")
    
    print("\n具体对比:")
    for model, specs in analysis.items():
        print(f"\n{model}:")
        for spec, value in specs.items():
            print(f"  - {spec}: {value}")

complexity_analysis()

4.2 性能优势

def performance_advantages():
    """
    Swin Transformer性能优势
    """
    advantages = [
        "计算效率高:O(n)复杂度 vs ViT的O(n²)",
        "内存占用少:局部注意力减少显存需求",
        "多尺度特征:层级结构提取不同尺度特征",
        "跨窗口连接:移位窗口实现全局信息流动",
        "可扩展性强:适用于不同视觉任务"
    ]
    
    print("Swin Transformer性能优势:")
    for i, advantage in enumerate(advantages, 1):
        print(f"{i}. {advantage}")

performance_advantages()

5. 预训练模型使用

5.1 使用timm库的Swin模型

def use_swin_with_timm():
    """
    使用timm库的Swin Transformer模型
    """
    print("使用timm库的Swin Transformer:")
    print("""
import torch
import timm

# 加载不同的Swin模型
models = {
    'swin_tiny_patch4_window7_224': 'Tiny模型,适合快速原型',
    'swin_small_patch4_window7_224': 'Small模型,平衡性能与效率',
    'swin_base_patch4_window7_224': 'Base模型,标准配置',
    'swin_large_patch4_window7_224': 'Large模型,高性能'
}

# 加载预训练模型
model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)

# 推理示例
input_tensor = torch.randn(1, 3, 224, 224)
output = model(input_tensor)
print(f"输出形状: {output.shape}")  # (1, 1000) for classification

# 获取中间特征
features = model.forward_features(input_tensor)
print(f"特征形状: {features.shape}")
""")

use_swin_with_timm()

5.2 使用Hugging Face Transformers

def use_swin_with_huggingface():
    """
    使用Hugging Face的Swin Transformer
    """
    print("使用Hugging Face Transformers:")
    print("""
from transformers import SwinImageProcessor, SwinModel
from PIL import Image
import requests

# 加载处理器和模型
processor = SwinImageProcessor.from_pretrained('microsoft/swin-tiny-patch4-window7-224')
model = SwinModel.from_pretrained('microsoft/swin-tiny-patch4-window7-224')

# 处理图像
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(images=image, return_tensors="pt")

# 推理
outputs = model(**inputs)
last_hidden_states = outputs.last_hidden_state
print(f"隐藏状态形状: {last_hidden_states.shape}")
""")

use_swin_with_huggingface()

6. Swin Transformer变体

6.1 Swin Transformer V2

def swin_v2_features():
    """
    Swin Transformer V2改进
    """
    improvements = [
        "Post-norm: 后归一化提高训练稳定性",
        "Scaled cosine attention: 改进的注意力机制",
        "Log-spaced continuous position bias: 连续位置偏置",
        "Larger resolution support: 支持更高分辨率"
    ]
    
    print("Swin Transformer V2改进:")
    for improvement in improvements:
        print(f"• {improvement}")

swin_v2_features()

6.2 其他变体

def swin_variants():
    """
    Swin Transformer变体
    """
    variants = {
        "Swin Transformer": "原始版本,适用于多种视觉任务",
        "Swin Transformer V2": "改进版本,支持更高分辨率",
        "SwinUNETR": "用于医学图像分割的U-Net变体",
        "Video Swin Transformer": "用于视频理解任务",
        "SwinIR": "用于图像恢复任务",
        "Swin Matting": "用于图像抠图任务"
    }
    
    print("Swin Transformer变体:")
    for variant, desc in variants.items():
        print(f"• {variant}: {desc}")

swin_variants()

7. 实际应用案例

7.1 目标检测应用

def detection_application():
    """
    Swin Transformer在目标检测中的应用
    """
    print("Swin Transformer用于目标检测:")
    print("""
# 使用Detectron2中的Swin Transformer骨干网络
from detectron2.model_zoo import model_zoo
from detectron2.config import get_cfg

cfg = get_cfg()
cfg.MODEL.BACKBONE.NAME = "build_swintransformer_backbone"
cfg.MODEL.WEIGHTS = "detectron2://ImageNetPretrained/timm/swin_base_patch4_window7_224.pth"

# Swin Transformer作为骨干网络的优势
# 1. 更好的多尺度特征表示
# 2. 更高的检测精度
# 3. 更好的效率
""")

detection_application()

7.2 语义分割应用

def segmentation_application():
    """
    Swin Transformer在语义分割中的应用
    """
    print("Swin Transformer用于语义分割:")
    print("""
# 使用Swin Transformer作为语义分割骨干网络
# 通常与UperNet等头部结合

class SwinSegmentationHead(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        # Swin Transformer提供多尺度特征
        self.decode_head = UperHead(
            in_channels=in_channels,
            num_classes=num_classes
        )
    
    def forward(self, features):
        # 使用Swin Transformer的多层级特征
        seg_logits = self.decode_head(features)
        return seg_logits

# 优势:
# 1. 层级特征适合分割任务
# 2. 全局感受野提升精度
# 3. 移位窗口保持边界细节
""")

segmentation_application()

相关教程

Swin Transformer是当前视觉领域的主流架构。建议深入理解滑动窗口机制和层级结构设计,这些创新点使其在效率和性能上都有显著提升。在实践中可以先从预训练模型开始,逐步了解架构细节。

8. 总结

Swin Transformer是计算机视觉领域的里程碑式工作:

核心创新:

  1. 滑动窗口注意力:局部注意力降低计算复杂度
  2. 层级结构:多尺度特征提取能力
  3. 移位窗口:实现跨窗口连接

技术优势:

  • 计算效率高
  • 多尺度特征
  • 适用于多种视觉任务

💡 重要提醒:Swin Transformer已成为视觉任务的标准骨干网络,在分类、检测、分割等任务上都取得了优异性能。理解其架构设计对深度学习研究和应用都很重要。

🔗 扩展阅读