Transfer Learning: Use pre-trained models to quickly build high-performance models

📂 Stage: Stage 2 - Deep Learning Vision Basics (CNN) 🔗 Related chapters: 数据增强 (Data Augmentation) · 目标检测理论


Introduction

Have you ever been troubled by these practical problems?

  • With only a few hundred privately labeled images on hand, training CNN from scratch, the accuracy is not even 60%;
  • After finally accumulating enough data, the GPU memory and time simply cannot support the training of large-scale models.

Transfer Learning is a modern deep learning tool born to solve these pain points. Its core idea is simple: reuse visual features that have been trained on large-scale general data sets (such as ImageNet's 1.2 million images), and then make a small amount of adjustments for your specific tasks, and you can quickly obtain a model with excellent performance.


1. Core concepts and working principles

1.1 Why is transfer learning so effective?

The features learned by different layers of CNN have natural universal layering:

Network levelLearned featuresTransfer value
Shallow layer (low layer)Edges, corners, basic texturesExtremely versatile, almost no modifications required
Middle layerShape combinations, simple object componentsHigh versatility, can be fine-tuned with a small amount
Deep level (top level + FC)Complete object, semantic classification (ImageNet category)Low versatility, must be replaced or significantly fine-tuned

Precisely because shallow and mid-level features are common to almost all vision tasks, we only need to adjust those deep parts related to specific categories, and can complete the adaptation with very little data.

1.2 Core migration strategy (sorted by degree of conservatism)

Depending on your data volume and computing resources, you can directly apply the following decision-making cheat sheet:

Condition combinationRecommended strategyResource/time consumptionApplicable scenarios
Small data set (<1000) + low computing powerFeature extraction (only train new classification head)Extremely lowPersonal experiment, rapid prototype verification
Small data set (<1000) + similar general tasksFeature extraction → fine-tuning the last 1-2 CNN modulesLowMedical/natural small classification tasks
Medium data set (1k-10k) + medium computing powerHierarchical fine-tuning (freeze the first N layers, fine-tune the subsequent layers + new head)MediumGeneral industrial classification, similar data sets but large enough
Large data set (>10k) + high computing powerFull fine-tuning (train all layers with a small learning rate)HighHigh precision requirements, cross-domain adaptation and then optimization

2. PyTorch implements transfer learning

2.1 Preparation: Fix outdated API + basic configuration

⚠️ IMPORTANT NOTE: torchvision 0.13+ version has been deprecatedpretrained=True, now it is recommended to use the more standardizedweights=预训练权重类

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from torchvision.models import (
    ResNet50_Weights,
    VGG16_Weights,
    MobileNet_V3_Small_Weights
)

2.2 Complete feature extraction/layered fine-tuning code

Taking the most commonly used ResNet50 + private image classification as an example, we write a function that is compatible with both feature extraction and hierarchical fine-tuning:

def build_transfer_model(
    model_name: str = "resnet50",
    num_classes: int = 10,
    strategy: str = "feature_extract",  # 可选: feature_extract, fine_tune, full_tune
    freeze_layers: int = 4  # 仅fine_tune有效;ResNet50共有8个children,这里冻结前4个
):
    """
    构建迁移学习模型
    """
    # 1. 加载带预训练权重的模型
    if model_name == "resnet50":
        weights = ResNet50_Weights.IMAGENET1K_V1
        model = models.resnet50(weights=weights)
        in_features = model.fc.in_features
        head = "fc"
    elif model_name == "vgg16":
        weights = VGG16_Weights.IMAGENET1K_V1
        model = models.vgg16(weights=weights)
        in_features = model.classifier[6].in_features
        head = "classifier"
    elif model_name == "mobilenet_v3_small":
        weights = MobileNet_V3_Small_Weights.IMAGENET1K_V1
        model = models.mobilenet_v3_small(weights=weights)
        in_features = model.classifier[3].in_features
        head = "classifier"
    else:
        raise ValueError(f"不支持的模型: {model_name}")

    # 2. 根据策略冻结参数
    if strategy == "feature_extract":
        # 冻结所有参数,只训练新分类头
        for param in model.parameters():
            param.requires_grad = False
    elif strategy == "fine_tune":
        # 只冻结前 freeze_layers 个模块
        children = list(model.children()) if hasattr(model, "children") else []
        for child in children[:freeze_layers]:
            for param in child.parameters():
                param.requires_grad = False
    # full_tune 不冻结任何参数

    # 3. 替换任务特定的分类头
    if model_name == "resnet50":
        model.fc = nn.Linear(in_features, num_classes)
    elif model_name == "vgg16":
        model.classifier[6] = nn.Linear(in_features, num_classes)
    elif model_name == "mobilenet_v3_small":
        model.classifier[3] = nn.Linear(in_features, num_classes)

    # 同时返回预训练模型要求的预处理方法,避免手动写错归一化参数
    return model, weights.transforms()

2.3 Data loading (note! Preprocessing must be consistent with pretraining)

justbuild_transfer_modelThe returned pre-training pre-processing can be directly reused, thus ensuring that our data pre-processing is exactly the same as when we originally trained ImageNet:

def get_data_loaders(
    dataset_path: str = "./data",
    preprocess: transforms.Compose = None,
    batch_size: int = 32,
    train_augment: bool = True
):
    """
    获取训练/验证数据加载器
    """
    # 训练集:预训练预处理 + 可选的数据增强
    train_transforms = [preprocess]
    if train_augment:
        # 简单而有效的数据增强
        train_transforms.insert(0, transforms.RandomHorizontalFlip(p=0.5))
        train_transforms.insert(0, transforms.Resize(256))
        train_transforms.insert(1, transforms.RandomCrop(224))
    train_transform = transforms.Compose(train_transforms)

    # 验证/测试集:只使用预训练预处理,不做增强
    val_transform = preprocess

    # 加载ImageFolder格式的数据集(最常用:根目录下直接是类别子文件夹)
    train_dataset = datasets.ImageFolder(f"{dataset_path}/train", transform=train_transform)
    val_dataset = datasets.ImageFolder(f"{dataset_path}/val", transform=val_transform)

    # 构建DataLoader
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    return train_loader, val_loader, train_dataset.classes

2.4 Training loop (minimalist version, including early stopping logic)

def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    num_epochs: int = 10,
    lr: float = 0.001,
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
):
    """
    训练迁移学习模型
    """
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    # 只优化那些需要训练的参数(即 requires_grad=True 的部分)
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    # 学习率调度器:验证准确率不再提升时自动降低学习率
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=3, factor=0.5)

    best_val_acc = 0.0
    best_model_state = model.state_dict()

    for epoch in range(num_epochs):
        # 训练阶段
        model.train()
        train_loss, train_correct, train_total = 0.0, 0, 0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)
            _, preds = outputs.max(1)
            train_total += labels.size(0)
            train_correct += preds.eq(labels).sum().item()

        train_loss /= train_total
        train_acc = train_correct / train_total

        # 验证阶段
        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * inputs.size(0)
                _, preds = outputs.max(1)
                val_total += labels.size(0)
                val_correct += preds.eq(labels).sum().item()

        val_loss /= val_total
        val_acc = val_correct / val_total
        scheduler.step(val_acc)  # 根据验证准确率调整学习率

        # 保存最佳模型
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = model.state_dict().copy()

        # 打印训练日志
        print(f"Epoch {epoch+1}/{num_epochs} | "
              f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")

    print(f"\n训练完成!最佳验证准确率: {best_val_acc:.4f}")
    model.load_state_dict(best_model_state)
    return model

3. Quick start example

Putting the above functions together, a complete transfer learning process only requires a dozen lines of code:

if __name__ == "__main__":
    # 1. 构建模型(以特征提取为例)
    model, preprocess = build_transfer_model(
        model_name="resnet50",
        num_classes=5,  # 假设你的私有数据集有5个类别
        strategy="feature_extract"
    )

    # 2. 加载数据(确保数据集是ImageFolder格式:./data/train/类1, ./data/val/类1...)
    train_loader, val_loader, classes = get_data_loaders(
        dataset_path="./data",
        preprocess=preprocess,
        batch_size=32
    )
    print(f"类别列表: {classes}")

    # 3. 训练模型
    trained_model = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=10,
        lr=0.001
    )

    # 4. 保存最佳模型
    torch.save(trained_model.state_dict(), "best_transfer_model.pth")

4. Best Practices

In order to make your migration learning journey smoother, here are some practical experiences:

  1. Preprocessing must be consistent: directly use the pre-training weights that come with ittransforms(), to avoid errors when manually filling in the normalized mean and standard deviation.
  2. Start with a conservative strategy: First use feature extraction to run through the entire process. After confirming that there are no problems with the data and code, try hierarchical fine-tuning or full fine-tuning.
  3. Layered learning rate (optional but highly recommended): When fine-tuning, you can set different learning rates for different layers, such as for new classification heads1e-3, the last 1-2 CNN modules use1e-4, the previous layer uses1e-5or smaller—this allows low-level general features to remain stable and high-level features to gradually adapt to new tasks.
  4. Early stopping mechanism: Avoid unnecessary over-training and always save the model with the best performance on the validation set.
  5. Data enhancement is essential: Even if there are only a few hundred images, simple data enhancement (horizontal flipping, random cropping) should be added, which can effectively suppress over-fitting.
Transfer learning is one of the cornerstones of modern deep learning. It is recommended to start practicing with the simplest feature extraction and gradually deepen the fine-tuning technology. Keep in mind that in current project practice, it is rare to train a complete CNN model from scratch - unless your task is completely unrelated to all public pre-training datasets, and you have extremely large-scale data.

Further reading