GAN 基础
大约 16 分钟约 4656 字
GAN 基础
简介
生成对抗网络(Generative Adversarial Network,GAN)通过生成器和判别器的对抗训练来学习数据分布。生成器努力产生逼真的假样本,判别器努力区分真假样本,二者在博弈中共同进步。GAN 在图像生成、风格迁移、超分辨率和数据增强等领域有广泛应用,虽然扩散模型在图像质量上已超越 GAN,但 GAN 在推理速度上仍有明显优势。
GAN 由 Ian Goodfellow 在 2014 年提出,灵感来自伪造者与鉴定专家的博弈过程。这个看似简单的思想产生了深远的影响:GAN 的对抗训练范式不仅推动了图像生成技术的发展,还影响了对抗训练、域适应、半监督学习等多个领域。Goodfellow 因此被称为"GAN 之父"。
从数学角度看,GAN 的训练过程等价于在两个参数化分布之间最小化 Jensen-Shannon 散度。生成器定义了一个参数化的分布 p_g(z),判别器定义了一个二分类器。当训练达到纳什均衡时,生成器产生的分布 p_g 与真实数据分布 p_data 完全一致,判别器无法区分真假(输出 0.5)。
特点
GAN 家族图谱
- 原始 GAN (2014): 基础框架,全连接网络
- DCGAN (2015): 使用卷积网络,稳定训练
- Conditional GAN (cGAN): 条件生成(类别、文本)
- WGAN (2017): Wasserstein 距离,更稳定的训练
- WGAN-GP (2017): 梯度惩罚替代权重裁剪
- Pix2Pix (2017): 配对图像翻译
- CycleGAN (2017): 无配对风格迁移
- Progressive GAN (2018): 渐进式增长训练
- StyleGAN (2019): 风格注入,高质量人脸生成
- StyleGAN2/3 (2020-2021): 消除伪影,Alias-Free
实现
# 示例1:构建 DCGAN 的生成器和判别器
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, latent_dim=100, img_channels=3, base_features=64):
super().__init__()
self.net = nn.Sequential(
# 输入: (latent_dim, 1, 1) -> 输出: (base_features*8, 4, 4)
nn.ConvTranspose2d(latent_dim, base_features * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(base_features * 8),
nn.ReLU(True),
# -> (base_features*4, 8, 8)
nn.ConvTranspose2d(base_features * 8, base_features * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(base_features * 4),
nn.ReLU(True),
# -> (base_features*2, 16, 16)
nn.ConvTranspose2d(base_features * 4, base_features * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(base_features * 2),
nn.ReLU(True),
# -> (base_features, 32, 32)
nn.ConvTranspose2d(base_features * 2, base_features, 4, 2, 1, bias=False),
nn.BatchNorm2d(base_features),
nn.ReLU(True),
# -> (img_channels, 64, 64)
nn.ConvTranspose2d(base_features, img_channels, 4, 2, 1, bias=False),
nn.Tanh(),
)
def forward(self, z):
return self.net(z.view(z.size(0), -1, 1, 1))
class Discriminator(nn.Module):
def __init__(self, img_channels=3, base_features=64):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(img_channels, base_features, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, True),
nn.Conv2d(base_features, base_features * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(base_features * 2),
nn.LeakyReLU(0.2, True),
nn.Conv2d(base_features * 2, base_features * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(base_features * 4),
nn.LeakyReLU(0.2, True),
nn.Conv2d(base_features * 4, 1, 4, 1, 0, bias=False),
nn.Sigmoid(),
)
def forward(self, img):
return self.net(img).view(-1)
G = Generator(latent_dim=100, img_channels=3)
D = Discriminator(img_channels=3)
z = torch.randn(4, 100)
fake_img = G(z)
print(f"生成图像: {fake_img.shape}, 判别输出: {D(fake_img).shape}")# 示例2:GAN 训练循环
import torch.optim as optim
def train_gan(G, D, dataloader, epochs=50, latent_dim=100, lr=2e-4, device='cpu'):
G, D = G.to(device), D.to(device)
criterion = nn.BCELoss()
opt_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
for epoch in range(epochs):
d_loss_total, g_loss_total = 0, 0
for real_imgs in dataloader:
batch_size = real_imgs.size(0)
real_imgs = real_imgs.to(device)
real_labels = torch.ones(batch_size, device=device)
fake_labels = torch.zeros(batch_size, device=device)
# === 训练判别器 ===
opt_D.zero_grad()
# 真实图片
d_real = D(real_imgs)
loss_real = criterion(d_real, real_labels)
# 生成假图片
z = torch.randn(batch_size, latent_dim, device=device)
fake_imgs = G(z).detach()
d_fake = D(fake_imgs)
loss_fake = criterion(d_fake, fake_labels)
d_loss = (loss_real + loss_fake) / 2
d_loss.backward()
opt_D.step()
# === 训练生成器 ===
opt_G.zero_grad()
z = torch.randn(batch_size, latent_dim, device=device)
fake_imgs = G(z)
d_fake = D(fake_imgs)
g_loss = criterion(d_fake, real_labels) # 目标是骗过判别器
g_loss.backward()
opt_G.step()
d_loss_total += d_loss.item()
g_loss_total += g_loss.item()
print(f"Epoch {epoch+1}: D_loss={d_loss_total/len(dataloader):.4f}, "
f"G_loss={g_loss_total/len(dataloader):.4f}")# 示例3:WGAN-GP 改进训练稳定性
import torch.autograd as autograd
def gradient_penalty(D, real, fake, device='cpu'):
"""Wasserstein 距离的梯度惩罚"""
alpha = torch.rand(real.size(0), 1, 1, 1, device=device)
interpolates = (alpha * real + (1 - alpha) * fake).requires_grad_(True)
d_interp = D(interpolates)
gradients = autograd.grad(
outputs=d_interp, inputs=interpolates,
grad_outputs=torch.ones_like(d_interp),
create_graph=True, retain_graph=True,
)[0]
gradients = gradients.view(gradients.size(0), -1)
gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gp
# WGAN-GP 判别器损失(使用 Wasserstein 距离而非 BCE)
def wgan_d_loss(D, real, fake, lambda_gp=10):
d_real = D(real)
d_fake = D(fake)
gp = gradient_penalty(D, real, fake)
return d_fake.mean() - d_real.mean() + lambda_gp * gp
# WGAN-GP 生成器损失
def wgan_g_loss(D, fake):
return -D(fake).mean()# 示例4:隐空间插值生成过渡图像
@torch.no_grad()
def interpolate_latent(G, z1, z2, steps=10):
"""在隐空间中线性插值,生成平滑过渡的图像"""
G.eval()
ratios = torch.linspace(0, 1, steps).unsqueeze(1)
z_interp = z1 * (1 - ratios) + z2 * ratios # 线性插值
images = G(z_interp)
return images
# 演示
z1 = torch.randn(1, 100)
z2 = torch.randn(1, 100)
transition = interpolate_latent(G, z1, z2, steps=8)
print(f"插值生成图像数量: {transition.shape[0]}")
print("这些图像展示了从 z1 到 z2 的平滑过渡")深入理解:GAN 训练的常见问题与解决方案
模式崩塌
def explain_mode_collapse():
"""模式崩塌(Mode Collapse)详解
现象:生成器只产生少数几种样本(甚至只有一种),丢失了数据分布的多样性。
原因:
1. 生成器找到了判别器的"弱点",只生成容易骗过判别器的少数样本
2. 判别器过于强大,生成器无法探索多样化的输出
3. 梯度信号过于稀疏,生成器无法获得足够的学习信号
检测方法:
1. 可视化:观察生成样本是否只有几种模式
2. FID 分数:FID 突然上升可能意味着模式崩塌
3. 多样性统计:计算生成样本的特征方差
解决方案:
1. Minibatch Discrimination:让判别器同时考虑一个 batch 的多样性
2. Unrolled GAN:生成器更新时考虑判别器的未来反应
3. WGAN-GP:使用 Wasserstein 距离提供更好的梯度信号
4. 增加噪声:在判别器输入中添加噪声
5. 多个生成器:使用混合生成器增加多样性
"""
print("模式崩塌的解决方案:")
print(" 1. 使用 WGAN-GP 替代原始 GAN")
print(" 2. 降低判别器容量或学习率")
print(" 3. 增加隐空间噪声")
print(" 4. 使用 Minibatch Discrimination")
print(" 5. 监控生成样本的多样性指标")
explain_mode_collapse()条件 GAN (cGAN)
import torch
import torch.nn as nn
class ConditionalGenerator(nn.Module):
"""条件 GAN:通过条件信息控制生成内容
cGAN 在生成器和判别器中都输入条件信息(类别标签、文本描述等),
使得生成器可以根据条件生成特定类别的样本。
数学形式:
G(z, y) -> 生成样本(z 是噪声,y 是条件)
D(x, y) -> 真假判断(x 是样本,y 是条件)
应用场景:
- 类别条件生成:生成特定类别的图像
- 文本到图像:根据文本描述生成图像
- 图像到图像:将一种图像转换为另一种(Pix2Pix)
"""
def __init__(self, latent_dim=100, num_classes=10, img_channels=3, base_features=64):
super().__init__()
# 将条件标签映射为嵌入向量
self.label_embedding = nn.Embedding(num_classes, latent_dim)
self.net = nn.Sequential(
nn.ConvTranspose2d(latent_dim * 2, base_features * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(base_features * 8),
nn.ReLU(True),
nn.ConvTranspose2d(base_features * 8, base_features * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(base_features * 4),
nn.ReLU(True),
nn.ConvTranspose2d(base_features * 4, base_features * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(base_features * 2),
nn.ReLU(True),
nn.ConvTranspose2d(base_features * 2, img_channels, 4, 2, 1, bias=False),
nn.Tanh(),
)
def forward(self, z, labels):
# 拼接噪声向量和标签嵌入
label_emb = self.label_embedding(labels)
x = torch.cat([z, label_emb], dim=1)
return self.net(x.view(x.size(0), -1, 1, 1))
cG = ConditionalGenerator(latent_dim=100, num_classes=10)
z = torch.randn(4, 100)
labels = torch.randint(0, 10, (4,))
fake_imgs = cG(z, labels)
print(f"条件生成: 噪声{z.shape} + 标签{labels.shape} -> 图像{fake_imgs.shape}")GAN 评估指标
def explain_gan_evaluation():
"""GAN 评估指标
GAN 的评估比分类/检测更困难,因为生成质量是主观的。
主流指标:
1. FID (Frechet Inception Distance):
- 计算真实图像和生成图像在 Inception 网络特征空间中的距离
- 越低越好(0 = 完美)
- 典型范围:StyleGAN2 人脸 ~2.7
2. IS (Inception Score):
- 衡量生成图像的清晰度和多样性
- 越高越好
- 局限:不与真实数据比较
3. LPIPS (Learned Perceptual Image Patch Similarity):
- 基于深度学习的感知相似度
- 越低表示图像越相似
4. Precision & Recall (P&R):
- Precision: 生成样本的质量(是否真实)
- Recall: 生成样本覆盖真实分布的程度
- 可以同时评估质量和多样性
5. CLIP Score:
- 使用 CLIP 模型评估文本-图像一致性
- 适合文本条件生成任务
"""
print("GAN 评估指标选择建议:")
print(" 通用: FID (质量+多样性的综合指标)")
print(" 多样性: Precision & Recall")
print(" 文本条件: CLIP Score")
print(" 感知质量: LPIPS")
print(" 实际部署: 人工评估 + A/B 测试")
explain_gan_evaluation()GAN vs 扩散模型
def gan_vs_diffusion():
"""GAN vs 扩散模型的对比
GAN 的优势:
- 推理速度:1 次前向传播 vs 扩散模型的 20-1000 步采样
- 实时性:适合视频生成、实时滤镜
- 生成质量:在特定领域(人脸)仍然领先
扩散模型的优势:
- 训练稳定性:不需要对抗训练
- 生成多样性:更好的覆盖数据分布
- 文本控制:DALL-E、Stable Diffusion 的精确文本条件
- 评估指标:FID 通常更好
选择建议:
- 需要实时推理 -> GAN
- 需要高质量+多样性 -> 扩散模型
- 需要精确文本控制 -> 扩散模型
- 特定领域高质量 -> GAN (如 StyleGAN 人脸)
"""
print("场景选择:")
print(" 实时视频滤镜/游戏 -> GAN")
print(" 文生图/艺术创作 -> 扩散模型")
print(" 人脸生成/编辑 -> StyleGAN")
print(" 数据增强 -> GAN (速度快)")
print(" 超分辨率 -> 两者都可 (Real-ESRGAN vs StableSR)")
gan_vs_diffusion()Pix2Pix 图像翻译实现
class UNetDown(nn.Module):
"""UNet 下采样块(用于 Pix2Pix 生成器)"""
def __init__(self, in_channels, out_channels, use_batchnorm=True):
super().__init__()
layers = [
nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False),
]
if use_batchnorm:
layers.append(nn.BatchNorm2d(out_channels))
layers.append(nn.LeakyReLU(0.2, True))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
class UNetUp(nn.Module):
"""UNet 上采样块(用于 Pix2Pix 生成器)"""
def __init__(self, in_channels, out_channels, use_dropout=False):
super().__init__()
layers = [
nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
]
if use_dropout:
layers.append(nn.Dropout(0.5))
self.net = nn.Sequential(*layers)
def forward(self, x, skip):
x = self.net(x)
return torch.cat([x, skip], dim=1)
class Pix2PixGenerator(nn.Module):
"""Pix2Pix 生成器:UNet 架构
用于配对图像翻译任务:
- 语义分割图 -> 真实图像
- 黑白图像 -> 彩色图像
- 白天 -> 夜晚
- 草图 -> 真实图像
"""
def __init__(self, in_channels=3, out_channels=3, base_features=64):
super().__init__()
# 编码器
self.down1 = UNetDown(in_channels, base_features, use_batchnorm=False)
self.down2 = UNetDown(base_features, base_features * 2)
self.down3 = UNetDown(base_features * 2, base_features * 4)
self.down4 = UNetDown(base_features * 4, base_features * 8)
self.down5 = UNetDown(base_features * 8, base_features * 8)
self.down6 = UNetDown(base_features * 8, base_features * 8)
self.down7 = UNetDown(base_features * 8, base_features * 8, use_batchnorm=False)
# 瓶颈层
self.bottleneck = nn.Sequential(
nn.Conv2d(base_features * 8, base_features * 8, 4, 2, 1),
nn.ReLU(True)
)
# 解码器(带跳跃连接)
self.up1 = UNetUp(base_features * 8, base_features * 8, use_dropout=True)
self.up2 = UNetUp(base_features * 16, base_features * 8, use_dropout=True)
self.up3 = UNetUp(base_features * 16, base_features * 8, use_dropout=True)
self.up4 = UNetUp(base_features * 16, base_features * 8)
self.up5 = UNetUp(base_features * 16, base_features * 4)
self.up6 = UNetUp(base_features * 8, base_features * 2)
self.up7 = UNetUp(base_features * 4, base_features)
self.final = nn.Sequential(
nn.ConvTranspose2d(base_features * 2, out_channels, 4, 2, 1),
nn.Tanh()
)
def forward(self, x):
d1 = self.down1(x)
d2 = self.down2(d1)
d3 = self.down3(d2)
d4 = self.down4(d3)
d5 = self.down5(d4)
d6 = self.down6(d5)
d7 = self.down7(d6)
bottleneck = self.bottleneck(d7)
u1 = self.up1(bottleneck, d7)
u2 = self.up2(u1, d6)
u3 = self.up3(u2, d5)
u4 = self.up4(u3, d4)
u5 = self.up5(u4, d3)
u6 = self.up6(u5, d2)
u7 = self.up7(u6, d1)
return self.final(u7)
# 测试 Pix2Pix 生成器
gen = Pix2PixGenerator(in_channels=3, out_channels=3)
input_img = torch.randn(2, 3, 256, 256)
output = gen(input_img)
print(f"Pix2Pix: 输入{input_img.shape} -> 输出{output.shape}")
print(f"参数量: {sum(p.numel() for p in gen.parameters()):,}")GAN 训练稳定性技巧总结
def gan_training_tricks():
"""GAN 训练稳定性技巧清单"""
tricks = {
"网络架构": [
"判别器使用 LeakyReLU(0.2 斜率),不要用 ReLU",
"生成器使用 ReLU,输出层用 Tanh",
"使用 BatchNorm(但生成器输出层和判别器输入层不放)",
"使用 Spectral Normalization 稳定判别器训练",
"生成器使用转置卷积或 PixelShuffle 上采样",
],
"训练策略": [
"Adam 优化器,lr=2e-4,beta1=0.5(不是默认的 0.9)",
"判别器训练步数多于生成器(如 5:1)",
"使用标签平滑:真实标签 0.9 而非 1.0",
"添加噪声到判别器输入",
"使用 Progressive Growing 逐步增加分辨率",
],
"损失函数": [
"原始 GAN 使用 BCE Loss",
"WGAN-GP 使用 Wasserstein 距离 + 梯度惩罚",
"最小二乘 GAN (LSGAN) 使用 MSE Loss",
"Hinge Loss 在 SAGAN/BigGAN 中效果良好",
],
"正则化": [
"Spectral Normalization:限制判别器 Lipschitz 常数",
"梯度惩罚(GP):限制判别器梯度范数接近 1",
"DRAGAN:在真实数据附近添加梯度惩罚",
]
}
for category, items in tricks.items():
print(f"\n=== {category} ===")
for item in items:
print(f" - {item}")
gan_training_tricks()Spectral Normalization 实现
class SpectralNormConv2d(nn.Module):
"""带谱归一化的卷积层(简化版)
谱归一化通过限制权重矩阵的最大奇异值来约束判别器的 Lipschitz 常数,
从而稳定 GAN 训练。PyTorch 内置了 nn.utils.spectral_norm。
"""
def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1):
super().__init__()
self.conv = nn.utils.spectral_norm(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
)
self.activation = nn.LeakyReLU(0.2, True)
def forward(self, x):
return self.activation(self.conv(x))
class SNDiscriminator(nn.Module):
"""使用 Spectral Normalization 的判别器
SN-GAN 的核心改进:在每个卷积层上应用谱归一化,
省去了 WGAN-GP 中计算梯度惩罚的额外开销。
"""
def __init__(self, img_channels=3, base_features=64):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(img_channels, base_features, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, True),
SpectralNormConv2d(base_features, base_features * 2),
SpectralNormConv2d(base_features * 2, base_features * 4),
SpectralNormConv2d(base_features * 4, base_features * 8),
nn.Conv2d(base_features * 8, 1, 4, 1, 0, bias=False),
)
def forward(self, x):
return self.net(x).view(-1)
sn_D = SNDiscriminator()
print(f"SN-Discriminator 参数量: {sum(p.numel() for p in sn_D.parameters()):,}")GAN Inversion(隐空间逆向映射)
class GANInversion:
"""GAN Inversion:将真实图像映射回隐空间
应用场景:
- 图像编辑:找到隐向量后,通过编辑隐向量来修改图像属性
- 图像混合:将两张图像的隐向量混合
- 风格迁移:交换不同层的隐向量
方法分类:
1. 编码器方法:训练一个编码器 E(x) -> z
2. 优化方法:固定 G,优化 z 使 G(z) ≈ x
3. 混合方法:编码器提供初始 z,再优化微调
"""
@staticmethod
def optimize_inversion(G, target_image, latent_dim=100,
steps=500, lr=0.01, device='cpu'):
"""优化方法:通过梯度下降找到最佳隐向量
原理:固定生成器 G,优化噪声向量 z,
使得 G(z) 与目标图像的像素差最小。
"""
G.eval()
z = torch.randn(1, latent_dim, device=device, requires_grad=True)
optimizer = torch.optim.Adam([z], lr=lr)
mse_loss = nn.MSELoss()
losses = []
for step in range(steps):
optimizer.zero_grad()
generated = G(z)
loss = mse_loss(generated, target_image)
loss.backward()
optimizer.step()
losses.append(loss.item())
if step % 100 == 0:
print(f" Step {step}: 重建损失 = {loss.item():.6f}")
return z.detach(), losses
# 演示
print("GAN Inversion: 将真实图像映射回隐空间")
print(" 优化方法: z* = argmin_z ||G(z) - x||^2")
print(" 编码器方法: z = E(x), 训练 E 使 G(E(x)) ≈ x")优点
缺点
总结
GAN 是深度学习中最具影响力的生成模型之一,它的对抗训练思想影响了整个 AI 生成领域。虽然扩散模型在图像生成质量上已占据主导地位,但 GAN 在实时推理、风格迁移和特定领域的生成任务中仍然是重要工具。
关键知识点
- 原始 GAN 的损失函数基于 JS 散度,WGAN 改用 Wasserstein 距离并加入梯度惩罚(GP)来提升稳定性
- 模式崩塌(Mode Collapse)指生成器只产生少数几种样本,是 GAN 训练中最常见的问题
- 条件 GAN(cGAN)通过将条件信息(类别、文本、图像)输入生成器和判别器来控制生成内容
- FID(Frechet Inception Distance)是评估 GAN 生成质量的常用指标,越低越好
项目落地视角
- 用 StyleGAN 做人脸/角色生成时,先检查训练数据的多样性是否足够
- 实时推理场景(视频滤镜、游戏渲染)优先考虑 GAN 而非扩散模型
- 数据增强场景中,GAN 生成样本可以扩充训练集,但要验证不会引入噪声
常见误区
- 认为 GAN 已被扩散模型完全替代——实时生成和特定风格迁移场景中 GAN 仍不可替代
- 训练 GAN 时不监控模式崩塌——应定期检查生成样本的多样性
- 直接在大型数据集上训练而不先用小数据集验证训练管线
- 忘记在判别器中使用 LeakyReLU 而非 ReLU——ReLU 会导致梯度消失
进阶路线
- 深入学习 StyleGAN3 的 Alias-Free 设计,理解如何消除伪影
- 探索 GAN 与扩散模型的结合(如 GAN 作为扩散模型的加速初始化)
- 学习 GAN Inversion,将真实图片映射回隐空间进行编辑
- 研究 GAN 的评估方法(FID、IS、Precision & Recall)的原理和局限
适用场景
- 实时图像生成和视频滤镜应用
- 图像风格迁移(CycleGAN、Pix2Pix)
- 数据增强:为分类或检测模型生成训练样本
- 超分辨率、图像修复和图像编辑
落地建议
- 从 DCGAN 或条件 GAN 的成熟实现开始,避免从零搭建训练管线
- 训练时使用 WGAN-GP 或 SN-GAN 来提升稳定性
- 记录每个检查点的 FID 分数,选择最佳模型而非最后一个
排错清单
- 训练 loss 震荡:降低学习率,检查 Adam 的 beta1 是否设为 0.5
- 生成图像模糊或噪声:增加网络容量,检查是否使用了 BatchNorm
- 模式崩塌:增加噪声注入、尝试 WGAN-GP、增加判别器的正则化
- 判别器 loss 趋近 0:判别器太强,需要降低其容量或学习率
复盘问题
- 你的 GAN 训练是否出现过模式崩塌?如何检测和处理的?
- 生成样本的 FID 分数是多少?是否与主观质量一致?
- 在推理速度和生成质量之间,你的项目更侧重哪个?
