title: Deep learning vision complete pass: from CNN principles to product deployment description: A complete deep learning image classification tutorial, from scratch CNN to ResNet transfer learning to Vision Transformer, including the complete process of PyTorch implementation, model deployment and web application development. keywords: [Deep learning, image classification, CNN, PyTorch, transfer learning, ResNet, Vision Transformer, ViT, machine learning, computer vision]

Deep learning visual pass: from CNN principles to product deployment

Introduction

This article is a step-by-step practical guide to deep learning image classification. We'll write a convolutional neural network from scratch, progress to pre-trained ResNet and Vision Transformer, and finally turn it into a web application that others can use. The entire process is based on the "cat and dog picture classification" task, which is suitable for beginners to understand the principles, and can also be used directly to develop your own image recognition projects.


1. Overview of image classification concepts

The goal of image classification is simple: give a picture and let the computer determine which category it belongs to. For example, if you see a photo of a cat, output "cat" instead of "dog".

1.1 Core elements

  • Input: The pixel matrix of an image (color images are usually RGB three channels).
  • Output: The probability of each category, usually expressed as a percentage, such as "cat 90%, dog 10%".
  • Challenge: The same cat will have completely different pixel values ​​under different lighting, angles, and occlusions, but the category will always remain the same.

1.2 Common applications

  • Automatic classification of products on e-commerce platforms
  • Medical imaging-assisted diagnosis
  • Face recognition and security
  • Automatic tagging of smartphone photo albums

2. Detailed explanation of convolutional neural network (CNN) technology

If the picture is directly expanded into a long series of numbers and then sent to an ordinary fully connected network, a lot of spatial relationships will be lost. By imitating the visual perception of animals, CNN can naturally capture the local texture and global structure in the image.

2.1 Core components of CNN

  1. Convolution layer: Use a set of learnable "filters" to slide over the image to extract low-level features such as edges and color blocks. After multiple layers are superimposed, more complex shapes, such as eyes and ears, can be combined.
  2. Activation layer (such as ReLU): Introduce nonlinearity to the network, retain only useful signals, and ignore noise.
  3. Pooling layer: Compress the feature map through methods such as maximum pooling to reduce the amount of calculation and make the model less sensitive to position changes (ie, translation invariance).
  4. Fully connected layer: Like traditional classifiers, all extracted features are summarized and the final category judgment is given.

Let's use PyTorch to implement a basic CNN structure to facilitate everyone's understanding of the relationship between the components.

import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicCNN(nn.Module):
    """
    基础CNN网络结构示例
    """
    def __init__(self):
        super(BasicCNN, self).__init__()
        
        # 卷积层1: 提取边缘、纹理等低级特征
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        
        # 卷积层2: 在低级特征上组合出更复杂的形状
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        
        # 最大池化层,每次将特征图尺寸减半
        self.pool = nn.MaxPool2d(2, 2)
        
        # 全连接分类器
        self.fc1 = nn.Linear(64 * 56 * 56, 128)  # 假设输入为224×224
        self.fc2 = nn.Linear(128, 2)              # 二分类:猫和狗
        
        # Dropout 防止过拟合
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, x):
        # 第一个卷积块
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        # 第二个卷积块
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        
        # 将三维特征图拉成一维向量
        x = x.view(-1, 64 * 56 * 56)
        # 全连接前向传播
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# 创建模型实例
model = BasicCNN()
print("基础CNN模型创建完成")

3. Why choose PyTorch handwritten implementation?

Although it is better to use ResNet or ViT directly, building a CNN from scratch allows you to:

  • Understand dimensional changes: Understand how the size of the image shrinks after passing through each layer, and how the number of channels increases.
  • Master Data Flow: Know how tensors are passed and transformed in the network.
  • Build lightweight models: The small handwritten CNN is only a few MB in size and is suitable for running on a laptop or even a mobile phone.

This article uses Python 3.10+ and the latest version of PyTorch. Please install the following dependencies in advance:

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
pip install Pillow numpy pandas matplotlib seaborn gradio

5. Cat and dog picture classification practice: handwriting CNN from scratch

5.1 Data enhancement and data set loading

We will use the classic cat and dog data set (scaled down version) on Kaggle. Data augmentation is crucial during training. It "expands" the data through random flipping, rotation, color fine-tuning, etc., allowing the model to learn more diverse samples, thereby enhancing generalization capabilities.

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os

# 设置设备(GPU优先)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 训练时的数据增强
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),           # 随机水平翻转
    transforms.RandomRotation(degrees=15),            # 随机旋转±15度
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # 颜色扰动
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # ImageNet标准化
])

# 验证时不使用随机增强
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# 假设数据集目录结构如下:
# data/dogs_vs_cats/
#   ├── train/
#   │   ├── cats/
#   │   └── dogs/
#   └── val/
#       ├── cats/
#       └── dogs/

try:
    train_data = datasets.ImageFolder(root='data/dogs_vs_cats/train', transform=train_transform)
    val_data = datasets.ImageFolder(root='data/dogs_vs_cats/val', transform=val_transform)
    
    train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_data, batch_size=32, shuffle=False, num_workers=4)
    
    print(f"训练集大小: {len(train_data)}")
    print(f"验证集大小: {len(val_data)}")
    print(f"类别数量: {len(train_data.classes)}")
    print(f"类别名称: {train_data.classes}")
except FileNotFoundError:
    print("数据集未找到,请确保数据集路径正确")
    print("使用模拟数据进行演示...")

5.2 Build a custom CNN model

This time we designed a slightly more complex CNN, including three convolution blocks, and added dropout and batch normalization to the fully connected layer to help the model train stably and reduce overfitting.

class CustomCNN(nn.Module):
    """
    自定义CNN模型:三层卷积 + 三层全连接
    """
    def __init__(self, num_classes=2):
        super(CustomCNN, self).__init__()
        
        # 卷积块1
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        
        # 卷积块2
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        
        # 卷积块3
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        
        # 池化与 Dropout
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.5)
        
        # 经过3次池化后,224×224 → 28×28,通道数128
        self.fc1 = nn.Linear(128 * 28 * 28, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, num_classes)
        
    def forward(self, x):
        # 三个卷积块,每个后面跟批标准化、ReLU激活和池化
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        
        # 展平并送入全连接层
        x = x.view(-1, 128 * 28 * 28)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        
        return x

custom_model = CustomCNN(num_classes=2).to(device)
print("自定义CNN模型创建完成")
print(f"模型参数数量: {sum(p.numel() for p in custom_model.parameters()):,}")

5.3 Training loop

The training loop is responsible for forward propagation, computing loss, back propagation, and parameter updates. We also monitor accuracy in real time on the validation set and automatically save the best-performing models.

def train_model(model, train_loader, val_loader, epochs=10, learning_rate=0.001):
    """
    训练模型并返回最佳验证准确率
    """
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
    
    best_val_acc = 0.0
    
    for epoch in range(epochs):
        # ---------- 训练 ----------
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()
        
        # ---------- 验证 ----------
        model.eval()
        correct_val = 0
        total_val = 0
        val_loss = 0.0
        
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                
                _, predicted = torch.max(outputs, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()
        
        train_acc = 100 * correct_train / total_train
        val_acc = 100 * correct_val / total_val
        
        print(f'Epoch [{epoch+1}/{epochs}]')
        print(f'Train Loss: {running_loss/len(train_loader):.4f}, Train Acc: {train_acc:.2f}%')
        print(f'Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {val_acc:.2f}%')
        print('-' * 50)
        
        # 保存最佳模型
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_custom_cnn.pth')
            print(f'--> 检测到更好的模型,已保存 (Val Acc: {best_val_acc:.2f}%)')
        
        scheduler.step()
    
    print(f'训练完成!最佳验证准确率: {best_val_acc:.2f}%')
    return best_val_acc

# 如果有数据,可以执行训练
# best_acc = train_model(custom_model, train_loader, val_loader, epochs=10)

6. Advanced techniques and optimization

6.1 Commonly used regularization techniques

  • Dropout: Randomly discard a part of neurons in the fully connected layer, forcing the network not to rely too much on certain specific nodes to prevent overfitting.
  • Batch Normalization: Normalize the output of each layer to make the training more stable and allow a larger learning rate to be used.
  • Weight Decay: Add L2 regularization to the optimizer to limit the size of the weight.

6.2 Learning rate scheduling

  • Step LR: Decay the learning rate by a certain proportion every few epochs.
  • Cosine Annealing: Let the learning rate gradually decrease like a cosine curve.
  • ReduceLROnPlateau: Automatically reduce learning rate when validation loss no longer improves.

At this point, you have implemented a basic CNN that can distinguish cats and dogs: the input end is standardized, feature extraction is completed by convolution, and finally the fully connected layer gives the classification result.


7. Local image test script

For the trained model to be truly useful, it needs to perform reasoning on a single image. A complete inference function is provided below, which loads the model, performs the same preprocessing on the images as for training, and then outputs the predicted category and confidence.

7.1 Inference process

  1. Recreate the exact same model structure as during training.
  2. Load the saved weight file.
  3. Resize and Normalize the input image (must be consistent with training).
  4. Perform a forward propagation and output the category with the highest probability.

7.2 Test code

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import os

def predict_local_image(image_path, model_path='best_custom_cnn.pth', num_classes=2):
    """
    对本地图片进行分类预测
    """
    if not os.path.exists(model_path):
        print(f"错误:找不到模型文件 {model_path}")
        return
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 重新构建模型结构
    model = CustomCNN(num_classes=num_classes).to(device)
    
    try:
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.eval()
        print("模型权重加载成功!")
    except Exception as e:
        print(f"模型加载失败: {e}")
        return

    # 与训练时完全相同的预处理
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    try:
        img = Image.open(image_path).convert('RGB')
        img_tensor = transform(img).unsqueeze(0).to(device)  # 增加 batch 维度
    except Exception as e:
        print(f"图片读取失败: {e}")
        return

    # 推理
    with torch.no_grad():
        outputs = model(img_tensor)
        probabilities = F.softmax(outputs, dim=1)
        confidence, predicted = torch.max(probabilities, 1)

    classes = ['猫 (Cat)', '狗 (Dog)']
    result = classes[predicted.item()]
    confidence_percent = confidence.item() * 100
    
    print("=" * 40)
    print(f"预测结果: {result}")
    print(f"置信度: {confidence_percent:.2f}%")
    print(f"各类别概率:")
    for i, class_name in enumerate(classes):
        prob = probabilities[0][i].item() * 100
        print(f"  {class_name}: {prob:.2f}%")
    print("=" * 40)

# 使用示例(取消注释以运行)
# predict_local_image('path/to/your/image.jpg')

7.3 Frequently Asked Questions and Solutions

  • Prediction results are always of a certain category: It may be that there are too few training rounds or the data is not diverse enough. The small model needs more data to generalize. It could also be that the preprocessing steps are inconsistent.
  • Size error: Please confirm that the Resize size during inference is exactly the same as during training.
  • Model file suffix: Commonly used in PyTorch.pthor.pt, there is no difference between the two.

8. Transfer learning in practice: ResNet18

8.1 Why can ResNet significantly improve the performance?

ResNet has been pre-trained on a very large-scale data set (ImageNet), and its internal convolution kernel has learned how to identify common features such as edges, textures, eyes, and ears. We only need to replace its final classification head with its own task (two classifications of cats and dogs), and then fine-tune it to achieve extremely high accuracy on very little data.

8.2 ResNet18 complete training script

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms, datasets
from torch.utils.data import DataLoader

def create_resnet_model(num_classes=2):
    """
    基于预训练 ResNet18 构建迁移学习模型
    """
    # 加载 ImageNet 预训练权重
    model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    
    # (可选)冻结所有预训练层,只训练新顶层
    # for param in model.parameters():
    #     param.requires_grad = False
    
    # 修改最后的全连接层为二分类
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    
    return model

def train_resnet_transfer(epochs=10):
    """
    训练 ResNet 迁移学习模型(示意代码)
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    model = create_resnet_model(num_classes=2).to(device)
    
    criterion = nn.CrossEntropyLoss()
    # 迁移学习通常使用较小的学习率
    optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
    
    print("ResNet18迁移学习模型已准备就绪")
    print(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}")
    print(f"可训练参数数量: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
    
    # 实际训练时需要加载 DataLoader 并编写类似5.3的循环
    torch.save(model.state_dict(), 'resnet18_transfer_init.pth')
    return model

resnet_model = train_resnet_transfer()

8.3 ResNet inference script

def predict_with_resnet(image_path, model_path='resnet18_transfer_init.pth'):
    """
    使用 ResNet18 模型预测图片
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 构建与训练时相同的模型结构
    model = models.resnet18(weights=None)
    model.fc = nn.Linear(model.fc.in_features, 2)

    try:
        model.load_state_dict(torch.load(model_path, map_location=device))
        model = model.to(device)
        model.eval()
        print("ResNet模型权重加载成功!")
    except Exception as e:
        print(f"ResNet模型加载失败: {e}")
        return

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    try:
        img = Image.open(image_path).convert('RGB')
        img_tensor = transform(img).unsqueeze(0).to(device)

        with torch.no_grad():
            outputs = model(img_tensor)
            probabilities = F.softmax(outputs, dim=1)
            confidence, predicted = torch.max(probabilities, 1)

        classes = ['猫 (Cat)', '狗 (Dog)']
        result = classes[predicted.item()]
        confidence_percent = confidence.item() * 100
        
        print(f"ResNet18预测结果: {result}")
        print(f"置信度: {confidence_percent:.2f}%")
        
        return result, confidence_percent
    except Exception as e:
        print(f"预测失败: {e}")
        return None, 0

# predict_with_resnet('path/to/your/image.jpg')

9. Model performance comparison

FeaturesHandwritten CustomCNNPre-trained ResNet18Vision Transformer
Number of network layers3 layers of convolution18 layers of residual blocksMulti-head self-attention mechanism
Training time (reference)About 30 minutesAbout 1 hour2 to 3 hours
Classification accuracy75%~80%90%~95%Above 95%
Number of parametersAbout 1 millionAbout 11 millionMore than 85 million
Suitable for scenariosLearning principles, small dataActual projects, medium dataLarge data sets, ultimate accuracy

10. Revolution in the visual field: Vision Transformer (ViT)

If CNN recognizes pictures through "local observation", then ViT understands pictures through "global attention". It moves the Transformer architecture originally used for text to images and achieves results beyond traditional CNN.

10.1 The core idea of ​​ViT: Treat pictures as sentences

  1. Image slicing (Patching): Cut an image into many small squares of fixed size, such as 16×16.
  2. Linear Projection (Embedding): Compress each small square into a vector, just like mapping a word into a word vector.
  3. Self-attention mechanism: Let each small square compare with all other squares to understand the global relationship of the entire picture, such as the positional relationship between cat ears and cat tail.

10.2 ViT transfer learning code

PyTorch already has a built-in implementation of ViT, which is as easy to use as ResNet.

from torchvision import models
import torch.nn as nn

def create_vit_model(num_classes=2):
    """
    创建基于 Vision Transformer Base 的迁移学习模型
    """
    weights = models.ViT_B_16_Weights.DEFAULT
    model = models.vit_b_16(weights=weights)
    
    # 替换分类头
    num_in_features = model.heads.head.in_features
    model.heads.head = nn.Linear(num_in_features, num_classes)
    
    return model

def train_vit_model(epochs=5):
    """
    训练 ViT 模型(示意)
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # ViT 常用的预处理
    vit_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # 与ResNet略有不同
    ])
    
    model = create_vit_model(num_classes=2).to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-5, weight_decay=0.1)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    print("Vision Transformer模型已准备就绪")
    print(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}")
    
    torch.save(model.state_dict(), 'vit_initial.pth')
    return model

vit_model = train_vit_model()

10.3 ViT inference script

def predict_with_vit(image_path, model_path='vit_initial.pth'):
    """
    使用 ViT 模型预测图片
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = models.vit_b_16(weights=None)
    model.heads.head = nn.Linear(model.heads.head.in_features, 2)

    try:
        model.load_state_dict(torch.load(model_path, map_location=device))
        model = model.to(device)
        model.eval()
        print("ViT模型权重加载成功!")
    except Exception as e:
        print(f"ViT模型加载失败: {e}")
        return

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])

    try:
        img = Image.open(image_path).convert('RGB')
        img_tensor = transform(img).unsqueeze(0).to(device)

        with torch.no_grad():
            output = model(img_tensor)
            prob = F.softmax(output, dim=1)
            score, pred = torch.max(prob, 1)

        classes = ['猫 (Cat)', '狗 (Dog)']
        result = classes[pred.item()]
        confidence = score.item() * 100
        
        print(f"ViT预测: {result} (置信度: {confidence:.2f}%)")
        
        return result, confidence
    except Exception as e:
        print(f"ViT预测失败: {e}")
        return None, 0

# predict_with_vit('path/to/your/image.jpg')

10.4 CNN vs ViT: How to choose?

DimensionResNet (CNN)ViT (Transformer)
Dependence on data volumeMedium, small data sets can be fine-tunedExtremely high, the more data, the more obvious the advantages
Training speedFaster, even calculationSlower, high requirements for video memory and memory bandwidth
Features that are good at capturingLocal details, such as hair textureGlobal structure, such as pose and shape
Hardware RequirementsLow, ordinary GPU is enoughHigh, 12GB or more of video memory is recommended
InterpretabilityGood, feature map can be drawnAverage, mainly depends on attention weight

11. Convert the model into a web application (Gradio)

If the trained model can only be called from the command line, it is very unfriendly to ordinary users. Gradio is currently the most popular machine learning front-end framework. A few lines of code can generate a web page with an interface, and can even generate a public network sharing link.

11.1 Why use Gradio?

  • No front-end code required: No need to write HTML, CSS, or JavaScript.
  • Automatically generate public network link: Settingsshare=TrueYou can get an external network address valid for 72 hours, which is convenient for sharing with friends for testing.
  • Built-in common components: image upload box, progress bar, label display, etc. are all available.

11.2 Complete code of cat and dog identification web application

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
import gradio as gr
from PIL import Image
import numpy as np

def load_model(model_type='resnet', model_path=None):
    """
    根据类型加载模型和对应的预处理
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if model_type == 'resnet':
        model = models.resnet18(weights=None)
        model.fc = nn.Linear(model.fc.in_features, 2)
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    elif model_type == 'vit':
        model = models.vit_b_16(weights=None)
        model.heads.head = nn.Linear(model.heads.head.in_features, 2)
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])
    else:  # custom cnn
        model = CustomCNN(num_classes=2)
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    
    if model_path:
        model.load_state_dict(torch.load(model_path, map_location=device))
    
    model.to(device).eval()
    return model, transform, device

def classify_image(img, model_type='ResNet18'):
    """
    接收图片,返回各类别概率
    """
    if img is None:
        return {"错误": 1.0}
    
    model, transform, device = load_model(model_type)
    
    # 将 numpy 数组转为 PIL 图像
    if isinstance(img, np.ndarray):
        img = Image.fromarray(img.astype('uint8'), 'RGB')
    
    img_tensor = transform(img).unsqueeze(0).to(device)
    
    with torch.no_grad():
        output = model(img_tensor)
        probabilities = F.softmax(output, dim=1)[0]
    
    classes = ['猫 (Cat)', '狗 (Dog)']
    results = {}
    for i, class_name in enumerate(classes):
        results[class_name] = float(probabilities[i])
    
    return results

# 构建 Gradio 界面
with gr.Blocks(title="智能猫狗识别系统") as demo:
    gr.Markdown("# 🐱 智能猫狗识别系统 🐶")
    gr.Markdown("上传一张猫或狗的照片,AI 将给出类别和置信度。")
    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(label="上传图片", type="numpy")
            model_choice = gr.Radio(
                choices=["CustomCNN", "ResNet18", "Vision Transformer"],
                value="ResNet18",
                label="选择模型"
            )
            run_button = gr.Button("开始识别")
        with gr.Column():
            output_label = gr.Label(label="识别结果")
    
    run_button.click(
        fn=classify_image,
        inputs=[image_input, model_choice],
        outputs=output_label
    )

# 启动(实际使用时取消注释)
# demo.launch(share=True)

12. Best practices for model deployment

12.1 Model optimization

  • Quantization: Convert model weights from 32-bit floating point numbers to 8-bit integers, significantly reducing model size and accelerating inference.
  • Pruning: Remove neurons and connections that contribute little to reduce the amount of calculation.
  • Knowledge Distillation: Use a large model to guide small model training, compressing the model while maintaining accuracy.

12.2 Inference acceleration

  • ONNX: Export PyTorch models to ONNX format to facilitate running on different frameworks and platforms.
  • TensorRT: An inference optimization library launched by NVIDIA that can fully utilize the parallel capabilities of the GPU.
  • OpenVINO: Intel's inference engine, optimizing inference speed on CPU and integrated graphics.

12.3 Cloud deployment

  • Docker: Package the model and all dependencies into containers to achieve "build once, run anywhere".
  • Kubernetes: Manage automatic scaling and load balancing of large-scale models.
  • API Gateway: Encapsulate the model into a RESTful API for calling by the front-end or mobile terminal.

13. Project summary and advanced direction

13.1 Technical Roadmap Review

Through this tutorial, you have completed a very complete image classification learning route:

  • Theoretical basis: The core ideas of CNN, ResNet, and ViT.
  • Hands-On: Build, train, and verify models with PyTorch.
  • Project Implementation: Turn the model into a web application, and understand the optimization and deployment methods.

13.2 What can you learn next?

  • Object Detection: Use YOLO and Faster R-CNN to frame the locations and categories of multiple objects in one picture.
  • Semantic Segmentation: Use U-Net and DeepLab to achieve fine recognition at the pixel level.
  • Multi-modal learning: fuse images and text like CLIP.
  • Self-supervised learning: Pre-training without human annotation.
  • Federated Learning: Complete model training without sharing original data.

13.3 Real World Applications

  • E-commerce: Automatically classify products and search for images.
  • Medical: Auxiliary diagnosis of X-rays and CT images.
  • Security: Face recognition, abnormal behavior detection.
  • Autonomous Driving: Traffic sign recognition, pedestrian detection.
  • Social Media: Content review, smart album classification.

Mastering deep learning image classification requires both theory and practice. It is recommended to first understand the basic concepts, then implement them, and finally consolidate knowledge through a large number of experiments. At the same time, pay attention to the latest research progress, such as Vision Transformer, Swin Transformer and other emerging architectures.

💡 Important reminder: Deep learning is a rapidly developing field. It is recommended to continue to pay attention to the latest research progress and constantly update the knowledge system. At the same time, we must pay attention to engineering practice and transform theoretical knowledge into the ability to solve practical problems.

🔗 Extended reading