Diffusion Model详解:扩散模型原理与PyTorch实现

大家好,这里是道满PythonAI!今天我们拆解的是让Midjourney和Sora爆火的技术核心——Diffusion Model扩散模型

如果说GAN是在博弈中“伪造”数据,CycleGAN是在循环中“翻译”风格,那么扩散模型就是2021年后图像生成的绝对霸主。它的灵感并非来自艺术,而是非平衡热力学:先“倒放”墨水扩散成黑液的过程,把噪声还原成清晰数据。


1. 扩散模型两大核心

扩散模型是一种生成模型,分为固定前向毁灭可学习反向重建两个阶段:

1.1 前向毁灭:加噪变废图

向原始图 x₀ 中不断加入微小高斯噪声,经过 T 步(通常1000-4000)后,x_T 变成完全无法辨认的标准正态分布噪声。这一步不需要训练,是预先设定的马尔可夫链。

1.2 反向重建:去噪变新图

模型学习预测每一步中被加入的噪声 ε,再用数学公式从 x_t 反推 x_{t-1}。经过数千次微小去噪,最终从 x_T 生成一张有意义的新图像。


2. 数学极简推导

2.1 前向公式:一步到位算加噪

原始马尔可夫加噪公式是逐步叠加的,但用重参数化技巧可以直接从 x₀ 算任意步 tx_t

xt=αˉtx0+1αˉtϵ,ϵN(0,I)x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)

其中:

  • αₜ = 1 - βₜβₜ 是噪声调度(越大加噪越快,但不能超过0.02)
  • $\bar{\alpha}_t = \prod_{s=1}^t \alpha_s$,是前 t 步的 α 乘积

2.2 反向训练:预测噪声即可

扩散模型的损失函数非常简单——直接让模型预测的噪声 ε_θ 与真实加入的噪声 ε 做均方误差(MSE):

Lt=Ex0,ϵ[ϵϵθ(xt,t)2]L_t = \mathbb{E}_{x_0, \epsilon} \left[ \| \epsilon - \epsilon_\theta(x_t, t) \|^2 \right]


3. DDPM核心实现

DDPM(Denoising Diffusion Probabilistic Models)是扩散模型的开山经典,我们用PyTorch实现核心模块。

3.1 核心网络:带时间嵌入的简化U-Net

扩散模型需要知道当前是“第几步加噪”,所以要加入正弦位置编码把时间步 t 编码成向量,和图像特征融合:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SinusoidalPositionEmbeddings(nn.Module):
    """正弦位置编码,用于时间步嵌入"""
    def __init__(self, dim):
        super().__init__()
        self.half_dim = dim // 2

    def forward(self, time):
        device = time.device
        emb = math.log(10000) / (self.half_dim - 1)
        emb = torch.exp(torch.arange(self.half_dim, device=device) * -emb)
        emb = time[:, None] * emb[None, :]
        return torch.cat((emb.sin(), emb.cos()), dim=-1)

class SimpleUnet(nn.Module):
    """带时间嵌入的简化U-Net"""
    def __init__(self):
        super().__init__()
        img_ch, time_emb_dim = 3, 32
        down_ch = (64, 128, 256, 512)
        up_ch = (512, 256, 128, 64)

        # 时间嵌入层
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim * 2),
            nn.ReLU(),
            nn.Linear(time_emb_dim * 2, time_emb_dim)
        )
        # 初始投影
        self.conv0 = nn.Conv2d(img_ch, down_ch[0], 3, padding=1)
        # 下采样+残差连接
        self.downs = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(down_ch[i], down_ch[i+1], 3, padding=1),
                nn.BatchNorm2d(down_ch[i+1]),
                nn.ReLU(),
                nn.Conv2d(down_ch[i+1], down_ch[i+1], 3, padding=1),
                nn.BatchNorm2d(down_ch[i+1]),
                nn.ReLU(),
                nn.Conv2d(down_ch[i+1], down_ch[i+1], 4, 2, 1)
            ) for i in range(len(down_ch)-1)
        ])
        # 上采样+残差拼接
        self.ups = nn.ModuleList([
            nn.Sequential(
                nn.ConvTranspose2d(up_ch[i], up_ch[i]//2, 4, 2, 1),
                nn.BatchNorm2d(up_ch[i]//2),
                nn.ReLU(),
                nn.Conv2d(up_ch[i], up_ch[i+1], 3, padding=1),
                nn.BatchNorm2d(up_ch[i+1]),
                nn.ReLU(),
                nn.Conv2d(up_ch[i+1], up_ch[i+1], 3, padding=1),
                nn.BatchNorm2d(up_ch[i+1]),
                nn.ReLU()
            ) for i in range(len(up_ch)-1)
        ])
        # 输出层(预测噪声)
        self.output = nn.Conv2d(up_ch[-1], img_ch, 1)

    def forward(self, x, t):
        # 时间嵌入
        t_emb = self.time_mlp(t)[(..., ) + (None, ) * 2]  # 扩展到[B, C, 1, 1]
        # 初始卷积
        x = self.conv0(x)
        # 下采样并保存残差
        residuals = []
        for down in self.downs:
            x = down(x) + t_emb  # 融合时间信息
            residuals.append(x)
        # 上采样并拼接残差
        for up, res in zip(self.ups, reversed(residuals)):
            x = up[:1](x)  # 先上采样
            x = torch.cat([x, res], dim=1)  # 拼接残差
            x = up[1:](x) + t_emb
        # 输出
        return self.output(x)

3.2 噪声调度与辅助函数

def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    """线性噪声调度"""
    return torch.linspace(start, end, timesteps)

def precompute_schedule(betas):
    """预计算α相关的所有值,避免训练/采样时重复计算"""
    alphas = 1.0 - betas
    alphas_bar = torch.cumprod(alphas, dim=0)
    sqrt_alphas_bar = torch.sqrt(alphas_bar)
    sqrt_one_minus_alphas_bar = torch.sqrt(1.0 - alphas_bar)
    # 反向采样需要的参数
    posterior_variance = betas * (1.0 - alphas_bar[:-1]) / (1.0 - alphas_bar[1:])
    posterior_variance = torch.cat([torch.tensor([1e-20]), posterior_variance])
    return {
        "sqrt_alphas_bar": sqrt_alphas_bar,
        "sqrt_one_minus_alphas_bar": sqrt_one_minus_alphas_bar,
        "posterior_variance": posterior_variance,
        "betas": betas,
        "sqrt_recip_alphas": torch.sqrt(1.0 / alphas),
        "sqrt_recip_alphas_bar_minus_1": torch.sqrt(1.0 / alphas_bar - 1.0)
    }

def extract(arr, t, shape):
    """从预计算数组中提取对应时间步的值,并扩展到指定形状"""
    batch_size = shape[0]
    out = arr.gather(-1, t)
    return out.reshape(batch_size, *([1]*(len(shape)-1)))

3.3 训练与采样循环

def get_loss(model, x0, t, schedule):
    """计算DDPM的MSE损失"""
    noise = torch.randn_like(x0)
    xt = (
        extract(schedule["sqrt_alphas_bar"], t, x0.shape) * x0 +
        extract(schedule["sqrt_one_minus_alphas_bar"], t, x0.shape) * noise
    )
    noise_pred = model(xt, t)
    return F.mse_loss(noise, noise_pred)

@torch.no_grad()
def sample_one_step(model, xt, t, schedule):
    """反向采样一步"""
    betas_t = extract(schedule["betas"], t, xt.shape)
    sqrt_recip_alphas_t = extract(schedule["sqrt_recip_alphas"], t, xt.shape)
    sqrt_recip_alphas_bar_minus_1_t = extract(schedule["sqrt_recip_alphas_bar_minus_1"], t, xt.shape)
    # 计算均值
    mean = sqrt_recip_alphas_t * (
        xt - betas_t * model(xt, t) / sqrt_recip_alphas_bar_minus_1_t
    )
    # 采样噪声(t=0时不加)
    posterior_var_t = extract(schedule["posterior_variance"], t, xt.shape)
    noise = torch.randn_like(xt) if t[0] > 0 else 0.0
    return mean + torch.sqrt(posterior_var_t) * noise

4. 主流改进与应用

4.1 Latent Diffusion(Stable Diffusion核心)

像素空间训练扩散模型太占显存(一张1024×1024的图有300万+像素),所以Stable Diffusion把图压缩到1/8分辨率的潜在空间(4通道)再训练:

  • 预训练VAE压缩/解压图像
  • CLIP文本编码器把提示词转成条件向量
  • 带条件的U-Net在潜在空间去噪

4.2 核心应用场景

  • 文生图/图生图:Midjourney、Stable Diffusion WebUI
  • 视频生成:Sora、Pika Labs(在时空潜在空间扩散)
  • 图像修复:Inpainting、Outpainting、超分辨率
  • 科学研究:分子生成、材料设计

5. 道满的实践建议

5.1 快速上手

  1. 先复现MNIST/CIFAR-10的简化DDPM
  2. 再用Stable Diffusion WebUI玩文生图
  3. 最后读源码理解条件控制和潜在空间

5.2 避坑指南

  • 图像必须归一化到[-1, 1]
  • 时间步用torch.long类型
  • 采样时尽量用混合精度(FP16)加速

总结

扩散模型凭借稳定训练、高质量生成、强可控性三大优势,已经成为AI生成领域的绝对核心。从DDPM到Stable Diffusion再到Sora,它的演进速度极快,应用场景也越来越广。

如果有兴趣深入学习,推荐阅读开篇的三篇扩展论文哦!


基础深度学习 → 变分推断/VAE → GAN → 简化DDPM复现 → Stable Diffusion源码

🔗 扩展阅读