def train_model(
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
num_epochs: int = 10,
lr: float = 0.001,
device: str = "cuda" if torch.cuda.is_available() else "cpu"
):
"""
训练迁移学习模型
"""
model.to(device)
criterion = nn.CrossEntropyLoss()
# 仅优化可训练的参数!
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
# 简单的学习率调度(可选)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=3, factor=0.5)
best_val_acc = 0.0
best_model_state = model.state_dict()
for epoch in range(num_epochs):
# 训练阶段
model.train()
train_loss, train_correct, train_total = 0.0, 0, 0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item() * inputs.size(0)
_, preds = outputs.max(1)
train_total += labels.size(0)
train_correct += preds.eq(labels).sum().item()
train_loss /= train_total
train_acc = train_correct / train_total
# 验证阶段
model.eval()
val_loss, val_correct, val_total = 0.0, 0, 0
with torch.no_grad():
for inputs, labels in val_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
val_loss += loss.item() * inputs.size(0)
_, preds = outputs.max(1)
val_total += labels.size(0)
val_correct += preds.eq(labels).sum().item()
val_loss /= val_total
val_acc = val_correct / val_total
scheduler.step(val_acc)
# 更新最佳模型
if val_acc > best_val_acc:
best_val_acc = val_acc
best_model_state = model.state_dict().copy()
# 打印日志
print(f"Epoch {epoch+1}/{num_epochs} | "
f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | "
f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")
print(f"\n训练完成!最佳验证准确率: {best_val_acc:.4f}")
model.load_state_dict(best_model_state)
return model