GAN详解:生成对抗网络原理与PyTorch实现

如果说传统CNN是让计算机"看懂"图像,那么GAN就是让计算机"学会创造"——它由Ian Goodfellow等人在2014年提出,用博弈论纳什均衡开创了无监督生成的新篇章。


1. GAN核心:造假者与鉴定师的博弈

1.1 生动比喻

  • 生成器 G:技艺精湛的造假者,输入随机噪声,输出「以假乱真」的样本
  • 判别器 D:经验丰富的鉴定师,输入样本,输出「为真」的概率

两者零和博弈:造假者不断精进,鉴定师同步升级,最终纳什均衡——生成样本与真实数据分布几乎一致,判别器输出恒为0.5。

1.2 核心优势

无需标注数据即可学习;可生成图像/音频/文本;图像生成质量远超传统生成模型(如VAE)。


2. 架构与数学原理

2.1 最小最大博弈

GAN的目标函数是经典的极小极大问题minGmaxDV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]\min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))]

简单拆解:

  • D的任务:最大化对真实样本的置信度(D(x)1D(x)\to1)、最小化对生成样本的置信度(D(G(z))0D(G(z))\to0
  • G的任务反向博弈——要么让D(G(z))1D(G(z))\to1(原始优化有饱和问题,实践中用-log(D(G(z)))代替)

2.2 基础组件

  • 输入:标准正态分布随机向量 zN(0,1)z \sim N(0,1)
  • 生成器:用转置卷积(反卷积)把低维噪声上采样为高维数据
  • 判别器:用标准卷积做二分类器
  • 输出:D输出[0,1]概率,G输出[-1,1](Tanh激活)归一化的样本

3. DCGAN PyTorch快速实现

DCGAN是GAN的卷积化标准实现,解决了原始GAN的训练不稳定问题:

3.1 环境准备

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

3.2 生成器

class DCGANGenerator(nn.Module):
    """输入100维噪声 -> 输出3×64×64归一化图像"""
    def __init__(self, nz=100, ngf=64, nc=3):
        super().__init__()
        self.layers = nn.Sequential(
            # nz -> ngf*8×4×4
            nn.ConvTranspose2d(nz, ngf*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf*8),
            nn.ReLU(True),
            # ngf*8×4×4 -> ngf*4×8×8
            nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True),
            # ngf*4×8×8 -> ngf×32×32
            nn.ConvTranspose2d(ngf*4, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # ngf×32×32 -> nc×64×64
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        return self.layers(z)

3.3 判别器

class DCGANDiscriminator(nn.Module):
    """输入3×64×64图像 -> 输出[0,1]置信度"""
    def __init__(self, nc=3, ndf=64):
        super().__init__()
        self.layers = nn.Sequential(
            # nc×64×64 -> ndf×32×32
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # ndf×32×32 -> ndf×2×16×16
            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2, inplace=True),
            # ndf×2×16×16 -> ndf×4×8×8
            nn.Conv2d(ndf×2, ndf×4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf×4),
            nn.LeakyReLU(0.2, inplace=True),
            # ndf×4×8×8 -> 1×1×1
            nn.Conv2d(ndf×4, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.layers(x).view(-1)

3.4 数据加载与训练

数据预处理

transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化到[-1,1]
])

# 替换为你的本地数据集路径
dataset = datasets.CIFAR10(root='./data', download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=2)

训练循环

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 初始化模型、优化器、损失
G = DCGANGenerator().to(device)
D = DCGANDiscriminator().to(device)
criterion = nn.BCELoss()
optimizerG = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerD = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))

# 固定噪声用于可视化
fixed_noise = torch.randn(16, 100, 1, 1, device=device)

num_epochs = 5
G_losses, D_losses = [], []

for epoch in range(num_epochs):
    for i, (real_imgs, _) in enumerate(dataloader):
        real_imgs = real_imgs.to(device)
        batch_size = real_imgs.size(0)
        real_label = torch.ones(batch_size, device=device)
        fake_label = torch.zeros(batch_size, device=device)

        ###########################
        # 训练判别器
        ###########################
        D.zero_grad()
        # 真实样本损失
        output = D(real_imgs)
        errD_real = criterion(output, real_label)
        errD_real.backward()
        # 生成样本损失
        noise = torch.randn(batch_size, 100, 1, 1, device=device)
        fake_imgs = G(noise)
        output = D(fake_imgs.detach()) # 冻结生成器梯度
        errD_fake = criterion(output, fake_label)
        errD_fake.backward()
        # 更新判别器
        errD = errD_real + errD_fake
        optimizerD.step()

        ###########################
        # 训练生成器
        ###########################
        G.zero_grad()
        # 反向欺骗判别器
        output = D(fake_imgs)
        errG = criterion(output, real_label)
        errG.backward()
        # 更新生成器
        optimizerG.step()

        # 记录损失
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # 每50步输出训练信息
        if i % 50 == 0:
            print(f"[Epoch {epoch+1}/{num_epochs}][Batch {i}/{len(dataloader)}] "
                  f"Loss_D: {errD:.4f} Loss_G: {errG:.4f} "
                  f"D(x): {output.mean().item():.4f}")

    # 每轮可视化生成结果
    with torch.no_grad():
        fake_fixed = G(fixed_noise).detach().cpu()
    grid = torchvision.utils.make_grid(fake_fixed, nrow=4, normalize=True)
    plt.figure(figsize=(8,8))
    plt.imshow(grid.permute(1,2,0))
    plt.axis('off')
    plt.show()

4. 常见挑战与改进方向

4.1 主要挑战

  1. 模式崩坏:生成器只输出有限种类样本
  2. 训练不稳定:损失震荡,判别器/生成器一方过强
  3. 评估困难:缺乏绝对客观的质量指标

4.2 经典改进方案

改进方案解决问题核心思路
WGAN模式崩坏/不稳定用Wasserstein距离代替JS散度
WGAN-GPWGAN梯度裁剪生硬用梯度惩罚代替硬裁剪
CycleGAN无配对图像翻译加入循环一致性损失
StyleGAN精细控制生成风格引入风格映射网络、AdaIN归一化

5. 实践建议

  1. 数据预处理:严格归一化到[-1,1],用Tanh激活生成器
  2. 优化器:固定Adam,lr=0.0002,beta1=0.5
  3. 训练策略:交替训练D和G,避免一方过强
  4. 监控指标:除了损失,还要直观观察生成样本(损失可能有误导性)

总结

GAN用简单的博弈思想实现了惊人的生成效果,虽有训练挑战,但仍是AI创作领域的核心工具之一。通过本文的DCGAN实现,你可以快速上手GAN,后续可根据需求选择CycleGAN、StyleGAN等变体。


🔗 扩展阅读