#手写数字识别 (MNIST) 实战:PyTorch图像分类模型完整指南
#引言
手写数字识别(MNIST)被誉为深度学习的Hello World,是计算机视觉入门的黄金基准。它由Yann LeCun等人整理发布,包含7万张28×28单通道灰度手写数字,覆盖0-9共10类,完美适合掌握卷积神经网络(CNN)的核心逻辑与完整训练流程。
📂 所属阶段:第二阶段 — 深度学习视觉基础(CNN 篇)
🔗 相关章节:经典 CNN 架构剖析 · 数据增强 (Data Augmentation)
#1. MNIST数据集快速上手
#1.1 核心参数一览
先通过简洁的代码块明确数据集的基础属性,避免冗长文字:
import torch
import torchvision
def get_mnist_info():
print("📊 MNIST数据集核心参数:")
print(f"• 训练集:60,000张\n• 测试集:10,000张")
print(f"• 尺寸:28×28×1(灰度单通道)\n• 类别:0-9(10类)")
print(f"• 像素范围:原始0-255,经ToTensor转[0,1]")
get_mnist_info()#1.2 样本可视化(简化版)
无需完整运行matplotlib,保留核心可执行逻辑:
def visualize_mnist():
# 仅加载原始PIL图像用于预览
raw_train = torchvision.datasets.MNIST(root="./data", train=True, download=True)
print(f"✅ 样本加载成功,标签示例:{raw_train[0][1]}(第1张为数字{raw_train[0][1]})")
# 若运行环境支持matplotlib,取消注释以下代码
# import matplotlib.pyplot as plt
# fig, axes = plt.subplots(2,5, figsize=(12,5))
# for i, ax in enumerate(axes.ravel()):
# img, lbl = raw_train[i]
# ax.imshow(img, cmap="gray")
# ax.set_title(f"Label: {lbl}")
# ax.axis("off")
# plt.tight_layout()
# plt.show()
visualize_mnist()#2. 数据预处理与加载
#2.1 标准预处理管道
MNIST的归一化参数(0.1307,)和(0.3081,)是官方预设的全数据集均值与标准差,直接复用即可:
from torchvision import transforms
from torch.utils.data import DataLoader
def get_dataloaders(batch_size=64):
# 训练/测试预处理(无需复杂增强即可达到99%+)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# 加载数据集
train_set = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_set = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
# 创建加载器:shuffle仅训练集用,num_workers根据CPU核数调(2-4适合入门)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
return train_loader, test_loader#3. 轻量高效CNN模型设计
本文采用轻量但带BN的SimpleCNN,既易理解,又能在10个epoch内稳定达到99.2%+的测试准确率:
import torch.nn as nn
import torch.nn.functional as F
class SimpleMNISTCNN(nn.Module):
def __init__(self):
super().__init__()
# 卷积块1:提取边缘、线条特征
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) # 保持尺寸
self.bn1 = nn.BatchNorm2d(32)
self.pool1 = nn.MaxPool2d(2) # 28→14
# 卷积块2:提取更复杂的形状
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.pool2 = nn.MaxPool2d(2) # 14→7
# 分类头:避免过拟合
self.dropout = nn.Dropout(0.4)
self.fc1 = nn.Linear(64*7*7, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
# 卷积块1
x = F.relu(self.bn1(self.conv1(x)))
x = self.pool1(x)
# 卷积块2
x = F.relu(self.bn2(self.conv2(x)))
x = self.pool2(x)
# 展平
x = x.view(-1, 64*7*7)
# 分类头
x = self.dropout(F.relu(self.fc1(x)))
x = self.fc2(x)
return x#模型参数量检查
简单统计可训练参数,确保模型轻量:
def count_params(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
model = SimpleMNISTCNN()
print(f"🤖 SimpleMNISTCNN参数量:{count_params(model):,}")#4. 完整训练与评估
#4.1 训练配置与循环
整合所有组件,编写可直接运行的训练循环:
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import time
def train_mnist(epochs=10, lr=0.001):
# 1. 环境初始化
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"💻 使用设备:{device}")
model = SimpleMNISTCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = StepLR(optimizer, step_size=5, gamma=0.5) # 每5轮降一半学习率
train_loader, test_loader = get_dataloaders()
# 2. 训练循环
best_acc = 0.0
start = time.time()
for epoch in range(epochs):
model.train()
train_loss = 0.0
train_correct = 0
# 训练阶段
for data, target in train_loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# 统计
train_loss += loss.item() * data.size(0)
pred = output.argmax(dim=1)
train_correct += pred.eq(target).sum().item()
# 计算轮次平均
train_loss /= len(train_loader.dataset)
train_acc = 100. * train_correct / len(train_loader.dataset)
# 3. 验证阶段
model.eval()
val_loss = 0.0
val_correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
loss = criterion(output, target)
val_loss += loss.item() * data.size(0)
pred = output.argmax(dim=1)
val_correct += pred.eq(target).sum().item()
val_loss /= len(test_loader.dataset)
val_acc = 100. * val_correct / len(test_loader.dataset)
# 更新学习率
scheduler.step()
# 保存最佳模型
if val_acc > best_acc:
best_acc = val_acc
torch.save(model.state_dict(), "best_mnist_cnn.pth")
# 打印日志
print(f"\n🔹 Epoch {epoch+1}/{epochs}")
print(f"Train: Loss={train_loss:.4f}, Acc={train_acc:.2f}%")
print(f"Val: Loss={val_loss:.4f}, Acc={val_acc:.2f}%")
# 训练结束
total_time = time.time() - start
print(f"\n🎉 训练完成!总耗时:{total_time:.1f}s,最佳Val Acc:{best_acc:.2f}%")
return model#4.2 单张图像推理
加载最佳模型进行简单的推理测试:
def infer_single_image():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载模型
model = SimpleMNISTCNN().to(device)
try:
model.load_state_dict(torch.load("best_mnist_cnn.pth", map_location=device))
model.eval()
print("✅ 最佳模型加载成功")
except FileNotFoundError:
print("⚠️ 未找到训练好的模型,请先运行train_mnist()")
return
# 随机取测试集1张
_, test_loader = get_dataloaders(batch_size=1)
data, target = next(iter(test_loader))
data, target = data.to(device), target.to(device)
# 推理
with torch.no_grad():
output = model(data)
pred = output.argmax(dim=1).item()
prob = torch.softmax(output, dim=1).max().item() * 100
print(f"\n🔍 真实标签:{target.item()},预测标签:{pred},置信度:{prob:.2f}%")#5. 总结与学习建议
#5.1 核心流程回顾
MNIST的训练流程完全通用于所有图像分类任务:
- 数据准备:加载、预处理、批处理
- 模型构建:特征提取(CNN)+ 分类头(FC)
- 训练循环:前向→反向→更新→验证
- 模型部署/推理:保存最佳权重、加载推理

