迁移学习 (Transfer Learning):利用预训练模型快速构建高性能模型

📂 所属阶段:第二阶段 — 深度学习视觉基础(CNN 篇)
🔗 相关章节:数据增强 (Data Augmentation) · 目标检测理论


引言

你是否遇到过这些痛点?

  • 只有几百张私有标注数据,从零训练CNN准确率连60%都达不到?
  • 好不容易攒够数据,却发现GPU显存/时间撑不住训练一个大模型?

迁移学习(Transfer Learning) 就是解决这些问题的「现代深度学习利器」。它通过复用在大规模通用数据集(如ImageNet 120万张)上预训练的视觉特征,仅需微调少量参数,就能快速获得针对特定任务的高性能模型。


1. 核心概念与工作原理

1.1 为什么迁移学习这么有效?

CNN的层级天然具有特征通用性分层的特性:

网络层级学到的特征迁移价值
浅层(低层)边缘、角点、基础纹理通用度极高,几乎无需修改
中层形状组合、简单对象部件通用度较高,可微调少量
深层(顶层+FC)完整对象、语义分类(ImageNet类)通用度低,必须替换/大幅微调

1.2 核心迁移策略(按保守程度排序)

为了让你快速选择,整理了一张决策速查表:

条件组合推荐策略资源/时间消耗适用场景
小数据集(<1000)+ 低算力特征提取(仅训练新分类头)极低个人实验、快速原型验证
小数据集(<1000)+ 相似通用任务特征提取 → 微调最后1-2个CNN模块医学/自然小分类任务
中等数据集(1k-10k)+ 中算力分层微调(冻结前N层,微调节下层+新头)通用工业分类、数据集相似但规模够大
大数据集(>10k)+ 高算力全量微调(用小学习率训练所有层)高精度需求、跨领域适配后再优化

2. PyTorch实现迁移学习

2.1 准备工作:修正过时API+基础配置

⚠️ 注意:torchvision 0.13+ 已弃用 pretrained=True,改用 weights=预训练权重类 更规范。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from torchvision.models import (
    ResNet50_Weights,
    VGG16_Weights,
    MobileNet_V3_Small_Weights
)

2.2 完整的特征提取/分层微调代码

这里以最常用的 ResNet50 + 私有图像分类 为例,同时兼容特征提取和分层微调:

def build_transfer_model(
    model_name: str = "resnet50",
    num_classes: int = 10,
    strategy: str = "feature_extract",  # 可选: feature_extract, fine_tune, full_tune
    freeze_layers: int = 4  # 仅fine_tune有效,ResNet50有children()[:8]是骨干网络
):
    """
    构建迁移学习模型
    """
    # 1. 加载带预训练权重的模型
    if model_name == "resnet50":
        weights = ResNet50_Weights.IMAGENET1K_V1
        model = models.resnet50(weights=weights)
        in_features = model.fc.in_features
        head = "fc"
    elif model_name == "vgg16":
        weights = VGG16_Weights.IMAGENET1K_V1
        model = models.vgg16(weights=weights)
        in_features = model.classifier[6].in_features
        head = "classifier"
    elif model_name == "mobilenet_v3_small":
        weights = MobileNet_V3_Small_Weights.IMAGENET1K_V1
        model = models.mobilenet_v3_small(weights=weights)
        in_features = model.classifier[3].in_features
        head = "classifier"
    else:
        raise ValueError(f"不支持的模型: {model_name}")

    # 2. 根据策略冻结参数
    if strategy == "feature_extract":
        for param in model.parameters():
            param.requires_grad = False
    elif strategy == "fine_tune":
        children = list(model.children()) if hasattr(model, "children") else []
        for child in children[:freeze_layers]:
            for param in child.parameters():
                param.requires_grad = False
    # full_tune不冻结任何参数

    # 3. 替换任务特定的分类头
    if model_name == "resnet50":
        model.fc = nn.Linear(in_features, num_classes)
    elif model_name == "vgg16":
        model.classifier[6] = nn.Linear(in_features, num_classes)
    elif model_name == "mobilenet_v3_small":
        model.classifier[3] = nn.Linear(in_features, num_classes)

    return model, weights.transforms()  # 直接获取预训练模型要求的预处理!

2.3 数据加载(注意预处理必须和预训练一致!)

刚刚的 build_transfer_model 返回了预训练权重自带的标准预处理,我们可以直接复用,避免手动归一化参数写错:

def get_data_loaders(
    dataset_path: str = "./data",
    preprocess: transforms.Compose = None,
    batch_size: int = 32,
    train_augment: bool = True
):
    """
    获取训练/验证数据加载器
    """
    # 训练集:用预训练预处理 + 数据增强(可选)
    train_transforms = [preprocess]
    if train_augment:
        train_transforms.insert(0, transforms.RandomHorizontalFlip(p=0.5))
        train_transforms.insert(0, transforms.Resize(256))
        train_transforms.insert(1, transforms.RandomCrop(224))
    train_transform = transforms.Compose(train_transforms)

    # 验证/测试集:只用预训练预处理
    val_transform = preprocess

    # 加载ImageFolder格式的数据集(最常用:根目录下是class文件夹)
    train_dataset = datasets.ImageFolder(f"{dataset_path}/train", transform=train_transform)
    val_dataset = datasets.ImageFolder(f"{dataset_path}/val", transform=val_transform)

    # 构建DataLoader
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    return train_loader, val_loader, train_dataset.classes

2.4 训练循环(极简版,含早停逻辑)

def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    num_epochs: int = 10,
    lr: float = 0.001,
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
):
    """
    训练迁移学习模型
    """
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    # 仅优化可训练的参数!
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    # 简单的学习率调度(可选)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=3, factor=0.5)

    best_val_acc = 0.0
    best_model_state = model.state_dict()

    for epoch in range(num_epochs):
        # 训练阶段
        model.train()
        train_loss, train_correct, train_total = 0.0, 0, 0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)
            _, preds = outputs.max(1)
            train_total += labels.size(0)
            train_correct += preds.eq(labels).sum().item()

        train_loss /= train_total
        train_acc = train_correct / train_total

        # 验证阶段
        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * inputs.size(0)
                _, preds = outputs.max(1)
                val_total += labels.size(0)
                val_correct += preds.eq(labels).sum().item()

        val_loss /= val_total
        val_acc = val_correct / val_total
        scheduler.step(val_acc)

        # 更新最佳模型
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = model.state_dict().copy()

        # 打印日志
        print(f"Epoch {epoch+1}/{num_epochs} | "
              f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")

    print(f"\n训练完成!最佳验证准确率: {best_val_acc:.4f}")
    model.load_state_dict(best_model_state)
    return model

3. 快速上手示例

把上面的函数拼起来就能用:

if __name__ == "__main__":
    # 1. 构建模型(以特征提取为例)
    model, preprocess = build_transfer_model(
        model_name="resnet50",
        num_classes=5,  # 假设你的私有数据集有5类
        strategy="feature_extract"
    )

    # 2. 加载数据(确保你的数据集是ImageFolder格式:./data/train/类1, ./data/val/类1...)
    train_loader, val_loader, classes = get_data_loaders(
        dataset_path="./data",
        preprocess=preprocess,
        batch_size=32
    )
    print(f"类别列表: {classes}")

    # 3. 训练模型
    trained_model = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=10,
        lr=0.001
    )

    # 4. 保存最佳模型
    torch.save(trained_model.state_dict(), "best_transfer_model.pth")

4. 最佳实践

  1. 预处理必须一致:直接用预训练权重的 transforms(),避免手动写归一化参数。
  2. 从保守策略开始:先用特征提取跑通流程,再尝试分层/全量微调。
  3. 分层学习率(可选但推荐):微调时,新分类头用 1e-3,最后1-2个CNN模块用 1e-4,前面用 1e-5 或更小。
  4. 早停机制:避免过度训练,保存验证集最佳模型。
  5. 数据增强:即使是小数据集,也要加上简单的数据增强(水平翻转、随机裁剪)。
迁移学习是现代深度学习的基石。建议从简单的特征提取开始练习,逐步掌握微调技术。记住,在2026年,从零开始训练CNN模型已经很少见——除非你的任务和预训练数据集完全无关且数据量超大。

扩展阅读