Detailed practical explanation of handwritten digit recognition MNIST - A complete guide to PyTorch image classification model | Daoman PythonAI

#Handwritten digit recognition (MNIST) in practice: A complete guide to PyTorch image classification model

Introduction

Handwritten digit recognition (MNIST) is known as the Hello World of deep learning and is the gold benchmark for entry into computer vision. It was compiled and released by Yann LeCun and others. It contains 70,000 28×28 single-channel grayscale handwritten digits, covering 10 categories from 0 to 9. It is perfect for mastering the core logic and complete training process of convolutional neural networks (CNN).

📂 Stage: Stage 2 - Deep Learning Vision Basics (CNN) 🔗 Related chapters: 经典 CNN 架构剖析 · 数据增强 (Data Augmentation)

In this tutorial, we will start from scratch and use PyTorch to build a lightweight but excellent CNN model, and gradually complete the complete process of data loading, preprocessing, model construction, training, verification and single image inference. Whether you are new to deep learning or want to quickly review the image classification process, this guide can help you get started easily.


1. Get started quickly with the MNIST data set

1.1 List of core parameters

Before writing the code, first clarify the basic properties of the data set through a concise function, so that you will have an idea during subsequent operations.

import torch
import torchvision

def get_mnist_info():
    print("📊 MNIST数据集核心参数:")
    print("• 训练集:60,000张\n• 测试集:10,000张")
    print("• 尺寸:28×28×1(灰度单通道)\n• 类别:0-9(10类)")
    print("• 像素范围:原始0-255,经ToTensor转[0,1]")

get_mnist_info()

1.2 Quick preview sample

In order to intuitively feel what the data looks like, we can load the original data set and view a few pictures. Here is a simplified example, if it is already installed in the environmentmatplotlib, you can uncomment and draw the image.

def visualize_mnist():
    # 仅加载原始PIL图像用于预览
    raw_train = torchvision.datasets.MNIST(root="./data", train=True, download=True)
    print(f"✅ 样本加载成功,标签示例:{raw_train[0][1]}(第1张为数字{raw_train[0][1]})")
    # 若运行环境支持matplotlib,取消注释以下代码
    # import matplotlib.pyplot as plt
    # fig, axes = plt.subplots(2,5, figsize=(12,5))
    # for i, ax in enumerate(axes.ravel()):
    #     img, lbl = raw_train[i]
    #     ax.imshow(img, cmap="gray")
    #     ax.set_title(f"Label: {lbl}")
    #     ax.axis("off")
    # plt.tight_layout()
    # plt.show()

visualize_mnist()

2. Data preprocessing and loading

2.1 Standard preprocessing pipeline

In order to make model training more stable, we will do two things on the image:

  • Convert to Tensor (and scale pixel values ​​from 0-255 to between 0-1)
  • Standardize using the mean and standard deviation of the entire MNIST data set so that the data distribution is centered around 0.

Here's(0.1307,)and(0.3081,)It is the officially calculated single-channel mean and standard deviation, which can be reused directly without having to calculate it yourself.

from torchvision import transforms
from torch.utils.data import DataLoader

def get_dataloaders(batch_size=64):
    # 训练/测试预处理(无需复杂增强即可达到99%+)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    # 加载数据集
    train_set = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    test_set  = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
    
    # 创建加载器:训练集需要打乱,测试集不需要
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True,  num_workers=2, pin_memory=True)
    test_loader  = DataLoader(test_set,  batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    return train_loader, test_loader

💡 parameter description:

  • shuffle=TrueOnly used during training to prevent the model from remembering the order of samples
  • num_workersYou can adjust it according to your own CPU core number, generally set it to 2~4.
  • pin_memory=TrueAccelerates data transfer when GPU is available

3. Lightweight and efficient CNN model design

In order to make the code easy to understand and quickly converge to a higher accuracy, we designed a simple convolutional network with batch normalization (BatchNorm)SimpleMNISTCNN. It can stably reach a test accuracy of more than 99.2% within 10 epochs.

3.1 Network structure

The model consists of two convolutional blocks and a classification head:

  • Convolution Block 1: 1 → 32 channels, extract edges and simple textures
  • Convolution block 2: 32 → 64 channels, extract more complex shape features
  • Classification header: Add Dropout to prevent overfitting, and finally output the scores of 10 categories
import torch.nn as nn
import torch.nn.functional as F

class SimpleMNISTCNN(nn.Module):
    def __init__(self):
        super().__init__()
        # 卷积块1:提取边缘、线条特征
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)  # 保持尺寸
        self.bn1   = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(2)  # 28×28 → 14×14
        
        # 卷积块2:提取更复杂的形状
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2   = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(2)  # 14×14 → 7×7
        
        # 分类头:避免过拟合
        self.dropout = nn.Dropout(0.4)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        # 卷积块1
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool1(x)
        # 卷积块2
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool2(x)
        # 展平张量
        x = x.view(-1, 64 * 7 * 7)
        # 分类头
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.fc2(x)
        return x

3.2 Model parameter quantity check

A lightweight model is very friendly for entry-level learning. Let’s count the total number of trainable parameters:

def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

model = SimpleMNISTCNN()
print(f"🤖 SimpleMNISTCNN参数量:{count_params(model):,}")

This number is usually in the order of hundreds of thousands, which is very small and will not be too slow even when training on a CPU.


4. Complete training and evaluation

4.1 Training configuration and loop

The code below combines the model, loss function, optimizer, learning rate scheduler, and data loader all into a ready-to-run training function. During the training process, the training set and validation set performance of each epoch will be printed in real time, and the best-performing model will be automatically saved.

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import time

def train_mnist(epochs=10, lr=0.001):
    # 1. 环境初始化
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"💻 使用设备:{device}")
    
    model = SimpleMNISTCNN().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = StepLR(optimizer, step_size=5, gamma=0.5)  # 每5轮学习率减半
    train_loader, test_loader = get_dataloaders()

    # 2. 训练循环
    best_acc = 0.0
    start = time.time()
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        train_correct = 0
        
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            # 累计损失和正确数
            train_loss += loss.item() * data.size(0)
            pred = output.argmax(dim=1)
            train_correct += pred.eq(target).sum().item()
        
        # 计算本轮平均损失和准确率
        train_loss /= len(train_loader.dataset)
        train_acc = 100. * train_correct / len(train_loader.dataset)

        # 3. 验证阶段
        model.eval()
        val_loss = 0.0
        val_correct = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss = criterion(output, target)
                val_loss += loss.item() * data.size(0)
                pred = output.argmax(dim=1)
                val_correct += pred.eq(target).sum().item()
        
        val_loss /= len(test_loader.dataset)
        val_acc = 100. * val_correct / len(test_loader.dataset)

        # 更新学习率
        scheduler.step()

        # 保存最佳模型
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), "best_mnist_cnn.pth")

        # 打印日志
        print(f"\n🔹 Epoch {epoch+1}/{epochs}")
        print(f"Train: Loss={train_loss:.4f}, Acc={train_acc:.2f}%")
        print(f"Val:   Loss={val_loss:.4f}, Acc={val_acc:.2f}%")

    # 训练结束
    total_time = time.time() - start
    print(f"\n🎉 训练完成!总耗时:{total_time:.1f}s,最佳Val Acc:{best_acc:.2f}%")
    return model

4.2 Single image reasoning

After training, we can use the best saved model to predict a single picture and feel the actual effect of the model.

def infer_single_image():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 加载模型
    model = SimpleMNISTCNN().to(device)
    try:
        model.load_state_dict(torch.load("best_mnist_cnn.pth", map_location=device))
        model.eval()
        print("✅ 最佳模型加载成功")
    except FileNotFoundError:
        print("⚠️ 未找到训练好的模型,请先运行 train_mnist()")
        return
    
    # 随机取测试集1张
    _, test_loader = get_dataloaders(batch_size=1)
    data, target = next(iter(test_loader))
    data, target = data.to(device), target.to(device)
    
    # 推理
    with torch.no_grad():
        output = model(data)
        pred = output.argmax(dim=1).item()
        prob = torch.softmax(output, dim=1).max().item() * 100
    
    print(f"\n🔍 真实标签:{target.item()},预测标签:{pred},置信度:{prob:.2f}%")

5. Summary and learning suggestions

5.1 Core process review

The entire process of the MNIST project actually reflects the standard routine of image classification tasks. No matter what data set is encountered in the future, the steps will be similar:

  1. Data preparation: loading, preprocessing, batch processing
  2. Model construction: Feature extraction (CNN) + classification head (fully connected layer)
  3. Training loop: Forward propagation → Back propagation → Update parameters → Verification evaluation
  4. Model Saving and Inference: Save the best weights and load them for actual predictions

Mastering this set of processes is equivalent to getting the key to the door of image classification.

5.2 Learning Suggestions

1. Adjust hyperparameters (batch_size, learning rate, dropout rate) and observe their impact on training speed and final accuracy. 2. Try adding data enhancement (such as random rotation of 10°, translation of 0.1) and see if you can break through the 99.5% accuracy rate 3. Compare the performance and number of parameters between SimpleCNN and classic LeNet-5 to feel the differences between different architectures