import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class PatchEmbedding(nn.Module):
"""
将图像分割成patch并进行嵌入
"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (img_size // patch_size) ** 2
# 使用卷积层进行patch嵌入,相当于线性投影
self.proj = nn.Conv2d(in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size)
def forward(self, x):
# x: (batch_size, channels, height, width)
x = self.proj(x) # (batch_size, embed_dim, n_patches_h, n_patches_w)
x = x.flatten(2) # (batch_size, embed_dim, n_patches)
x = x.transpose(1, 2) # (batch_size, n_patches, embed_dim)
return x
class PositionalEmbedding(nn.Module):
"""
位置嵌入,用于编码patch的位置信息
"""
def __init__(self, n_patches, embed_dim):
super().__init__()
self.embeddings = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
self.init_weights()
def init_weights(self):
nn.init.trunc_normal_(self.embeddings, mean=0.0, std=0.02)
def forward(self, x):
# x: (batch_size, n_patches, embed_dim)
embeddings = self.embeddings.repeat(x.size(0), 1, 1)
return x + embeddings
class MultiHeadAttention(nn.Module):
"""
多头自注意力机制
"""
def __init__(self, embed_dim, n_heads, dropout=0.1):
super().__init__()
self.embed_dim = embed_dim
self.n_heads = n_heads
self.head_dim = embed_dim // n_heads
assert self.head_dim * n_heads == embed_dim, "embed_dim must be divisible by n_heads"
self.qkv = nn.Linear(embed_dim, embed_dim * 3)
self.attention_dropout = nn.Dropout(dropout)
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
batch_size, seq_len, embed_dim = x.size()
# 生成Q, K, V
qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.n_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, batch_size, n_heads, seq_len, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2]
# 计算注意力分数
attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = self.attention_dropout(attention_probs)
# 应用注意力权重到V
context = torch.matmul(attention_probs, v)
context = context.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
return self.out_proj(context)
class MLPBlock(nn.Module):
"""
多层感知机块
"""
def __init__(self, embed_dim, mlp_dim, dropout=0.1):
super().__init__()
self.fc1 = nn.Linear(embed_dim, mlp_dim)
self.act = nn.GELU()
self.fc2 = nn.Linear(mlp_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class TransformerEncoderLayer(nn.Module):
"""
Transformer编码器层
"""
def __init__(self, embed_dim, n_heads, mlp_dim, dropout=0.1):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = MultiHeadAttention(embed_dim, n_heads, dropout)
self.norm2 = nn.LayerNorm(embed_dim)
self.mlp = MLPBlock(embed_dim, mlp_dim, dropout)
def forward(self, x):
# 残差连接和层归一化
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class VisionTransformer(nn.Module):
"""
Vision Transformer完整模型
"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000,
embed_dim=768, depth=12, n_heads=12, mlp_dim=3072, dropout=0.1):
super().__init__()
# Patch嵌入
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
n_patches = self.patch_embed.n_patches
# Class token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# 位置嵌入
self.pos_embed = PositionalEmbedding(n_patches, embed_dim)
# Dropout
self.dropout = nn.Dropout(dropout)
# Transformer编码器层
self.transformer_encoder = nn.Sequential(*[
TransformerEncoderLayer(embed_dim, n_heads, mlp_dim, dropout)
for _ in range(depth)
])
# 分类头
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)
def forward(self, x):
batch_size = x.size(0)
# Patch嵌入
x = self.patch_embed(x) # (batch_size, n_patches, embed_dim)
# 重复class token
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
# 拼接class token
x = torch.cat([cls_tokens, x], dim=1) # (batch_size, n_patches+1, embed_dim)
# 添加位置嵌入
x = self.pos_embed(x)
x = self.dropout(x)
# 通过Transformer编码器
x = self.transformer_encoder(x)
# 使用class token进行分类
x = self.norm(x[:, 0]) # 取出class token的输出
x = self.head(x)
return x
# 创建ViT模型实例
def create_vit_base():
"""
创建ViT-Base模型 (DeiT版本)
"""
return VisionTransformer(
img_size=224,
patch_size=16,
embed_dim=768,
depth=12,
n_heads=12,
mlp_dim=3072
)
# 模型参数统计
def count_parameters(model):
"""
统计模型参数量
"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# 示例:创建并打印模型信息
if __name__ == "__main__":
model = create_vit_base()
print(f"ViT-Base参数量: {count_parameters(model):,}")
# 测试模型
dummy_input = torch.randn(1, 3, 224, 224)
output = model(dummy_input)
print(f"输入形状: {dummy_input.shape}")
print(f"输出形状: {output.shape}")