Diffusion Model详解:扩散模型原理与PyTorch实现

引言

如果说GAN是在博弈中"伪造"数据,CycleGAN是在循环中"翻译"风格,那么扩散模型(Diffusion Model)就是2021年之后图像生成领域的绝对霸主。它是Stable Diffusion、Midjourney和OpenAI Sora等顶尖AI的核心引擎,彻底改变了我们对图像生成的认知。

扩散模型的灵感并非来自艺术,而是来自非平衡热力学。它的逻辑非常反直觉:

  • GAN的目标是直接画出一张好图
  • 扩散模型的目标是学习如何将一堆杂乱无章的噪声,一步步还原成一张有意义的图像

想象你有一杯清水,滴入一滴墨水(数据),墨水会逐渐扩散直到整杯水变黑(噪声)。扩散模型的工作就是记录这个扩散过程,然后"倒放"录像,把墨黑的水变回清澈并提取出那一滴墨水。


1. 扩散模型概述

1.1 核心理念

扩散模型是一种生成模型,它通过一个前向扩散过程将数据逐渐转化为噪声,然后学习一个反向过程将噪声转化为数据。与GAN相比,扩散模型具有以下优势:

  • 训练稳定:不存在GAN的模式崩坏问题
  • 生成质量高:能够生成高质量、多样化的样本
  • 理论基础扎实:基于变分推断和马尔可夫链的理论
  • 可控性强:易于添加条件控制

1.2 两大核心过程

扩散模型主要由两个阶段组成:

1. 前向过程 (Forward Diffusion) —— "毁灭"

向原始图片中不断加入微小的高斯噪声。随着步骤T的增加,图片最终会变成完全无法辨认的纯噪声。这个过程是固定的,不需要学习。

2. 反向过程 (Reverse Diffusion) —— "重建"

这是模型真正学习的地方。它尝试预测在每一步中被加入的噪声是多少,并将其"减去"。通过数千次的微小去噪,模型从纯噪声中"变"出了一张精美的图像。


2. 数学原理详解

2.1 前向扩散过程

前向扩散过程定义了一个马尔可夫链,逐步向数据中添加噪声:

q(xtxt1)=N(xt;1βtxt1,βtI)q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t} x_{t-1}, \beta_t I)

其中βt\beta_t是预先设定的噪声调度参数。

完整的前向过程可以表示为:

q(x1:Tx0)=t=1Tq(xtxt1)q(x_{1:T} | x_0) = \prod_{t=1}^T q(x_t | x_{t-1})

根据重参数化技巧,我们可以直接从x0x_0计算任意时刻ttxtx_t

xt=αˉtx0+1αˉtϵx_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon

其中αt=1βt\alpha_t = 1 - \beta_tαˉt=s=1tαs\bar{\alpha}_t = \prod_{s=1}^t \alpha_sϵN(0,I)\epsilon \sim \mathcal{N}(0, I)

2.2 反向过程

反向过程是学习一个马尔可夫链来逆转前向扩散过程:

pθ(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t))p_\theta(x_{t-1} | x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \Sigma_\theta(x_t, t))

其中θ\theta表示神经网络的参数。

2.3 损失函数

扩散模型的训练目标是最小化变分上界:

L=Eq(x0)Eq(xtx0)Eq(xt1xt,x0)[logq(xt1xt,x0)pθ(xt1xt)]L = \mathbb{E}_{q(x_0)} \mathbb{E}_{q(x_t | x_0)} \mathbb{E}_{q(x_{t-1} | x_t, x_0)} \left[ \log \frac{q(x_{t-1} | x_t, x_0)}{p_\theta(x_{t-1} | x_t)} \right]

经过简化,主要的训练损失可以表示为:

Lt=Ex0,ϵ[ϵϵθ(xt,t)2]L_t = \mathbb{E}_{x_0, \epsilon} \left[ \| \epsilon - \epsilon_\theta(x_t, t) \|^2 \right]

这表明模型需要学习预测在时间步tt加入的噪声ϵ\epsilon


3. DDPM (Denoising Diffusion Probabilistic Models)

DDPM是扩散模型的经典实现,其核心思想是学习去除添加的噪声。

3.1 网络架构

DDPM使用U-Net作为骨干网络:

import torch
import torch.nn as nn
import torch.nn.functional as F

class SinusoidalPositionEmbeddings(nn.Module):
    """
    正弦位置编码,用于时间步嵌入
    """
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class Block(nn.Module):
    """
    基础卷积块
    """
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp =  nn.Linear(time_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU()

    def forward(self, x, t):
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        time_emb = time_emb[(..., ) + (None, ) * 2]
        h = h + time_emb
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        # Down or Upsample
        return self.transform(h)

class SimpleUnet(nn.Module):
    """
    简化的U-Net架构
    """
    def __init__(self):
        super().__init__()
        image_channels = 3
        down_channels = (64, 128, 256, 512, 1024)
        up_channels = (1024, 512, 256, 128, 64)
        out_dim = 3
        time_emb_dim = 32

        # Time embedding
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU()
        )
        
        # Initial projection
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

        # Downsample
        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], 
                                         time_emb_dim) for i in range(len(down_channels)-1)])

        # Upsample
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], 
                                       time_emb_dim, up=True) for i in range(len(up_channels)-1)])

        # Edit: Corrected the final layer
        self.output = nn.Conv2d(up_channels[-1], out_dim, 1)

    def forward(self, x, timestep):
        # Embedd time
        t = self.time_mlp(timestep)
        # Initial conv
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            # Add residual x as additional channels
            x = torch.cat((x, residual_x), dim=1)           
            x = up(x, t)
        return self.output(x)

3.2 噪声调度器

import numpy as np

def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    """
    线性噪声调度
    """
    return torch.linspace(start, end, timesteps)

def cosine_beta_schedule(timesteps, s=0.008):
    """
    余弦噪声调度
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

4. 训练过程

4.1 训练算法

def get_loss(model, x_0, t):
    """
    计算扩散模型的损失
    """
    x_noisy, noise = forward_diffusion_sample(x_0, t, device)
    noise_pred = model(x_noisy, t)
    return F.mse_loss(noise, noise_pred)

def forward_diffusion_sample(x_0, t, device):
    """
    前向扩散采样
    """
    noise = torch.randn_like(x_0)
    x_noisy = x_0 * expand_to_shape(sqrt_alphas_bar[t], x_0.shape) + \
              noise * expand_to_shape(sqrt_one_minus_alphas_bar[t], x_0.shape)
    return x_noisy, noise

def expand_to_shape(arr, shape):
    """
    扩展数组到指定形状
    """
    return arr[(..., ) + (None, ) * (len(shape) - 1)]

def train_model(model, dataloader, epochs, device, timesteps):
    """
    训练扩散模型
    """
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(epochs):
        for step, (batch, _) in enumerate(dataloader):
            optimizer.zero_grad()
            
            batch = batch.to(device)
            # 随机选择时间步
            t = torch.randint(0, timesteps, (batch.shape[0],), device=device).long()
            
            loss = get_loss(model, batch, t)
            loss.backward()
            optimizer.step()
            
            if step % 100 == 0:
                print(f"Epoch {epoch}, Step {step}, Loss: {loss.item():.4f}")

4.2 采样(生成)过程

@torch.no_grad()
def sample_timestep(x, t):
    """
    在时间步t进行采样
    """
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_bar_t = extract(sqrt_one_minus_alphas_bar, t, x.shape)
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    
    # 计算均值
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_bar_t
    )
    
    if t == 0:
        return model_mean
    else:
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        # 根据方差进行采样
        return model_mean + torch.sqrt(posterior_variance_t) * noise

@torch.no_grad()
def sample_plot_image():
    """
    生成图像样本
    """
    # 从随机噪声开始
    img = torch.randn((1, 3, 28, 28), device=device)
    plt.figure(figsize=(15,15))
    plt.axis('off')
    num_images = 10
    stepsize = int(T/num_images)

    for i in range(0,T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long)
        img = sample_timestep(img, t)
        if i % stepsize == 0:
            plt.subplot(1, num_images, int(i/stepsize)+1)
            show_tensor_image(img.detach().cpu())

5. 改进与变体

5.1 DDIM (Denoising Diffusion Implicit Models)

DDIM通过非马尔可夫过程加速采样,可以在较少的步骤内生成高质量图像。

5.2 Latent Diffusion (Stable Diffusion)

Latent Diffusion在潜在空间而非像素空间进行扩散,大大减少了计算复杂度:

class LatentDiffusionModel(nn.Module):
    """
    潜在扩散模型
    """
    def __init__(self, autoencoder, diffusion_model, clip_model):
        super().__init__()
        self.autoencoder = autoencoder  # VAE编码器/解码器
        self.diffusion_model = diffusion_model  # 扩散模型
        self.clip_model = clip_model  # 文本编码器
        
    def encode(self, x):
        """将图像编码到潜在空间"""
        return self.autoencoder.encode(x)
    
    def decode(self, z):
        """将潜在向量解码到图像空间"""
        return self.autoencoder.decode(z)
    
    def forward(self, z, t, context):
        """在潜在空间进行扩散"""
        return self.diffusion_model(z, t, context)

5.3 Classifier-Free Guidance

通过条件和无条件生成的组合,提高生成质量:

def classifier_free_guidance(x, t, text_embeddings, guidance_scale=7.5):
    """
    无分类器指导
    """
    # 无条件生成
    uncond_output = model(x, t, null_embedding)
    # 条件生成
    cond_output = model(x, t, text_embeddings)
    # 组合输出
    output = uncond_output + guidance_scale * (cond_output - uncond_output)
    return output

6. 实际应用

6.1 图像生成

def text_to_image_generation(prompt, model, tokenizer, scheduler):
    """
    文本到图像生成
    """
    # 编码文本提示
    text_embeddings = tokenizer.encode(prompt)
    
    # 从随机噪声开始
    latent = torch.randn((1, 4, 64, 64))  # 在潜在空间
    
    # 逐步去噪
    for t in reversed(range(scheduler.timesteps)):
        latent = scheduler.step(model, t, latent, text_embeddings)
    
    # 解码到图像
    image = model.decode(latent)
    return image

6.2 图像编辑

扩散模型可以用于图像编辑任务,如inpainting、super-resolution等。

6.3 视频生成

Sora等模型将扩散模型扩展到视频领域,实现时空连续的视频生成。


7. 评估指标

7.1 FID (Fréchet Inception Distance)

衡量生成图像与真实图像分布的差异。

7.2 IS (Inception Score)

评估生成图像的质量和多样性。

7.3 LPIPS (Learned Perceptual Image Patch Similarity)

感知相似性度量。


8. 挑战与解决方案

8.1 计算复杂度

  • 问题:需要大量采样步骤
  • 解决方案:DDIM、蒸馏、渐进蒸馏

8.2 模式坍塌

  • 问题:生成多样性不足
  • 解决方案:更好的架构设计、损失函数改进

8.3 采样速度

  • 问题:生成速度慢
  • 解决方案:加速采样算法、模型蒸馏

9. 实践建议

9.1 数据准备

  • 数据质量:使用高质量、一致的数据集
  • 数据预处理:归一化到[-1, 1]范围
  • 数据增强:适度使用,避免破坏语义

9.2 模型调优

  • 学习率调度:使用余弦退火等策略
  • 噪声调度:尝试不同的β调度策略
  • 架构选择:U-Net通常是最佳选择

9.3 部署考虑

  • 推理优化:使用加速采样算法
  • 模型压缩:知识蒸馏减少采样步骤
  • 硬件加速:利用GPU/TPU进行加速

10. 发展趋势与未来方向

10.1 技术发展方向

  • 多模态扩散:结合文本、图像、音频等多种模态
  • 3D扩散模型:生成3D内容
  • 物理感知扩散:结合物理规律的生成模型

10.2 应用前景

  • 创意产业:艺术创作、设计辅助
  • 科学研究:分子设计、材料发现
  • 教育娱乐:个性化内容生成

11. 总结

扩散模型作为当前最主流的生成模型之一,以其稳定的训练过程、高质量的生成结果和强大的可控性,成为AI生成领域的核心技术。从DDPM到Stable Diffusion,再到Sora,扩散模型不断演进,推动着生成AI的发展。

其核心优势在于:

  • 理论基础扎实:基于变分推断的严谨数学框架
  • 训练稳定:避免了GAN的训练不稳定性
  • 生成质量高:能够生成高质量、多样化的样本
  • 可控性强:易于添加条件控制

通过本文的详细分析和代码实现,读者应该对扩散模型的核心原理、实现细节和实际应用有了深入的理解,为进一步研究和应用生成模型打下坚实基础。


相关教程

建议先掌握基础的深度学习和概率论知识,再学习扩散模型。通过复现DDPM等经典模型,可以更好地理解其工作原理和训练技巧。

🔗 扩展阅读