title: Detailed explanation of Diffusion Model: Diffusion model principle and PyTorch implementation | Daoman PythonAI description: In-depth analysis of Diffusion Model (diffusion model), introducing its application in tasks such as image generation, image generation, video generation, etc., including detailed architecture analysis, PyTorch implementation and practical application scenarios, covering core technologies such as DDPM and Stable Diffusion. keywords: [Diffusion Model, diffusion model, image generation, DDPM, Stable Diffusion, Vincent diagram, AI painting, deep learning, computer vision, PyTorch]

Detailed explanation of Diffusion Model: Diffusion model principle and PyTorch implementation

Hello everyone, this is Daoman PythonAI! Today we are dismantling the core technology that made Midjourney and Sora popular - Diffusion Model.

If GAN "forges" data in the game, and CycleGAN "translates" the style in the cycle, then the diffusion model will be the absolute overlord of image generation after 2021. Its inspiration does not come from art, but from non-equilibrium thermodynamics: first "reverse" the process of ink diffusion into black liquid, and restore noise to clear data.


1. Two cores of diffusion model

The diffusion model is a generative model that is divided into two stages: fixed forward destruction and learnable reverse reconstruction:

1.1 Forward Destruction: Adding Noise to a Wasted Image

to the original picturex₀Tiny Gaussian noise is continuously added to theTAfter steps (usually 1000-4000),x_Tbecomes completely unrecognizable standard normal distribution noise. This step requires no training and is a preset Markov chain.

1.2 Reverse reconstruction: denoising and making new images

The model learns to predict the noise being added at each stepε, and then use mathematical formulas fromx_tPush backx_{t-1}. After thousands of tiny denoising, finally fromx_TGenerate a meaningful new image.


2. Mathematical minimalist derivation

2.1 Forward formula: Calculate noise in one step

The original Markov chain needs to add noise step by step, but with the help of heavy parameterization technique, we can directly start from the original imagex₀deduce anytNoisy imagex_t, greatly simplifying the calculation:

Noise addition formula (text description) x_t = sqrt(alpha_bar_t) * x₀ + sqrt(1 - alpha_bar_t) * ε
Among themεis from the standard normal distributionN(0, I)Randomly sampled noise.

The meaning of the key symbols in the formula:

  • α_t = 1 - β_tβ_tIt is noise scheduling, which controls how much noise is added at each step.β_tThe larger the value, the faster the image will be destroyed, but the value is usually no more than 0.02.
  • alpha_bar_tIt's beforetstep allα_tThe cumulative product of:alpha_bar_t = α_1 * α_2 * ... * α_t

With this one-step formula, only one command is needed during training to generate images with any noise level, which is extremely efficient.

2.2 Reverse training: just predict the noise

The training goal of the diffusion model is extremely simple - directly let the model predict the noiseε_θwith real added noiseεDo the mean square error (MSE). The form of the loss function is:

Loss function L_t = 期望值 || ε - ε_θ(x_t, t) ||²

That is, we train a neural networkε_θ, input the current noisy imagex_tand time stept, output a guess about the noise, and then calculate the squared difference between it and the true noise. This simple and stable loss function is an important cornerstone of the success of diffusion models.


3. DDPM core implementation

DDPM (Denoising Diffusion Probabilistic Models) is the pioneering classic of diffusion models. We use PyTorch to implement the core module.

3.1 Core Network: Simplified U-Net with Time Embedding

The diffusion model needs to know "which step is the noise addition" currently, so it is necessary to add sine position coding to the time steptEncode into vectors and fuse with image features:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SinusoidalPositionEmbeddings(nn.Module):
    """正弦位置编码,用于时间步嵌入"""
    def __init__(self, dim):
        super().__init__()
        self.half_dim = dim // 2

    def forward(self, time):
        device = time.device
        emb = math.log(10000) / (self.half_dim - 1)
        emb = torch.exp(torch.arange(self.half_dim, device=device) * -emb)
        emb = time[:, None] * emb[None, :]
        return torch.cat((emb.sin(), emb.cos()), dim=-1)

class SimpleUnet(nn.Module):
    """带时间嵌入的简化U-Net"""
    def __init__(self):
        super().__init__()
        img_ch, time_emb_dim = 3, 32
        down_ch = (64, 128, 256, 512)
        up_ch = (512, 256, 128, 64)

        # 时间嵌入层
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim * 2),
            nn.ReLU(),
            nn.Linear(time_emb_dim * 2, time_emb_dim)
        )
        # 初始投影
        self.conv0 = nn.Conv2d(img_ch, down_ch[0], 3, padding=1)
        # 下采样+残差连接
        self.downs = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(down_ch[i], down_ch[i+1], 3, padding=1),
                nn.BatchNorm2d(down_ch[i+1]),
                nn.ReLU(),
                nn.Conv2d(down_ch[i+1], down_ch[i+1], 3, padding=1),
                nn.BatchNorm2d(down_ch[i+1]),
                nn.ReLU(),
                nn.Conv2d(down_ch[i+1], down_ch[i+1], 4, 2, 1)
            ) for i in range(len(down_ch)-1)
        ])
        # 上采样+残差拼接
        self.ups = nn.ModuleList([
            nn.Sequential(
                nn.ConvTranspose2d(up_ch[i], up_ch[i]//2, 4, 2, 1),
                nn.BatchNorm2d(up_ch[i]//2),
                nn.ReLU(),
                nn.Conv2d(up_ch[i], up_ch[i+1], 3, padding=1),
                nn.BatchNorm2d(up_ch[i+1]),
                nn.ReLU(),
                nn.Conv2d(up_ch[i+1], up_ch[i+1], 3, padding=1),
                nn.BatchNorm2d(up_ch[i+1]),
                nn.ReLU()
            ) for i in range(len(up_ch)-1)
        ])
        # 输出层(预测噪声)
        self.output = nn.Conv2d(up_ch[-1], img_ch, 1)

    def forward(self, x, t):
        # 时间嵌入
        t_emb = self.time_mlp(t)[(..., ) + (None, ) * 2]  # 扩展到[B, C, 1, 1]
        # 初始卷积
        x = self.conv0(x)
        # 下采样并保存残差
        residuals = []
        for down in self.downs:
            x = down(x) + t_emb  # 融合时间信息
            residuals.append(x)
        # 上采样并拼接残差
        for up, res in zip(self.ups, reversed(residuals)):
            x = up[:1](x)  # 先上采样
            x = torch.cat([x, res], dim=1)  # 拼接残差
            x = up[1:](x) + t_emb
        # 输出
        return self.output(x)

3.2 Noise scheduling and auxiliary functions

def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    """线性噪声调度"""
    return torch.linspace(start, end, timesteps)

def precompute_schedule(betas):
    """预计算α相关的所有值,避免训练/采样时重复计算"""
    alphas = 1.0 - betas
    alphas_bar = torch.cumprod(alphas, dim=0)
    sqrt_alphas_bar = torch.sqrt(alphas_bar)
    sqrt_one_minus_alphas_bar = torch.sqrt(1.0 - alphas_bar)
    # 反向采样需要的参数
    posterior_variance = betas * (1.0 - alphas_bar[:-1]) / (1.0 - alphas_bar[1:])
    posterior_variance = torch.cat([torch.tensor([1e-20]), posterior_variance])
    return {
        "sqrt_alphas_bar": sqrt_alphas_bar,
        "sqrt_one_minus_alphas_bar": sqrt_one_minus_alphas_bar,
        "posterior_variance": posterior_variance,
        "betas": betas,
        "sqrt_recip_alphas": torch.sqrt(1.0 / alphas),
        "sqrt_recip_alphas_bar_minus_1": torch.sqrt(1.0 / alphas_bar - 1.0)
    }

def extract(arr, t, shape):
    """从预计算数组中提取对应时间步的值,并扩展到指定形状"""
    batch_size = shape[0]
    out = arr.gather(-1, t)
    return out.reshape(batch_size, *([1]*(len(shape)-1)))

3.3 Training and sampling loop

def get_loss(model, x0, t, schedule):
    """计算DDPM的MSE损失"""
    noise = torch.randn_like(x0)
    xt = (
        extract(schedule["sqrt_alphas_bar"], t, x0.shape) * x0 +
        extract(schedule["sqrt_one_minus_alphas_bar"], t, x0.shape) * noise
    )
    noise_pred = model(xt, t)
    return F.mse_loss(noise, noise_pred)

@torch.no_grad()
def sample_one_step(model, xt, t, schedule):
    """反向采样一步"""
    betas_t = extract(schedule["betas"], t, xt.shape)
    sqrt_recip_alphas_t = extract(schedule["sqrt_recip_alphas"], t, xt.shape)
    sqrt_recip_alphas_bar_minus_1_t = extract(schedule["sqrt_recip_alphas_bar_minus_1"], t, xt.shape)
    # 计算均值
    mean = sqrt_recip_alphas_t * (
        xt - betas_t * model(xt, t) / sqrt_recip_alphas_bar_minus_1_t
    )
    # 采样噪声(t=0时不加)
    posterior_var_t = extract(schedule["posterior_variance"], t, xt.shape)
    noise = torch.randn_like(xt) if t[0] > 0 else 0.0
    return mean + torch.sqrt(posterior_var_t) * noise

4. Mainstream improvements and applications

4.1 Latent Diffusion (Stable Diffusion Core)

Training the diffusion model in pixel space takes up too much video memory (a 1024×1024 image has 3 million+ pixels), so Stable Diffusion compresses the image to 1/8 resolution latent space (4 channels) and then trains:

  • Compress/decompress images with Pretrained VAE
  • Use CLIP text encoder to convert prompt words into conditional vectors
  • Conditional U-Net denoising in latent space

4.2 Core application scenarios

  • 文生图/图生图: Midjourney, Stable Diffusion WebUI
  • Video Generation: Sora, Pika Labs (diffusion in space-time potential space)
  • Image Repair: Inpainting, Outpainting, Super Resolution
  • Science research: molecular generation, material design

5. Daoman’s practical suggestions

5.1 Get started quickly

  1. First reproduce the simplified DDPM of MNIST/CIFAR-10
  2. Use Stable Diffusion WebUI to play with Vincent pictures
  3. Finally read the source code to understand conditional control and potential space

5.2 Pitfall avoidance guide

  • The image** must be normalized to [-1, 1]**
  • for time stepstorch.longtype
  • Try to use mixed precision (FP16) acceleration when sampling

Summarize

With three major advantages: stable training, high-quality generation, and strong controllability, the diffusion model has become the absolute core in the field of AI generation. From DDPM to Stable Diffusion to Sora, its evolution is extremely fast and its application scenarios are becoming wider and wider.

If you are interested in learning more, we recommend reading the extended essay at the end of the article!


Basic deep learning → Variational inference/VAE → GAN → Simplified DDPM reproduction → Stable Diffusion source code

🔗 Extended reading