#Swin Transformer:滑动窗口机制与层级特征提取详解
#引言
Swin Transformer是微软亚洲研究院提出的革命性视觉Transformer架构,通过引入滑动窗口机制和层级特征提取,有效解决了原始Vision Transformer在计算效率和多尺度特征提取方面的局限性。Swin Transformer已成为计算机视觉领域的主流架构之一,广泛应用于图像分类、目标检测、语义分割等多个任务。本文将深入探讨Swin Transformer的核心概念、架构细节和实现方法。
📂 所属阶段:第二阶段 — 深度学习视觉基础(CNN 篇)
🔗 相关章节:Vision Transformer (ViT) 详解 · MAE (Masked Autoencoders)
#1. Swin Transformer核心创新
#1.1 解决ViT的挑战
Swin Transformer主要解决了原始ViT的两个关键问题:
"""
Swin Transformer解决的问题:
1. 计算复杂度问题:
- ViT: O(n²) 全局注意力
- Swin: O(n) 局部注意力
2. 层级结构缺失:
- ViT: 平坦架构,无尺度变化
- Swin: 层级结构,多尺度特征
"""
def swin_vs_vit_problems():
"""
Swin Transformer与ViT的对比问题
"""
problems = {
"ViT问题": [
"全局注意力计算复杂度过高",
"缺乏层级特征提取能力",
"对高分辨率图像处理效率低",
"缺少类似CNN的归纳偏置"
],
"Swin解决方案": [
"滑动窗口注意力降低复杂度",
"层级结构实现多尺度特征",
"移位窗口实现跨窗口连接",
"保持了Transformer的优势"
]
}
print("Swin Transformer解决的问题:")
for category, issues in problems.items():
print(f"{category}:")
for issue in issues:
print(f" • {issue}")
print()
swin_vs_vit_problems()#1.2 核心创新点
def swin_core_innovations():
"""
Swin Transformer核心创新点
"""
innovations = {
"Shifted Window Attention": "局部窗口注意力机制",
"Hierarchical Feature Maps": "多尺度特征提取",
"Window Partition": "图像分割为局部窗口",
"Cyclic Shift": "周期性移位实现跨窗口连接",
"Patch Merging": "分辨率递减,通道数递增"
}
print("Swin Transformer核心创新:")
for innovation, desc in innovations.items():
print(f"• {innovation}: {desc}")
swin_core_innovations()#2. 滑动窗口机制详解
#2.1 窗口注意力机制
窗口注意力是Swin Transformer的核心组件,它将全局注意力限制在局部窗口内。
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class WindowAttention(nn.Module):
"""
窗口注意力模块
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# 相对位置偏置
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
)
# 位置索引
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def window_attention_explanation():
"""
窗口注意力机制解释
"""
print("窗口注意力机制:")
print("1. 将特征图分割为不重叠的窗口")
print("2. 每个窗口内计算局部自注意力")
print("3. 使用相对位置编码")
print("4. 时间复杂度从O(n²)降至O(n)")
window_attention_explanation()#2.2 移位窗口机制
移位窗口是Swin Transformer的另一个关键创新,通过周期性移位实现跨窗口连接。
def window_partition(x, window_size):
"""
将特征图分割为不重叠的窗口
Args:
x: (B, H, W, C)
window_size (int): 窗口大小
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
将窗口重新组合为特征图
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): 窗口大小
H (int): 图像高度
W (int): 图像宽度
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class SwinTransformerBlock(nn.Module):
"""
Swin Transformer基本块
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# 如果分辨率小于等于窗口大小,不移位
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=torch.Size([self.window_size, self.window_size]), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = nn.Dropout(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if self.shift_size > 0:
# 计算注意力掩码
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# 循环移位
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# 分割为窗口
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# 窗口重排
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# 逆转移位
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def shifted_window_explanation():
"""
移位窗口机制解释
"""
print("移位窗口机制:")
print("1. 非移位层:普通窗口注意力")
print("2. 移位层:窗口循环移位")
print("3. 移位大小:window_size//2")
print("4. 实现跨窗口连接")
print("5. 保持计算效率的同时实现全局感受野")
shifted_window_explanation()#3. 层级结构设计
#3.1 Patch Merging
层级结构通过Patch Merging实现分辨率递减和通道数递增。
class PatchMerging(nn.Module):
"""
Patch Merging层,用于降低分辨率、增加通道数
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
def hierarchical_structure_explanation():
"""
层级结构解释
"""
print("Swin Transformer层级结构:")
print("Stage 0: Resolution H×W, Channels C")
print("Stage 1: Resolution H/2×W/2, Channels 2C")
print("Stage 2: Resolution H/4×W/4, Channels 4C")
print("Stage 3: Resolution H/8×W/8, Channels 8C")
hierarchical_structure_explanation()#3.2 完整Swin Transformer架构
class BasicLayer(nn.Module):
"""
一个Swin Transformer基本层,包含多个Swin Transformer块
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# 构建Swin Transformer块
self.blocks = nn.ModuleList([
SwinTransformerBlock(
dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
# 下采样层
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
class Mlp(nn.Module):
"""
MLP模块
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def swin_architecture_overview():
"""
Swin Transformer架构概览
"""
print("Swin Transformer整体架构:")
print("1. Patch Partition: 将图像分割为patch")
print("2. Linear Embedding: 线性投影")
print("3. Stage 0: 基础特征提取")
print("4. Stage 1: 分辨率减半,通道加倍")
print("5. Stage 2: 进一步降采样")
print("6. Stage 3: 最深层特征提取")
print("7. Global Average Pooling: 全局池化")
print("8. Classifier: 分类头")
swin_architecture_overview()#4. 计算复杂度分析
#4.1 复杂度对比
def complexity_analysis():
"""
Swin Transformer与ViT复杂度对比
"""
analysis = {
"ViT (Vision Transformer)": {
"Self-Attention Complexity": "O(n²)",
"Where n": "number of patches",
"For 224x224 image": "196² = 38,416 operations per layer",
"Memory Usage": "High due to global attention"
},
"Swin Transformer": {
"Window Attention Complexity": "O(n)",
"Where n": "number of patches",
"Per Window": "O(window_size²)",
"Overall": "O(n) for fixed window size",
"Memory Usage": "Lower due to local attention"
}
}
print("计算复杂度分析:")
print("• ViT: 全局注意力,复杂度O(n²)")
print("• Swin: 窗口注意力,复杂度O(n)")
print("• 优势: Swin在处理高分辨率图像时更加高效")
print("\n具体对比:")
for model, specs in analysis.items():
print(f"\n{model}:")
for spec, value in specs.items():
print(f" - {spec}: {value}")
complexity_analysis()#4.2 性能优势
def performance_advantages():
"""
Swin Transformer性能优势
"""
advantages = [
"计算效率高:O(n)复杂度 vs ViT的O(n²)",
"内存占用少:局部注意力减少显存需求",
"多尺度特征:层级结构提取不同尺度特征",
"跨窗口连接:移位窗口实现全局信息流动",
"可扩展性强:适用于不同视觉任务"
]
print("Swin Transformer性能优势:")
for i, advantage in enumerate(advantages, 1):
print(f"{i}. {advantage}")
performance_advantages()#5. 预训练模型使用
#5.1 使用timm库的Swin模型
def use_swin_with_timm():
"""
使用timm库的Swin Transformer模型
"""
print("使用timm库的Swin Transformer:")
print("""
import torch
import timm
# 加载不同的Swin模型
models = {
'swin_tiny_patch4_window7_224': 'Tiny模型,适合快速原型',
'swin_small_patch4_window7_224': 'Small模型,平衡性能与效率',
'swin_base_patch4_window7_224': 'Base模型,标准配置',
'swin_large_patch4_window7_224': 'Large模型,高性能'
}
# 加载预训练模型
model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)
# 推理示例
input_tensor = torch.randn(1, 3, 224, 224)
output = model(input_tensor)
print(f"输出形状: {output.shape}") # (1, 1000) for classification
# 获取中间特征
features = model.forward_features(input_tensor)
print(f"特征形状: {features.shape}")
""")
use_swin_with_timm()#5.2 使用Hugging Face Transformers
def use_swin_with_huggingface():
"""
使用Hugging Face的Swin Transformer
"""
print("使用Hugging Face Transformers:")
print("""
from transformers import SwinImageProcessor, SwinModel
from PIL import Image
import requests
# 加载处理器和模型
processor = SwinImageProcessor.from_pretrained('microsoft/swin-tiny-patch4-window7-224')
model = SwinModel.from_pretrained('microsoft/swin-tiny-patch4-window7-224')
# 处理图像
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(images=image, return_tensors="pt")
# 推理
outputs = model(**inputs)
last_hidden_states = outputs.last_hidden_state
print(f"隐藏状态形状: {last_hidden_states.shape}")
""")
use_swin_with_huggingface()#6. Swin Transformer变体
#6.1 Swin Transformer V2
def swin_v2_features():
"""
Swin Transformer V2改进
"""
improvements = [
"Post-norm: 后归一化提高训练稳定性",
"Scaled cosine attention: 改进的注意力机制",
"Log-spaced continuous position bias: 连续位置偏置",
"Larger resolution support: 支持更高分辨率"
]
print("Swin Transformer V2改进:")
for improvement in improvements:
print(f"• {improvement}")
swin_v2_features()#6.2 其他变体
def swin_variants():
"""
Swin Transformer变体
"""
variants = {
"Swin Transformer": "原始版本,适用于多种视觉任务",
"Swin Transformer V2": "改进版本,支持更高分辨率",
"SwinUNETR": "用于医学图像分割的U-Net变体",
"Video Swin Transformer": "用于视频理解任务",
"SwinIR": "用于图像恢复任务",
"Swin Matting": "用于图像抠图任务"
}
print("Swin Transformer变体:")
for variant, desc in variants.items():
print(f"• {variant}: {desc}")
swin_variants()#7. 实际应用案例
#7.1 目标检测应用
def detection_application():
"""
Swin Transformer在目标检测中的应用
"""
print("Swin Transformer用于目标检测:")
print("""
# 使用Detectron2中的Swin Transformer骨干网络
from detectron2.model_zoo import model_zoo
from detectron2.config import get_cfg
cfg = get_cfg()
cfg.MODEL.BACKBONE.NAME = "build_swintransformer_backbone"
cfg.MODEL.WEIGHTS = "detectron2://ImageNetPretrained/timm/swin_base_patch4_window7_224.pth"
# Swin Transformer作为骨干网络的优势
# 1. 更好的多尺度特征表示
# 2. 更高的检测精度
# 3. 更好的效率
""")
detection_application()#7.2 语义分割应用
def segmentation_application():
"""
Swin Transformer在语义分割中的应用
"""
print("Swin Transformer用于语义分割:")
print("""
# 使用Swin Transformer作为语义分割骨干网络
# 通常与UperNet等头部结合
class SwinSegmentationHead(nn.Module):
def __init__(self, in_channels, num_classes):
super().__init__()
# Swin Transformer提供多尺度特征
self.decode_head = UperHead(
in_channels=in_channels,
num_classes=num_classes
)
def forward(self, features):
# 使用Swin Transformer的多层级特征
seg_logits = self.decode_head(features)
return seg_logits
# 优势:
# 1. 层级特征适合分割任务
# 2. 全局感受野提升精度
# 3. 移位窗口保持边界细节
""")
segmentation_application()#相关教程
#8. 总结
Swin Transformer是计算机视觉领域的里程碑式工作:
核心创新:
- 滑动窗口注意力:局部注意力降低计算复杂度
- 层级结构:多尺度特征提取能力
- 移位窗口:实现跨窗口连接
技术优势:
- 计算效率高
- 多尺度特征
- 适用于多种视觉任务
💡 重要提醒:Swin Transformer已成为视觉任务的标准骨干网络,在分类、检测、分割等任务上都取得了优异性能。理解其架构设计对深度学习研究和应用都很重要。
🔗 扩展阅读

