SRGAN详解:超分辨率生成对抗网络原理与PyTorch实现

想象你翻出10年前320x240像素的毕业合影,用手机放大后人脸糊成马赛克、黑板字完全认不出——双三次插值只能给你「模糊的平滑」,但2017年Ledig等人提出的SRGAN能给你「回忆的清晰」,首次让GAN介入超分辨率,实现从「像素填充」到「细节重建」的跨越。


1. SRGAN概述

1.1 传统方法的痛点

在SRGAN之前,主流方法(如SRCNN)靠最小化MSE训练,虽能拿到高PSNR(峰值信噪比,数学层面的准确度),但输出图像往往「像磨了皮」——丢失关键的高频细节(头发丝、毛孔、建筑纹理)。

1.2 核心双创新

1. **对抗性损失**:用生成器(做假)和判别器(打假)博弈,逼生成器做「骗过眼睛」的纹理 2. **感知损失**:不比像素点,比预训练VGG提取的**深层语义特征**是否一致

1.3 主要优势

  • 视觉真实感远超传统插值/CNN
  • 4倍以上放大仍能重建细节
  • 可扩展到医学影像、卫星遥感等领域

2. 核心架构:三组件协同

SRGAN由生成器、判别器、VGG感知损失网络构成,缺一不可。

2.1 生成器:低清→高清的魔术棒

生成器用16层残差块(SRResNet骨架)+ PixelShuffle上采样,既保证深层特征提取,又避免梯度消失。

关键组件代码(精简)

import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    """SRGAN残差块:跳跃连接防止梯度消失"""
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.BatchNorm2d(channels),
            nn.PReLU(),  # ⚠️ PReLU比ReLU更适合超分(减少死神经元)
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.BatchNorm2d(channels),
        )

    def forward(self, x):
        return x + self.block(x)

class UpsampleBlock(nn.Module):
    """PixelShuffle上采样:避免棋盘伪影"""
    def __init__(self, in_channels, up_scale=2):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, in_channels * up_scale**2, 3, padding=1),
            nn.PixelShuffle(up_scale),
            nn.PReLU(),
        )

    def forward(self, x):
        return self.block(x)

完整生成器(代码折叠)

点击查看完整Generator
class Generator(nn.Module):
    """SRGAN生成器:默认4倍放大"""
    def __init__(self, scale_factor=4, num_res_blocks=16):
        super().__init__()
        self.scale_factor = scale_factor
        
        # 1. 初始低级特征提取
        self.first_conv = nn.Sequential(
            nn.Conv2d(3, 64, 9, padding=4),
            nn.PReLU()
        )
        
        # 2. 16层残差块深层特征提取
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(64) for _ in range(num_res_blocks)]
        )
        
        # 3. 中间卷积+全局跳跃连接
        self.mid_conv = nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1),
            nn.BatchNorm2d(64)
        )
        
        # 4. PixelShuffle上采样(每次2倍,4倍需2次)
        self.upsample = nn.Sequential(
            *[UpsampleBlock(64, 2) for _ in range(int(scale_factor/2))]
        )
        
        # 5. 输出RGB图像(归一化到[-1,1])
        self.last_conv = nn.Sequential(
            nn.Conv2d(64, 3, 9, padding=4),
            nn.Tanh()
        )

    def forward(self, x):
        out1 = self.first_conv(x)
        out = self.res_blocks(out1)
        out = self.mid_conv(out)
        out = out1 + out  # 全局跳跃连接
        out = self.upsample(out)
        out = self.last_conv(out)
        return out

2.2 判别器:真假图像的判官

判别器是8层stride=1/2的CNN+全局平均池化+分类头,输出0-1的置信度(0=假,1=真)。

点击查看完整Discriminator
class Discriminator(nn.Module):
    """SRGAN判别器:二分类真假高清图"""
    def __init__(self, input_shape=(3, 96, 96)):
        super().__init__()
        self.input_shape = input_shape
        
        def conv_block(in_f, out_f, stride=1, norm=True):
            layers = [nn.Conv2d(in_f, out_f, 3, stride, padding=1)]
            if norm: layers.append(nn.BatchNorm2d(out_f))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        # 特征提取主干
        self.backbone = nn.Sequential(
            *conv_block(3, 64, stride=1, norm=False),
            *conv_block(64, 64, stride=2),
            *conv_block(64, 128, stride=1),
            *conv_block(128, 128, stride=2),
            *conv_block(128, 256, stride=1),
            *conv_block(256, 256, stride=2),
            *conv_block(256, 512, stride=1),
            *conv_block(512, 512, stride=2),
        )

        # 全局平均池化+分类
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.5),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        feats = self.backbone(img)
        feats = self.pool(feats).flatten(1)
        return self.classifier(feats)

3. 灵魂:损失函数设计

SRGAN的损失是内容损失(像素+感知)+ 对抗损失的加权和,其中感知损失是核心。

3.1 内容损失

import torchvision.models as models

class VGG19Extractor(nn.Module):
    """冻结预训练VGG19,提取特征计算感知损失"""
    def __init__(self, layer_idx=35):  # ⚠️ 用VGG19的第35层(ReLU激活后)
        super().__init__()
        vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features
        self.extractor = nn.Sequential(*list(vgg.children())[:layer_idx+1])
        
        # 冻结所有参数
        for p in self.extractor.parameters():
            p.requires_grad = False
        
        # ImageNet归一化参数
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))

    def forward(self, x):
        # 先将[-1,1]反归一化到[0,1],再用ImageNet参数归一化
        x = (x + 1) / 2
        x = (x - self.mean) / self.std
        return self.extractor(x)

def content_loss(sr, hr, vgg, mse_w=0.01):
    """内容损失 = 小权重MSE + 大权重感知损失"""
    pixel_loss = F.mse_loss(sr, hr)
    sr_feats = vgg(sr)
    hr_feats = vgg(hr)
    percept_loss = F.mse_loss(sr_feats, hr_feats)
    return mse_w * pixel_loss + percept_loss

3.2 对抗损失

LSGAN损失代替原版GAN的交叉熵,避免梯度消失:

def adv_loss_g(generator_out):
    """生成器损失:让判别器认为生成图是真的(标签1)"""
    return F.mse_loss(generator_out, torch.ones_like(generator_out))

def adv_loss_d(real_out, fake_out):
    """判别器损失:区分真假(真1假0)"""
    real_l = F.mse_loss(real_out, torch.ones_like(real_out))
    fake_l = F.mse_loss(fake_out, torch.zeros_like(fake_out))
    return (real_l + fake_l) / 2

4. 训练策略:两阶段更稳

如果跳过预训练,生成器一开始输出太假,判别器会瞬间学会“打假”,导致生成器梯度消失,无法收敛。

阶段1:预训练生成器(用SRResNet的MSE损失)

def pretrain_gen(generator, dataloader, device, epochs=50, lr=1e-4):
    opt = torch.optim.Adam(generator.parameters(), lr=lr)
    mse = nn.MSELoss()
    generator.train()
    
    for epoch in range(epochs):
        for batch_idx, (hr, lr) in enumerate(dataloader):
            hr, lr = hr.to(device), lr.to(device)
            opt.zero_grad()
            sr = generator(lr)
            loss = mse(sr, hr)
            loss.backward()
            opt.step()
            if batch_idx % 100 == 0:
                print(f"Pretrain E{epoch} B{batch_idx} | Loss: {loss:.4f}")
    torch.save(generator.state_dict(), "srresnet_pretrain.pth")

阶段2:对抗训练(加载预训练权重)

代码逻辑核心为:先训判别器1次,再训生成器1次,避免一方太强。


5. 快速上手老照片修复

from PIL import Image
import torchvision.transforms as transforms

def enhance_old_photo(img_path, generator_path, device, scale=4):
    # 加载预训练生成器
    gen = Generator(scale_factor=scale).to(device)
    gen.load_state_dict(torch.load(generator_path, map_location=device))
    gen.eval()
    
    # 图像预处理
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
    ])
    
    img = Image.open(img_path).convert('RGB')
    lr = transform(img).unsqueeze(0).to(device)
    
    # 推理
    with torch.no_grad():
        sr = gen(lr)
    
    # 反归一化后处理
    sr = (sr.squeeze(0).cpu() + 1) / 2
    sr = torch.clamp(sr, 0, 1)
    return transforms.ToPILImage()(sr)

6. 发展趋势与挑战

主要变体

  • ESRGAN:改进残差块为RRDB,去掉BN,用相对论GAN
  • Real-ESRGAN:用纯合成数据训练,解决真实场景超分问题

现存挑战

  • 推理速度较慢(移动端需量化/剪枝)
  • 可能生成“伪真实细节”(比如把模糊的痣变成雀斑)

总结

SRGAN是超分辨率从「数学至上」到「视觉至上」的里程碑,通过残差网络架构、感知损失、对抗训练三大核心,实现了高质量的图像放大。后续的ESRGAN、Real-ESRGAN在此基础上不断优化,已经能解决很多真实场景的需求。


相关教程

🔗 扩展阅读