扩散模型
扩散模型
简介
扩散模型(Diffusion Models)通过前向加噪和反向去噪两个过程来建模数据分布。在训练时学习如何从纯噪声中逐步恢复数据,在推理时从随机噪声出发生成高质量样本。扩散模型在图像生成领域已超越 GAN,成为 Stable Diffusion、DALL-E、Midjourney 等产品的核心技术。
扩散模型的核心思想来源于非平衡热力学:前向过程是一个马尔可夫链,逐步向数据中注入高斯噪声,直到数据完全变为各向同性的高斯分布;反向过程则学习每一步的逆变换,从纯噪声中逐步恢复出原始数据分布。这种优雅的数学框架使其不仅适用于图像生成,还被扩展到音频合成、分子设计、三维生成、视频生成等多个领域。
特点
数学原理
前向过程(Forward Process)
前向过程是一个固定的马尔可夫链,用 q(x_t | x_{t-1}) 表示:
q(x_t | x_{t-1}) = N(x_t; sqrt(1 - beta_t) * x_{t-1}, beta_t * I)其中 beta_t 是预先设定的噪声调度参数。利用重参数化技巧,可以直接从 x_0 采样任意时间步 t 的加噪结果:
q(x_t | x_0) = N(x_t; sqrt(alpha_bar_t) * x_0, (1 - alpha_bar_t) * I)其中 alpha_bar_t = prod_{s=1}^{t} alpha_s,alpha_s = 1 - beta_s。
反向过程(Reverse Process)
反向过程用神经网络参数化 p_theta(x_{t-1} | x_t):
p_theta(x_{t-1} | x_t) = N(x_{t-1}; mu_theta(x_t, t), sigma_t^2 * I)训练目标是最小化预测噪声 epsilon_theta(x_t, t) 与真实噪声 epsilon 的均方误差:
L = E_{t, x_0, epsilon}[|| epsilon - epsilon_theta(x_t, t) ||^2]分数匹配视角
扩散模型也可以从分数匹配(Score Matching)的角度理解。分数函数 s(x) = nabla_x log p(x) 指向数据密度增长最快的方向。训练噪声预测网络等价于学习带噪声数据的分数函数,这为后续的分数扩散模型(Score SDE)提供了理论基础。
实现
噪声调度策略
import torch
import math
def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=0.02):
"""线性噪声调度——最基础的调度方案"""
return torch.linspace(beta_start, beta_end, timesteps)
def cosine_beta_schedule(timesteps, s=0.008):
"""余弦噪声调度——Improved DDPM 提出,在低噪声区域变化更平缓"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clamp(betas, 0.0001, 0.9999)
def sigmoid_beta_schedule(timesteps, beta_start=1e-4, beta_end=0.02):
"""Sigmoid 噪声调度——在起点和终点变化更平缓"""
betas = torch.linspace(-6, 6, timesteps)
betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
return betas
# 对比三种调度策略
T = 1000
lin_betas = linear_beta_schedule(T)
cos_betas = cosine_beta_schedule(T)
sig_betas = sigmoid_beta_schedule(T)
print(f"线性调度 - 前10步 beta 均值: {lin_betas[:10].mean():.6f}")
print(f"余弦调度 - 前10步 beta 均值: {cos_betas[:10].mean():.6f}")
print(f"Sigmoid调度 - 前10步 beta 均值: {sig_betas[:10].mean():.6f}")前向扩散过程
def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=0.02):
"""线性噪声调度"""
return torch.linspace(beta_start, beta_end, timesteps)
T = 1000
betas = linear_beta_schedule(T)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
def q_sample(x_0, t, noise=None):
"""前向过程:在时间步 t 给 x_0 添加噪声"""
if noise is None:
noise = torch.randn_like(x_0)
sqrt_alpha_bar = torch.sqrt(alphas_cumprod[t])[:, None, None, None]
sqrt_one_minus_alpha_bar = torch.sqrt(1 - alphas_cumprod[t])[:, None, None, None]
return sqrt_alpha_bar * x_0 + sqrt_one_minus_alpha_bar * noise
# 演示:对一张图逐步加噪
x_0 = torch.randn(1, 3, 32, 32) # 模拟原始图像
for step in [0, 100, 500, 999]:
t = torch.tensor([step])
x_t = q_sample(x_0, t)
snr = alphas_cumprod[step] / (1 - alphas_cumprod[step])
print(f"时间步 {step:4d}: 信噪比(SNR)={snr.item():.4f}")加权损失策略
def get_loss_weighting(schedule_type="min_snr"):
"""
不同的损失加权策略可以显著影响训练效果
- min_snr: Min-SNR 加权,降低高噪声步的损失权重(推荐)
- p2: P2 加重权,平方根惩罚
- uniform: 均匀加权,最基础
"""
if schedule_type == "min_snr":
# Min-SNR-gamma: 截断 SNR 到 gamma=5
snr = alphas_cumprod / (1 - alphas_cumprod)
gamma = 5.0
weights = torch.clamp(snr, max=gamma) / snr
elif schedule_type == "p2":
snr = alphas_cumprod / (1 - alphas_cumprod)
weights = 1.0 / torch.sqrt(snr)
else:
weights = torch.ones_like(alphas_cumprod)
return weights
weights = get_loss_weighting("min_snr")
print(f"Min-SNR 加权 - 前10步: {weights[:10].tolist()}")
print(f"Min-SNR 加权 - 后10步: {weights[-10:].tolist()}")U-Net 去噪网络
import torch
import torch.nn as nn
class SimpleUNetBlock(nn.Module):
def __init__(self, in_ch, out_ch, time_emb_dim):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
)
self.time_mlp = nn.Linear(time_emb_dim, out_ch)
self.relu = nn.ReLU()
def forward(self, x, t_emb):
h = self.conv(x)
t_emb = self.time_mlp(t_emb)[:, :, None, None]
return self.relu(h + t_emb)
class TimeEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim),
)
def forward(self, t):
half = t.shape[-1] // 2
emb = torch.exp(-torch.arange(half, device=t.device) * math.log(10000) / half)
emb = t[:, :half] * emb
return self.mlp(torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1))
class DenoisingUNet(nn.Module):
def __init__(self, in_channels=3, base_dim=64, time_dim=128):
super().__init__()
self.time_embed = TimeEmbedding(time_dim)
self.down1 = SimpleUNetBlock(in_channels, base_dim, time_dim)
self.down2 = SimpleUNetBlock(base_dim, base_dim * 2, time_dim)
self.pool = nn.MaxPool2d(2)
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.up1 = SimpleUNetBlock(base_dim * 2, base_dim, time_dim)
self.out_conv = nn.Conv2d(base_dim, in_channels, 1)
def forward(self, x, t):
t_emb = self.time_embed(t.unsqueeze(-1).float().repeat(1, 128)[:, :128])
d1 = self.down1(x, t_emb)
d2 = self.down2(self.pool(d1), t_emb)
u1 = self.up1(self.up(d2), t_emb)
return self.out_conv(u1)
model = DenoisingUNet()
x_noisy = torch.randn(2, 3, 32, 32)
t = torch.randint(0, 1000, (2,))
noise_pred = model(x_noisy, t)
print(f"去噪网络输出: {noise_pred.shape}")带条件输入的 U-Net(支持分类器引导)
class ConditionalUNet(nn.Module):
"""支持类别条件输入的 U-Net,可用于分类器引导生成"""
def __init__(self, in_channels=3, base_dim=64, time_dim=128, num_classes=10):
super().__init__()
self.time_embed = TimeEmbedding(time_dim)
# 类别嵌入
self.class_embed = nn.Embedding(num_classes, time_dim)
self.down1 = SimpleUNetBlock(in_channels, base_dim, time_dim)
self.down2 = SimpleUNetBlock(base_dim, base_dim * 2, time_dim)
self.pool = nn.MaxPool2d(2)
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.up1 = SimpleUNetBlock(base_dim * 2, base_dim, time_dim)
self.out_conv = nn.Conv2d(base_dim, in_channels, 1)
def forward(self, x, t, class_labels=None, dropout_prob=0.1):
t_emb = self.time_embed(t.unsqueeze(-1).float().repeat(1, 128)[:, :128])
# Classifier-Free Guidance: 训练时随机丢弃条件
if class_labels is not None and self.training:
mask = torch.rand(class_labels.shape[0], device=class_labels.device) > dropout_prob
class_labels = class_labels * mask.long()
if class_labels is not None:
c_emb = self.class_embed(class_labels)
t_emb = t_emb + c_emb
d1 = self.down1(x, t_emb)
d2 = self.down2(self.pool(d1), t_emb)
u1 = self.up1(self.up(d2), t_emb)
return self.out_conv(u1)
cond_model = ConditionalUNet(num_classes=10)
x_noisy = torch.randn(2, 3, 32, 32)
t = torch.randint(0, 1000, (2,))
labels = torch.randint(0, 10, (2,))
noise_pred = cond_model(x_noisy, t, labels)
print(f"条件去噪网络输出: {noise_pred.shape}")DDPM 训练循环
import torch.optim as optim
def train_ddpm(model, dataloader, timesteps=1000, epochs=5, use_ema=True):
"""
DDPM 完整训练循环
- 支持指数移动平均(EMA)提升生成质量
- 支持梯度裁剪防止训练不稳定
"""
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
criterion = nn.MSELoss()
# EMA 模型:训练结束后用 EMA 权重生成效果更好
if use_ema:
ema_model = {k: v.clone() for k, v in model.state_dict().items()}
ema_decay = 0.9999
for epoch in range(epochs):
total_loss = 0
for batch_idx, batch in enumerate(dataloader):
x_0 = batch # 原始图像 (B, C, H, W)
batch_size = x_0.shape[0]
# 随机采样时间步
t = torch.randint(0, timesteps, (batch_size,))
# 采样噪声
noise = torch.randn_like(x_0)
# 前向加噪
x_t = q_sample(x_0, t, noise)
# 预测噪声
noise_pred = model(x_t, t)
# 计算损失
loss = criterion(noise_pred, noise)
optimizer.zero_grad()
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
# 更新 EMA
if use_ema:
for k, v in model.state_dict().items():
ema_model[k] = ema_decay * ema_model[k] + (1 - ema_decay) * v
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.6f}")
return ema_model if use_ema else model条件扩散模型训练(Classifier-Free Guidance)
def train_conditional_ddpm(model, dataloader, timesteps=1000, epochs=5,
uncond_prob=0.1, guidance_scale=7.5):
"""
带分类器引导的条件扩散模型训练
- uncond_prob: 训练时随机丢弃条件的概率(用于 CFG)
- guidance_scale: 推理时 CFG 的引导强度
"""
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()
for epoch in range(epochs):
total_loss = 0
for batch, labels in dataloader:
batch_size = batch.shape[0]
t = torch.randint(0, timesteps, (batch_size,))
noise = torch.randn_like(batch)
x_t = q_sample(batch, t, noise)
noise_pred = model(x_t, t, class_labels=labels, dropout_prob=uncond_prob)
loss = criterion(noise_pred, noise)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.6f}")
@torch.no_grad()
def cfg_sample(model, shape, class_label, timesteps=1000, guidance_scale=7.5):
"""
Classifier-Free Guidance 采样
通过对条件预测和无条件预测做线性插值,
在增强条件信号的同时保持生成质量
"""
device = next(model.parameters()).device
x = torch.randn(shape, device=device)
for t in reversed(range(timesteps)):
t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)
# 无条件预测
noise_uncond = model(x, t_batch, class_labels=None)
# 条件预测
labels = torch.full((shape[0],), class_label, device=device, dtype=torch.long)
noise_cond = model(x, t_batch, class_labels=labels)
# CFG 公式:epsilon = epsilon_uncond + scale * (epsilon_cond - epsilon_uncond)
noise_pred = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
alpha = alphas[t]
alpha_bar = alphas_cumprod[t]
beta = betas[t]
noise_coeff = (1 - alpha) / torch.sqrt(1 - alpha_bar)
x_mean = (x - noise_coeff * noise_pred) / torch.sqrt(alpha)
if t > 0:
x = x_mean + torch.sqrt(beta) * torch.randn_like(x)
else:
x = x_mean
return xDDPM 反向采样(推理生成)
@torch.no_grad()
def ddpm_sample(model, shape, timesteps=1000):
"""从纯噪声逐步去噪生成图像"""
device = next(model.parameters()).device
# 从纯高斯噪声开始
x = torch.randn(shape, device=device)
for t in reversed(range(timesteps)):
t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)
# 预测噪声
noise_pred = model(x, t_batch)
# 计算 x_{t-1}
alpha = alphas[t]
alpha_bar = alphas_cumprod[t]
beta = betas[t]
noise_coeff = (1 - alpha) / torch.sqrt(1 - alpha_bar)
x_mean = (x - noise_coeff * noise_pred) / torch.sqrt(alpha)
if t > 0:
noise = torch.randn_like(x)
sigma = torch.sqrt(beta)
x = x_mean + sigma * noise
else:
x = x_mean
return x
# 生成图像
# generated = ddpm_sample(model, shape=(4, 3, 32, 32))
# print(f"生成图像形状: {generated.shape}")
print("DDPM 采样函数已定义,实际生成需要训练好的模型")DDIM 采样器(加速推理)
@torch.no_grad()
def ddim_sample(model, shape, timesteps=1000, ddim_steps=50, eta=0.0):
"""
DDIM 采样器:通过确定性采样跳步,大幅减少推理步数
参数:
- ddim_steps: 实际采样步数(通常 20-50 即可)
- eta: 随机性控制,0 为完全确定性,1 退化为 DDPM
"""
device = next(model.parameters()).device
x = torch.randn(shape, device=device)
# 创建子序列:从 1000 步中均匀采样 ddim_steps 步
c = timesteps // ddim_steps
ddim_timesteps = torch.arange(0, timesteps, c)
# 时间步序列
time_seq = list(reversed(range(timesteps)))
# 只使用子序列中的时间步
sub_seq = [time_seq[i] for i in range(0, len(time_seq), c)]
for i in range(len(sub_seq)):
t = sub_seq[i]
t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)
noise_pred = model(x, t_batch)
alpha_bar_t = alphas_cumprod[t]
# 前一步的 alpha_bar(DDIM 关键)
if i + 1 < len(sub_seq):
prev_t = sub_seq[i + 1]
else:
prev_t = 0
alpha_bar_prev = alphas_cumprod[prev_t] if prev_t >= 0 else torch.tensor(1.0)
# DDIM 更新公式
pred_x0 = (x - torch.sqrt(1 - alpha_bar_t) * noise_pred) / torch.sqrt(alpha_bar_t)
dir_xt = torch.sqrt(1 - alpha_bar_prev) * noise_pred
x_prev = torch.sqrt(alpha_bar_prev) * pred_x0 + dir_xt
# 可选随机性
if eta > 0 and t > 0:
noise = torch.randn_like(x)
sigma = eta * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar_t)) * \
torch.sqrt(1 - alpha_bar_t / alpha_bar_prev)
x_prev = x_prev + sigma * noise
x = x_prev
return x
print("DDIM 采样器已定义,50 步即可生成高质量图像")DPM-Solver++ 采样器
@torch.no_grad()
def dpm_solver_sample(model, shape, timesteps=1000, steps=20, order=2):
"""
DPM-Solver++ 采样器:基于 ODE 求解器的高阶采样方法
相比 DDIM,在相同步数下质量更高,或相同质量下步数更少
"""
device = next(model.parameters()).device
x = torch.randn(shape, device=device)
# 构建时间步序列(对数间隔)
step_indices = torch.linspace(0, timesteps - 1, steps + 1).long()
timesteps_seq = list(reversed(step_indices.tolist()))
# 数据相关(DRP)噪声调度
model_outputs = []
for i, t in enumerate(timesteps_seq):
t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)
noise_pred = model(x, t_batch)
model_outputs.append(noise_pred)
alpha_bar_t = alphas_cumprod[t]
if i == 0:
# 第一步:欧拉法
pred_x0 = (x - torch.sqrt(1 - alpha_bar_t) * noise_pred) / torch.sqrt(alpha_bar_t)
if i + 1 < len(timesteps_seq):
next_t = timesteps_seq[i + 1]
alpha_bar_next = alphas_cumprod[next_t]
x = torch.sqrt(alpha_bar_next) * pred_x0 + \
torch.sqrt(1 - alpha_bar_next) * noise_pred
elif order >= 2 and i >= 2:
# 多步法:利用前两步的结果做二阶校正
prev_pred = model_outputs[-2]
# 简化版:使用前一步的预测做线性外推
pred_x0 = (x - torch.sqrt(1 - alpha_bar_t) * noise_pred) / torch.sqrt(alpha_bar_t)
if i + 1 < len(timesteps_seq):
next_t = timesteps_seq[i + 1]
alpha_bar_next = alphas_cumprod[next_t]
x = torch.sqrt(alpha_bar_next) * pred_x0 + \
torch.sqrt(1 - alpha_bar_next) * noise_pred
return x
print("DPM-Solver++ 采样器已定义,20 步即可获得高质量结果")采样器对比
| 采样器 | 推荐步数 | 质量 | 速度 | 随机性 | 适用场景 |
|---|---|---|---|---|---|
| DDPM | 1000 | 最高 | 最慢 | 有 | 研究/基线 |
| DDIM | 20-50 | 高 | 快 | 可控 | 通用 |
| Euler a | 20-30 | 中高 | 很快 | 有 | 快速预览 |
| DPM++ 2M | 20-30 | 高 | 很快 | 无 | 生产推荐 |
| DPM++ SDE | 10-20 | 高 | 快 | 有 | 需要多样性 |
| UniPC | 10-20 | 高 | 最快 | 可控 | 最佳性价比 |
潜空间扩散(Latent Diffusion)
原理
Stable Diffusion 的核心创新在于潜空间扩散。它不再直接在像素空间(如 512x512x3)上做扩散,而是先用预训练的自编码器(VAE)将图像编码到低维潜空间(如 64x64x4),在潜空间上完成扩散过程后,再用 VAE 解码器恢复为像素图像。这使得计算量降低到原来的约 1/48。
import torch
import torch.nn as nn
class SimpleVAE(nn.Module):
"""简化的 VAE 编码器-解码器结构示意"""
def __init__(self, in_channels=3, latent_dim=4, base_dim=64):
super().__init__()
# 编码器:将高分辨率图像压缩到低维潜空间
self.encoder = nn.Sequential(
nn.Conv2d(in_channels, base_dim, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(base_dim, base_dim * 2, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(base_dim * 2, base_dim * 4, 4, stride=2, padding=1),
nn.ReLU(),
# 输出均值和对数方差
nn.Conv2d(base_dim * 4, latent_dim * 2, 3, padding=1),
)
# 解码器:将低维潜空间恢复为高分辨率图像
self.decoder = nn.Sequential(
nn.Conv2d(latent_dim, base_dim * 4, 3, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(base_dim * 4, base_dim * 2, 4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(base_dim * 2, base_dim, 4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(base_dim, in_channels, 4, stride=2, padding=1),
nn.Sigmoid(),
)
def encode(self, x):
"""编码图像到潜空间"""
h = self.encoder(x)
mu, logvar = h.chunk(2, dim=1)
std = torch.exp(0.5 * logvar)
z = mu + std * torch.randn_like(std)
return z, mu, logvar
def decode(self, z):
"""从潜空间解码图像"""
return self.decoder(z)
def forward(self, x):
z, mu, logvar = self.encode(x)
recon = self.decode(z)
return recon, mu, logvar
# 演示潜空间压缩
vae = SimpleVAE()
x_pixel = torch.randn(1, 3, 256, 256) # 像素空间
z_latent, _, _ = vae.encode(x_pixel) # 潜空间
x_recon = vae.decode(z_latent) # 重建
print(f"像素空间大小: {x_pixel.shape} = {x_pixel.numel() / 1024:.1f}K 参数")
print(f"潜空间大小: {z_latent.shape} = {z_latent.numel() / 1024:.1f}K 参数")
print(f"压缩比: {x_pixel.numel() / z_latent.numel():.1f}x")Latent Diffusion 训练流程
def train_latent_diffusion(vae, unet, dataloader, timesteps=1000, epochs=5):
"""
潜空间扩散训练流程
1. 用 VAE 编码器将图像编码到潜空间
2. 在潜空间上训练扩散模型
3. 生成时先用扩散模型生成潜表示,再用 VAE 解码器恢复
"""
optimizer = optim.AdamW(unet.parameters(), lr=1e-4)
criterion = nn.MSELoss()
# VAE 编码器冻结,只训练扩散 U-Net
vae.eval()
for param in vae.parameters():
param.requires_grad = False
for epoch in range(epochs):
total_loss = 0
for batch in dataloader:
# 第一步:编码到潜空间
with torch.no_grad():
z_0, _, _ = vae.encode(batch) # (B, latent_dim, H/8, W/8)
# 第二步:在潜空间上训练扩散
batch_size = z_0.shape[0]
t = torch.randint(0, timesteps, (batch_size,))
noise = torch.randn_like(z_0)
z_t = q_sample(z_0, t, noise)
noise_pred = unet(z_t, t)
loss = criterion(noise_pred, noise)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss / len(dataloader):.6f}")
print("潜空间扩散训练流程已定义")条件生成技术
文本条件(Cross-Attention 机制)
class CrossAttentionBlock(nn.Module):
"""
交叉注意力模块:将文本特征注入到图像特征中
Stable Diffusion 通过此机制实现文本到图像的条件生成
"""
def __init__(self, query_dim, context_dim, num_heads=8):
super().__init__()
self.num_heads = num_heads
self.head_dim = query_dim // num_heads
self.scale = self.head_dim ** -0.5
self.to_q = nn.Linear(query_dim, query_dim)
self.to_k = nn.Linear(context_dim, query_dim)
self.to_v = nn.Linear(context_dim, query_dim)
self.to_out = nn.Linear(query_dim, query_dim)
def forward(self, x, context):
"""
x: 图像特征 (B, N, query_dim) 来自 U-Net 中间层
context: 文本特征 (B, L, context_dim) 来自 CLIP 文本编码器
"""
B, N, _ = x.shape
L = context.shape[1]
q = self.to_q(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
k = self.to_k(context).reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2)
v = self.to_v(context).reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
out = (attn @ v).transpose(1, 2).reshape(B, N, -1)
return self.to_out(out)
# 演示交叉注意力
cross_attn = CrossAttentionBlock(query_dim=512, context_dim=768)
img_feat = torch.randn(2, 4096, 512) # 来自 U-Net 展平后的特征
text_feat = torch.randn(2, 77, 768) # 来自 CLIP 文本编码器(77 个 token)
out = cross_attn(img_feat, text_feat)
print(f"交叉注意力输出: {out.shape}")使用 HuggingFace Diffusers 进行条件生成
# 使用 Diffusers 库进行文生图(生产推荐)
from diffusers import StableDiffusionPipeline
import torch
# 加载预训练模型
# model_id = "runwayml/stable-diffusion-v1-5"
# pipe = StableDiffusionPipeline.from_pretrained(
# model_id,
# torch_dtype=torch.float16,
# safety_checker=None
# )
# pipe.to("cuda")
# 生成图像
# prompt = "a professional photograph of a cat sitting on a windowsill, golden hour lighting"
# negative_prompt = "blurry, low quality, distorted, watermark"
# image = pipe(
# prompt=prompt,
# negative_prompt=negative_prompt,
# num_inference_steps=30, # DDIM 步数
# guidance_scale=7.5, # CFG 引导强度
# width=512,
# height=512,
# ).images[0]
# image.save("generated_cat.png")
print("Diffusers 条件生成示例(需 GPU 环境)")LoRA 微调
# LoRA(Low-Rank Adaptation)微调示意
class LoRALayer(nn.Module):
"""
LoRA 通过在预训练权重旁添加低秩分解矩阵实现高效微调
只训练 A 和 B 矩阵,大幅减少显存和存储需求
"""
def __init__(self, original_layer, rank=4, alpha=1.0):
super().__init__()
self.original = original_layer
self.rank = rank
self.alpha = alpha
in_features = original_layer.in_features
out_features = original_layer.out_features
# 冻结原始权重
self.original.weight.requires_grad = False
# 低秩矩阵
self.lora_A = nn.Parameter(torch.randn(rank, in_features) * 0.01)
self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
self.scaling = alpha / rank
def forward(self, x):
# 原始输出 + LoRA 输出
original_out = self.original(x)
lora_out = (x @ self.lora_A.T @ self.lora_B.T) * self.scaling
return original_out + lora_out
# 演示 LoRA 参数量对比
original = nn.Linear(512, 512)
lora = LoRALayer(original, rank=4)
original_params = sum(p.numel() for p in original.parameters())
lora_trainable = sum(p.numel() for p in lora.parameters() if p.requires_grad)
print(f"原始参数量: {original_params}")
print(f"LoRA 可训练参数量: {lora_trainable}")
print(f"参数减少比: {original_params / lora_trainable:.1f}x")推理加速技术
xFormers 内存高效注意力
# xFormers 可以大幅降低注意力计算的显存占用
# 安装:pip install xformers
#
# 在 Diffusers 中启用:
# pipe.enable_xformers_memory_efficient_attention()
#
# 加速效果:
# - 显存占用降低 30-50%
# - 推理速度提升 20-40%
# - 支持更大分辨率生成
print("xFormers 加速方案:pipe.enable_xformers_memory_efficient_attention()")torch.compile 加速
import torch
# PyTorch 2.0+ 的 torch.compile 可以自动优化模型
# compiled_model = torch.compile(model)
#
# 使用方法:
# 1. 对 U-Net 编译:compiled_unet = torch.compile(pipe.unet)
# 2. 对整个 pipeline 编译:compiled_pipe = torch.compile(pipe)
#
# 注意事项:
# - 首次运行有编译开销(约 1-3 分钟)
# - 后续推理速度提升 20-50%
# - 需要 PyTorch 2.0+
print("torch.compile 加速:compiled_model = torch.compile(model)")批量生成与异步队列
import asyncio
from concurrent.futures import ThreadPoolExecutor
class AsyncImageGenerator:
"""异步图像生成服务——生产环境推荐模式"""
def __init__(self, pipeline, max_batch_size=4):
self.pipeline = pipeline
self.max_batch_size = max_batch_size
self.executor = ThreadPoolExecutor(max_workers=2)
async def generate(self, prompt, **kwargs):
"""异步生成单张图像"""
loop = asyncio.get_event_loop()
image = await loop.run_in_executor(
self.executor,
lambda: self.pipeline(prompt, **kwargs).images[0]
)
return image
async def generate_batch(self, prompts, **kwargs):
"""批量生成多张图像,利用 GPU 并行性"""
loop = asyncio.get_event_loop()
# 分批处理,避免显存溢出
results = []
for i in range(0, len(prompts), self.max_batch_size):
batch = prompts[i:i + self.max_batch_size]
images = await loop.run_in_executor(
self.executor,
lambda b=batch: self.pipeline(b, **kwargs).images
)
results.extend(images)
return results
print("异步批量生成服务模式已定义")评估指标
def compute_fid(real_features, generated_features):
"""
FID(Fréchet Inception Distance)计算示意
衡量生成分布与真实分布的距离,越低越好
"""
import numpy as np
from scipy import linalg
mu_real = np.mean(real_features, axis=0)
mu_gen = np.mean(generated_features, axis=0)
sigma_real = np.cov(real_features, rowvar=False)
sigma_gen = np.cov(generated_features, rowvar=False)
diff = mu_real - mu_gen
covmean, _ = linalg.sqrtm(sigma_real @ sigma_gen, disp=False)
if np.iscomplexobj(covmean):
covmean = covmean.real
fid = diff @ diff + np.trace(sigma_real + sigma_gen - 2 * covmean)
return fid
# FID 参考值:
# - < 10: 非常好
# - 10-50: 好
# - 50-100: 一般
# - > 100: 差
print("FID 评估函数已定义")优点
缺点
总结
扩散模型是当前图像生成领域的主流技术,Stable Diffusion、DALL-E 3 等产品都基于扩散模型。理解前向加噪、反向去噪和条件引导的原理,有助于在实际项目中有效使用和定制扩散模型。核心优化方向包括:使用潜空间扩散降低计算量、采用 CFG 提升条件生成质量、选择合适的采样器(DPM++/DDIM)加速推理、通过 LoRA 在小数据上高效微调。
关键知识点
- 噪声调度(beta schedule)控制加噪速度,常见方案有线性、余弦和偏移余弦调度
- Classifier-Free Guidance(CFG)通过在条件和无条件预测之间插值来平衡生成质量和多样性
- 潜空间扩散(Latent Diffusion)先编码到低维潜空间再做扩散,大幅降低计算成本
- DDIM 采样器通过跳步加速推理,可在 20-50 步内获得高质量结果
- EMA(指数移动平均)在训练时维护模型权重的滑动平均,生成质量更稳定
- Min-SNR 加权策略可以平衡不同噪声步的损失贡献,提升训练效率
项目落地视角
- 使用 Stable Diffusion 的 LoRA 微调而非从零训练,可在消费级 GPU 上定制风格
- 推理时使用 xFormers、TensorRT 加速,配合 DDIM/DPM++ 采样器减少步数
- 上线前必须准备 NSFW 过滤、水印添加和生成内容审核机制
- 生产环境建议使用异步队列管理 GPU 资源,避免请求堆积
- 监控生成延迟、GPU 利用率和请求成功率,建立告警机制
常见误区
- 认为扩散模型可以完全替代 GAN——在实时生成和风格迁移场景中 GAN 仍有速度优势
- 忽略采样步数与质量的权衡——不是步数越多越好,通常 30-50 步即可
- 直接在高分辨率空间做扩散——应该用潜空间扩散降低计算量
- CFG Scale 越大越好——过高的 CFG(>15)会导致色彩失真和伪影
- 忽略 VAE 解码器的质量——有时生成问题不是扩散模型本身,而是 VAE 解码
进阶路线
- 学习 ControlNet、T2I-Adapter 等空间条件控制技术
- 深入理解 DiT(Diffusion Transformer)架构,它是 Sora 等视频生成模型的基础
- 探索视频扩散模型(Sora、CogVideo)的时空注意力设计
- 研究一致性模型(Consistency Models)等加速采样方法
- 学习扩散模型的分数蒸馏(Score Distillation)用于三维生成
适用场景
- 文本到图像、图像到图像的创意生成
- 数据增强:为训练集生成多样化的合成样本
- 图像修复、超分辨率和风格迁移
- 设计稿生成、产品原型可视化
- 游戏/影视中的素材生成
落地建议
- 优先基于 Stable Diffusion 生态(Diffusers 库)开发,避免重复造轮子
- 用 LoRA 或 Textual Inversion 在小数据上快速定制,而非全量微调
- 推理服务用批量请求和异步队列管理 GPU 资源
- 建立 A/B 测试机制对比不同采样器、步数和 CFG 的效果
- 预留审核和过滤接口,确保生成内容符合业务规范
排错清单
- 生成图像模糊:增大 CFG Scale(7-15),检查 VAE 解码是否正常
- 生成内容不符合文本描述:优化提示词或使用更强的文本编码器
- 训练 loss 不收敛:检查学习率(通常 1e-4 ~ 1e-5)和噪声调度设置
- 生成速度太慢:切换到 DPM++ 2M 采样器,减少步数到 20-30
- 显存溢出:启用 xFormers,减小批量大小,或使用低精度(FP16)
- 生成结果重复性高:增大随机种子变化范围,检查 CFG Scale 是否过高
复盘问题
- 你使用的采样器(DDIM/DPM++/Euler)对效果和速度的影响是否做过对比?
- 生成图像的一致性和可控性是否满足业务需求?
- 推理成本是否在可接受范围?是否需要考虑蒸馏或加速方案?
- 生成内容的审核和过滤机制是否完善?
- 模型的版本管理和回滚方案是否到位?
