title: Detailed explanation of SRGAN: Super-resolution generative adversarial network principle and PyTorch implementation | Daoman PythonAI description: In-depth analysis of SRGAN (Super-Resolution GAN) super-resolution generative adversarial network, introducing its application in tasks such as image super-resolution reconstruction and old photo restoration, including detailed architecture analysis, PyTorch implementation and practical application scenarios. keywords: [SRGAN, super-resolution, generative adversarial network, image reconstruction, image amplification, GAN, deep learning, computer vision, PyTorch]

Detailed explanation of SRGAN: Super-resolution generative adversarial network principle and PyTorch implementation

Imagine you pulled out a 320×240 pixel graduation photo from 10 years ago. When you zoom in with your finger, the face becomes a mosaic and the chalk words on the blackboard are completely unrecognizable. Methods such as bicubic interpolation can only give you a "fuzzy sense of smoothness", but SRGAN (Super-Resolution GAN) proposed by Ledig et al. in 2017 can give you a "clear sense of recall." It introduces generative adversarial networks into super-resolution tasks for the first time, allowing image amplification to cross from "pixel filling" to "detail reconstruction".


1. SRGAN Overview

1.1 Pain points of traditional methods

Before the emergence of SRGAN, mainstream super-resolution methods (such as SRCNN) mostly relied on minimizing the mean square error (MSE) for training. Although this can achieve high scores on numerical indicators such as PSNR, the image always looks like "skinned" - key high-frequency details such as hair, skin texture, and building edges are lost, making it visually unnatural.

1.2 Two core innovations

1. **Adversarial Loss**: Let the generator make fakes and let the discriminator detect fakes. The two play games repeatedly, forcing the generator to learn to draw texture details that "can deceive the human eye". 2. **Perceptual Loss**: Instead of comparing stupidly pixel by pixel, the pre-trained VGG network is used to extract **deep semantic features**, so that the generated image and the real high-resolution image are closer in terms of "looks like".

1.3 Main advantages

  • Visual realism far exceeds traditional interpolation or pure CNN methods
  • Believable details can still be reconstructed at 4x or higher magnification
  • The architecture can be migrated to medical imaging, satellite remote sensing, video enhancement and other fields

2. Core architecture: three-component collaboration

SRGAN is not an isolated network, but consists of a trinity of generator, discriminator, and VGG perceptual loss network.

2.1 Generator: Magic wand from low definition to high definition

The generator uses 16 residual blocks (SRResNet skeleton) + PixelShuffle upsampling. The residual block is responsible for deep feature extraction and effectively prevents gradient disappearance; PixelShuffle is an elegant sub-pixel convolution upsampling method specifically used to avoid checkerboard artifacts.

Key component code (simplified)

import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    """SRGAN残差块:跳跃连接防止梯度消失"""
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.BatchNorm2d(channels),
            nn.PReLU(),  # PReLU比ReLU更适合超分任务(减少死神经元)
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.BatchNorm2d(channels),
        )

    def forward(self, x):
        return x + self.block(x)   # 局部跳跃连接

class UpsampleBlock(nn.Module):
    """PixelShuffle上采样:避免棋盘伪影"""
    def __init__(self, in_channels, up_scale=2):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, in_channels * up_scale**2, 3, padding=1),
            nn.PixelShuffle(up_scale),
            nn.PReLU(),
        )

    def forward(self, x):
        return self.block(x)

Complete generator (code folding)

Click to view the complete Generator
class Generator(nn.Module):
    """SRGAN生成器:默认4倍放大"""
    def __init__(self, scale_factor=4, num_res_blocks=16):
        super().__init__()
        self.scale_factor = scale_factor
        
        # 1. 初始低级特征提取
        self.first_conv = nn.Sequential(
            nn.Conv2d(3, 64, 9, padding=4),
            nn.PReLU()
        )
        
        # 2. 16层残差块,深度提取高级特征
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(64) for _ in range(num_res_blocks)]
        )
        
        # 3. 中间卷积+全局跳跃连接(保留低频结构信息)
        self.mid_conv = nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1),
            nn.BatchNorm2d(64)
        )
        
        # 4. PixelShuffle上采样(每次2倍,4倍需2次)
        self.upsample = nn.Sequential(
            *[UpsampleBlock(64, 2) for _ in range(int(scale_factor/2))]
        )
        
        # 5. 输出RGB图像(Tanh归一化到[-1,1])
        self.last_conv = nn.Sequential(
            nn.Conv2d(64, 3, 9, padding=4),
            nn.Tanh()
        )

    def forward(self, x):
        out1 = self.first_conv(x)
        out = self.res_blocks(out1)
        out = self.mid_conv(out)
        out = out1 + out          # 全局跳跃连接:把浅层信息直接送到深层
        out = self.upsample(out)
        out = self.last_conv(out)
        return out

2.2 Discriminator: The judge of true and false images

The discriminator is essentially an 8-layer convolutional network that alternately uses convolution with a step size of 1 and a step size of 2 to gradually extract features. Finally, it is connected to global average pooling and a classification head, and outputs a confidence level of 0 to 1 (0=generated image/false, 1=real high-resolution image/true).

Click to view the complete Discriminator
class Discriminator(nn.Module):
    """SRGAN判别器:二分类,判断图像是真实高清图还是生成图"""
    def __init__(self, input_shape=(3, 96, 96)):
        super().__init__()
        self.input_shape = input_shape
        
        def conv_block(in_f, out_f, stride=1, norm=True):
            layers = [nn.Conv2d(in_f, out_f, 3, stride, padding=1)]
            if norm:
                layers.append(nn.BatchNorm2d(out_f))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        # 特征提取主干
        self.backbone = nn.Sequential(
            *conv_block(3, 64, stride=1, norm=False),
            *conv_block(64, 64, stride=2),
            *conv_block(64, 128, stride=1),
            *conv_block(128, 128, stride=2),
            *conv_block(128, 256, stride=1),
            *conv_block(256, 256, stride=2),
            *conv_block(256, 512, stride=1),
            *conv_block(512, 512, stride=2),
        )

        # 全局平均池化 + 二分类头
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.5),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        feats = self.backbone(img)
        feats = self.pool(feats).flatten(1)
        return self.classifier(feats)

3. Soul: loss function design

The loss of SRGAN is weighted by two parts: content loss (pixel loss + perceptual loss) and adversarial loss. Among them, perceptual loss is the key to making the image "look real".

3.1 Content loss: pixel matching + perceptual matching

First load a VGG19 network with frozen parameters and use it to extract high-level semantic features of the image. Content loss = pixel MSE with very small weight + perceptual feature MSE with large weight. In this way, the network can not only ensure that the overall structure does not deviate, but also concentrate on drawing realistic high-frequency textures.

import torchvision.models as models

class VGG19Extractor(nn.Module):
    """冻结预训练VGG19,提取特征用于感知损失计算"""
    def __init__(self, layer_idx=35):  # 取VGG19的第35层(ReLU激活后)
        super().__init__()
        vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features
        self.extractor = nn.Sequential(*list(vgg.children())[:layer_idx+1])
        
        for p in self.extractor.parameters():
            p.requires_grad = False
        
        # ImageNet的归一化参数
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))

    def forward(self, x):
        # 生成器输出在[-1,1]范围,先映射到[0,1],再用ImageNet均值标准差归一化
        x = (x + 1) / 2
        x = (x - self.mean) / self.std
        return self.extractor(x)

def content_loss(sr, hr, vgg, mse_w=0.01):
    """内容损失 = 小权重像素MSE + 大权重感知损失"""
    pixel_loss = F.mse_loss(sr, hr)
    sr_feats = vgg(sr)
    hr_feats = vgg(hr)
    percept_loss = F.mse_loss(sr_feats, hr_feats)
    return mse_w * pixel_loss + percept_loss

3.2 Adversarial loss: making the discriminator “difficult to distinguish true from false”

The original GAN ​​paper used cross entropy, but here it is replaced with LSGAN (least squares GAN) loss, which can alleviate the gradient disappearance and make the training more stable.

def adv_loss_g(generator_out):
    """生成器损失:希望判别器认为生成的图是真的(标签为1)"""
    return F.mse_loss(generator_out, torch.ones_like(generator_out))

def adv_loss_d(real_out, fake_out):
    """判别器损失:真图标1,假图标0"""
    real_l = F.mse_loss(real_out, torch.ones_like(real_out))
    fake_l = F.mse_loss(fake_out, torch.zeros_like(fake_out))
    return (real_l + fake_l) / 2

4. Training strategy: Two stages are more stable

If the generator is allowed to directly compete with the discriminator from the beginning, the image output by the generator will be too fake, and the discriminator will learn to "fight the fake" instantly, the gradient obtained by the generator will disappear, and the training will soon stagnate. The usual approach is to pre-train the generator first and then introduce adversarial training.

Stage 1: Pre-trained generator (only MSE pixel loss)

This stage is actually training a SRResNet network. The goal is to make the enlarged image as close as possible to the real high-resolution image (pixel mean square error). Convergence is fast and training is stable.

def pretrain_gen(generator, dataloader, device, epochs=50, lr=1e-4):
    opt = torch.optim.Adam(generator.parameters(), lr=lr)
    mse = nn.MSELoss()
    generator.train()
    
    for epoch in range(epochs):
        for batch_idx, (hr, lr) in enumerate(dataloader):
            hr, lr = hr.to(device), lr.to(device)
            opt.zero_grad()
            sr = generator(lr)
            loss = mse(sr, hr)
            loss.backward()
            opt.step()
            if batch_idx % 100 == 0:
                print(f"Pretrain E{epoch} B{batch_idx} | MSE Loss: {loss:.4f}")
    torch.save(generator.state_dict(), "srresnet_pretrain.pth")

Phase 2: Adversarial training (loading pre-trained weights)

The core loop logic is usually: Alternately train the discriminator and generator, such as updating the discriminator first and then updating the generator. This prevents one party from overwhelming the other and maintains a dynamic balance.


5. Get started quickly: Use SRGAN to repair old photos

To enlarge and restore the details of blurry old photos, you only need to load a pre-trained generator and write a few lines of pre-processing code.

from PIL import Image
import torchvision.transforms as transforms

def enhance_old_photo(img_path, generator_path, device, scale=4):
    # 加载预训练生成器
    gen = Generator(scale_factor=scale).to(device)
    gen.load_state_dict(torch.load(generator_path, map_location=device))
    gen.eval()
    
    # 预处理:转tensor并归一化到[-1,1]
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
    ])
    
    img = Image.open(img_path).convert('RGB')
    lr = transform(img).unsqueeze(0).to(device)
    
    # 推理
    with torch.no_grad():
        sr = gen(lr)
    
    # 后处理:反归一化回[0,1],再转PIL图像
    sr = (sr.squeeze(0).cpu() + 1) / 2
    sr = torch.clamp(sr, 0, 1)
    return transforms.ToPILImage()(sr)

Main variants

  • ESRGAN: Replace the residual block with Residual‑in‑Residual Dense Block (RRDB), remove the BN layer, and introduce Relativistic GAN (Relativistic GAN) to enhance details.
  • Real‑ESRGAN: Trained with purely synthetic data, greatly improving the generalization ability on real low-quality images. It is now the engine behind many image enhancement tools.

Existing challenges

  • The inference speed is slow, and mobile terminal deployment usually requires a combination of acceleration methods such as quantization and pruning.
  • Occasionally, "pseudo-realistic details" are generated - such as blurry skin spots mistakenly drawn as freckles, which is still a problem in some scenes that require extremely high accuracy.

Summarize

SRGAN is a milestone in the field of super-resolution from "pursuing numerical indicators" to "pursuing visual reality". It uses residual network, VGG perceptual loss and adversarial training to achieve high-quality image amplification. Subsequent variants such as ESRGAN and Real-ESRGAN have continued to evolve on this basis, and now have played a huge role in real-life scenarios such as old photo restoration, video enhancement, and game texture amplification.


🔗 Extended reading