手写数字识别 (MNIST) 实战:PyTorch图像分类模型完整指南

引言

手写数字识别(MNIST)被誉为深度学习的Hello World,是计算机视觉入门的黄金基准。它由Yann LeCun等人整理发布,包含7万张28×28单通道灰度手写数字,覆盖0-9共10类,完美适合掌握卷积神经网络(CNN)的核心逻辑与完整训练流程。

📂 所属阶段:第二阶段 — 深度学习视觉基础(CNN 篇)
🔗 相关章节:经典 CNN 架构剖析 · 数据增强 (Data Augmentation)


1. MNIST数据集快速上手

1.1 核心参数一览

先通过简洁的代码块明确数据集的基础属性,避免冗长文字:

import torch
import torchvision

def get_mnist_info():
    print("📊 MNIST数据集核心参数:")
    print(f"• 训练集:60,000张\n• 测试集:10,000张")
    print(f"• 尺寸:28×28×1(灰度单通道)\n• 类别:0-9(10类)")
    print(f"• 像素范围:原始0-255,经ToTensor转[0,1]")

get_mnist_info()

1.2 样本可视化(简化版)

无需完整运行matplotlib,保留核心可执行逻辑:

def visualize_mnist():
    # 仅加载原始PIL图像用于预览
    raw_train = torchvision.datasets.MNIST(root="./data", train=True, download=True)
    print(f"✅ 样本加载成功,标签示例:{raw_train[0][1]}(第1张为数字{raw_train[0][1]})")
    # 若运行环境支持matplotlib,取消注释以下代码
    # import matplotlib.pyplot as plt
    # fig, axes = plt.subplots(2,5, figsize=(12,5))
    # for i, ax in enumerate(axes.ravel()):
    #     img, lbl = raw_train[i]
    #     ax.imshow(img, cmap="gray")
    #     ax.set_title(f"Label: {lbl}")
    #     ax.axis("off")
    # plt.tight_layout()
    # plt.show()

visualize_mnist()

2. 数据预处理与加载

2.1 标准预处理管道

MNIST的归一化参数(0.1307,)(0.3081,)是官方预设的全数据集均值与标准差,直接复用即可:

from torchvision import transforms
from torch.utils.data import DataLoader

def get_dataloaders(batch_size=64):
    # 训练/测试预处理(无需复杂增强即可达到99%+)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    # 加载数据集
    train_set = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    test_set = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
    # 创建加载器:shuffle仅训练集用,num_workers根据CPU核数调(2-4适合入门)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    return train_loader, test_loader

3. 轻量高效CNN模型设计

本文采用轻量但带BN的SimpleCNN,既易理解,又能在10个epoch内稳定达到99.2%+的测试准确率:

import torch.nn as nn
import torch.nn.functional as F

class SimpleMNISTCNN(nn.Module):
    def __init__(self):
        super().__init__()
        # 卷积块1:提取边缘、线条特征
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)  # 保持尺寸
        self.bn1 = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(2)  # 28→14
        # 卷积块2:提取更复杂的形状
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(2)  # 14→7
        # 分类头:避免过拟合
        self.dropout = nn.Dropout(0.4)
        self.fc1 = nn.Linear(64*7*7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        # 卷积块1
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool1(x)
        # 卷积块2
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool2(x)
        # 展平
        x = x.view(-1, 64*7*7)
        # 分类头
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.fc2(x)
        return x

模型参数量检查

简单统计可训练参数,确保模型轻量:

def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

model = SimpleMNISTCNN()
print(f"🤖 SimpleMNISTCNN参数量:{count_params(model):,}")

4. 完整训练与评估

4.1 训练配置与循环

整合所有组件,编写可直接运行的训练循环:

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import time

def train_mnist(epochs=10, lr=0.001):
    # 1. 环境初始化
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"💻 使用设备:{device}")
    model = SimpleMNISTCNN().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = StepLR(optimizer, step_size=5, gamma=0.5)  # 每5轮降一半学习率
    train_loader, test_loader = get_dataloaders()

    # 2. 训练循环
    best_acc = 0.0
    start = time.time()
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        train_correct = 0
        # 训练阶段
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            # 统计
            train_loss += loss.item() * data.size(0)
            pred = output.argmax(dim=1)
            train_correct += pred.eq(target).sum().item()
        # 计算轮次平均
        train_loss /= len(train_loader.dataset)
        train_acc = 100. * train_correct / len(train_loader.dataset)

        # 3. 验证阶段
        model.eval()
        val_loss = 0.0
        val_correct = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss = criterion(output, target)
                val_loss += loss.item() * data.size(0)
                pred = output.argmax(dim=1)
                val_correct += pred.eq(target).sum().item()
        val_loss /= len(test_loader.dataset)
        val_acc = 100. * val_correct / len(test_loader.dataset)

        # 更新学习率
        scheduler.step()

        # 保存最佳模型
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), "best_mnist_cnn.pth")

        # 打印日志
        print(f"\n🔹 Epoch {epoch+1}/{epochs}")
        print(f"Train: Loss={train_loss:.4f}, Acc={train_acc:.2f}%")
        print(f"Val:   Loss={val_loss:.4f}, Acc={val_acc:.2f}%")

    # 训练结束
    total_time = time.time() - start
    print(f"\n🎉 训练完成!总耗时:{total_time:.1f}s,最佳Val Acc:{best_acc:.2f}%")
    return model

4.2 单张图像推理

加载最佳模型进行简单的推理测试:

def infer_single_image():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 加载模型
    model = SimpleMNISTCNN().to(device)
    try:
        model.load_state_dict(torch.load("best_mnist_cnn.pth", map_location=device))
        model.eval()
        print("✅ 最佳模型加载成功")
    except FileNotFoundError:
        print("⚠️ 未找到训练好的模型,请先运行train_mnist()")
        return
    
    # 随机取测试集1张
    _, test_loader = get_dataloaders(batch_size=1)
    data, target = next(iter(test_loader))
    data, target = data.to(device), target.to(device)
    
    # 推理
    with torch.no_grad():
        output = model(data)
        pred = output.argmax(dim=1).item()
        prob = torch.softmax(output, dim=1).max().item() * 100
    
    print(f"\n🔍 真实标签:{target.item()},预测标签:{pred},置信度:{prob:.2f}%")

5. 总结与学习建议

5.1 核心流程回顾

MNIST的训练流程完全通用于所有图像分类任务:

  1. 数据准备:加载、预处理、批处理
  2. 模型构建:特征提取(CNN)+ 分类头(FC)
  3. 训练循环:前向→反向→更新→验证
  4. 模型部署/推理:保存最佳权重、加载推理

5.2 学习建议

1. 调整超参数(batch_size、lr、dropout率),观察对结果的影响 2. 尝试加入数据增强(随机旋转10°、平移0.1),看能否突破99.5% 3. 对比SimpleCNN与LeNet-5的性能与参数量

相关教程