你的AI“画室”:GANs最简实操教程
生成对抗网络(GANs)自诞生以来,便以其独特的对抗性训练机制彻底革新了人工智能生成内容的范式。本文通过一个简洁的MNIST手写数字生成实验,直观地展示了GANs如何从随机噪声中学习并创造出有意义的图像。更重要的是,这一核心技术正是当前文生图(Text-to-Image)领域取得突破性进展的基石。GANs的生成器扮演着“艺术家”的角色,从文本描述中提取语义信息,并将其转化为视觉特征,逐步构建出符合
前言摘要
生成对抗网络(GANs)是深度学习领域一项引人入胜的创新,它通过模拟两位艺术家(生成器)和评论家(判别器)之间的“博弈”,实现了从无到有地创造出逼真数据。本文将通过一个最简化的实操教程,引导你亲身体验GANs的魅力。我们将利用Google Colab提供的免费GPU算力,从零开始构建并训练一个GAN模型,让它学会生成MNIST手写数字。你将亲眼见证模型如何从随机噪声中逐渐“绘制”出可识别的数字,从而直观理解生成器与判别器如何协同进化,感受对抗性训练的动态平衡。

一、实验目标与所需工具
本实验旨在让你:
- 理解生成器(Generator)和判别器(Discriminator)如何协同工作。
- 观察GANs在训练过程中如何从随机噪声中学习生成有意义的数据(例如手写数字)。
- 感受对抗训练的动态平衡。
为了顺利进行实验,你需要准备:
- Google 账号:用于登录Google Colab。
- Google Colab:提供免费GPU算力,让你无需配置本地环境。
- 一点点耐心!
二、实验步骤
1. 打开Google Colab并切换到GPU运行时
首先,访问https://colab.research.google.com/。 点击“文件”(File) -> “新建笔记本”(New notebook)。
为了加速图像生成过程,你需要切换到GPU运行时。点击“运行时”(Runtime) -> “更改运行时类型”(Change runtime type)。在“硬件加速器”(Hardware accelerator)下拉菜单中选择“GPU”,然后点击“保存”。
2. 开始你的第一个GAN实验!
在Colab的代码单元格中,按照以下标题复制粘贴代码,并逐个运行。每个代码块运行完毕后,都会有相应的输出或提示。
三、代码实操环节
1. 导入必要的库
我们将导入PyTorch及其相关模块,包括神经网络层(nn)、优化器(optim)、数据集(datasets)、图像变换(transforms)、数据加载器(DataLoader),以及用于绘图的matplotlib和数值计算的numpy。
Python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
print("所有库已成功导入!")
2. 设置超参数与设备
这些参数是GAN模型行为和训练过程的关键控制点:
latent_dim: 噪声向量的维度,是生成器的输入,可以想象成艺术家的“灵感”源泉。img_size: 图像的尺寸(MNIST手写数字是28x28像素)。img_channels: 图像通道数(MNIST是灰度图,所以通道数是1)。batch_size: 每次训练喂入模型的图片数量。num_epochs: 训练的总轮数。lr: 学习率,决定了模型参数更新的步长。
Python
# 设备配置:优先使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"当前使用的设备: {device}")
# 模型超参数
latent_dim = 100 # 生成器输入噪声的维度
img_size = 28 # MNIST图片尺寸
img_channels = 1 # MNIST是灰度图
batch_size = 128 # 每次训练处理的图片数量
num_epochs = 50 # 训练的总轮数
lr = 0.0002 # 学习率
3. 数据准备:MNIST手写数字
我们将使用经典的MNIST手写数字数据集。transforms.Compose定义了一系列预处理操作,包括将图片转换为PyTorch的Tensor格式,并使用transforms.Normalize((0.5,), (0.5,))将像素值从[0, 1]归一化到[-1, 1],这与生成器输出层通常使用的Tanh激活函数范围匹配,有助于模型更快收敛。
Python
transform = transforms.Compose([
transforms.ToTensor(), # 将图片转换为Tensor
transforms.Normalize((0.5,), (0.5,)) # 归一化到[-1, 1]
])
# 下载并加载MNIST训练集
mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True, drop_last=True) # drop_last=True确保每个batch有相同数量的样本
print(f"MNIST数据集加载完成,包含 {len(mnist_dataset)} 张图片。")
4. 构建生成器(The Artist / 画家)
生成器的任务是从随机噪声中生成逼真的图片。我们将使用简单的全连接神经网络(MLP):
- 输入:
latent_dim(100) 维的随机噪声向量。 - 输出:
img_size * img_size * img_channels(28*28*1 = 784) 维的图像向量,最终会reshape成28x28的图片。 - 激活函数:中间层使用LeakyReLU;输出层使用Tanh,将像素值缩放到[-1, 1],与归一化后的真实图片范围匹配。
Python
class Generator(nn.Module):
def __init__(self, latent_dim, img_size, img_channels):
super(Generator, self).__init__()
self.img_dim = img_size * img_size * img_channels
self.model = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, self.img_dim),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
# 将一维的图像向量重新塑形为图片 (batch_size, channels, height, width)
img = img.view(img.size(0), img_channels, img_size, img_size)
return img
# 初始化生成器并将其移动到指定设备
generator = Generator(latent_dim, img_size, img_channels).to(device)
print("生成器模型已创建。")
5. 构建判别器(The Critic / 鉴赏家)
判别器的任务是区分真实图片和生成器生成的假图片。它也是一个简单的全连接神经网络:
- 输入:
img_size * img_size * img_channels(784) 维的图像向量。 - 输出:一个单一的标量值,代表这张图片是真实的概率。
- 激活函数:中间层使用LeakyReLU;输出层不直接使用Sigmoid,而是使用BCEWithLogitsLoss(它内部包含了Sigmoid和二元交叉熵损失的计算,更稳定)。
Python
class Discriminator(nn.Module):
def __init__(self, img_size, img_channels):
super(Discriminator, self).__init__()
self.img_dim = img_size * img_size * img_channels
self.model = nn.Sequential(
nn.Linear(self.img_dim, 1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.3), # Dropout层,防止过拟合
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, 1) # 1个输出(真实/虚假概率)
)
def forward(self, img):
# 将图片展平为一维向量
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
# 初始化判别器并将其移动到指定设备
discriminator = Discriminator(img_size, img_channels).to(device)
print("判别器模型已创建。")
6. 定义损失函数与优化器
- 损失函数:我们使用BCEWithLogitsLoss(二元交叉熵损失),它适用于二分类问题,并包含Sigmoid层,更稳定。判别器希望真实图片预测为1,虚假图片预测为0。生成器希望它生成的虚假图片被判别器预测为1(也就是成功骗过判别器)。
- 优化器:使用Adam优化器,它在处理稀疏梯度和非平稳目标时表现良好。为生成器和判别器分别设置优化器。
Python
# 损失函数:二元交叉熵损失
criterion = nn.BCEWithLogitsLoss().to(device)
# 优化器:分别优化生成器和判别器
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)
print("损失函数和优化器已设置。")
7. 训练循环:对抗博弈的核心
这是GANs训练最激动人心的部分,生成器和判别器将在这里展开一场永无止境的“猫鼠游戏”。
- 判别器训练:
- 喂入真实图片:判别器接收真实图片,并期望它们被标记为“真实”(标签为1)。
- 喂入虚假图片:生成器生成图片,判别器接收这些图片,并期望它们被标记为“虚假”(标签为0)。
- 通过计算真实图片和虚假图片的损失,判别器更新自己的参数,以便更好地分辨真伪。
- 生成器训练:
- 生成器生成图片,并将其送入判别器。
- 生成器希望这些虚假图片被判别器标记为“真实”(标签为1),因为它的目标是“骗过”判别器。
- 根据判别器的输出,生成器更新自己的参数,以便生成更逼真的图片。
Python
# 用于可视化生成结果的固定噪声,这样每次都能看到相同“灵感”下生成器的进步
fixed_noise = torch.randn(64, latent_dim).to(device)
# 训练记录
g_losses = [] # 记录生成器损失
d_losses = [] # 记录判别器损失
img_list = [] # 记录周期性生成的图片
print("开始训练GAN模型...")
for epoch in range(num_epochs):
for i, (imgs, _) in enumerate(dataloader):
# 将真实图片移动到指定设备,并展平为一维向量
real_imgs = imgs.view(imgs.size(0), -1).to(device)
# 定义真实和虚假的标签
real_labels = torch.ones(imgs.size(0), 1).to(device)
fake_labels = torch.zeros(imgs.size(0), 1).to(device)
# ---------------------
# 训练判别器 D
# ---------------------
optimizer_D.zero_grad() # 清零判别器的梯度
# 1. 训练判别器识别真实图片
output_real = discriminator(real_imgs)
d_loss_real = criterion(output_real, real_labels) # 真实图片希望是1
# 2. 训练判别器识别虚假图片
# 从标准正态分布中采样噪声
noise = torch.randn(imgs.size(0), latent_dim).to(device)
# 用生成器生成虚假图片
fake_imgs = generator(noise).detach() # detach() 阻止梯度流回生成器
output_fake = discriminator(fake_imgs)
d_loss_fake = criterion(output_fake, fake_labels) # 虚假图片希望是0
# 合并真实和虚假损失,并反向传播
d_loss = (d_loss_real + d_loss_fake) / 2
d_loss.backward()
optimizer_D.step() # 更新判别器参数
# ---------------------
# 训练生成器 G
# ---------------------
optimizer_G.zero_grad() # 清零生成器的梯度
# 生成新的虚假图片
# 这一次,我们不使用 detach(),让梯度流回生成器
noise = torch.randn(imgs.size(0), latent_dim).to(device)
generated_imgs = generator(noise)
# 生成器希望判别器将这些虚假图片判为“真实”(标签为1)
output_gen = discriminator(generated_imgs)
g_loss = criterion(output_gen, real_labels) # 生成器希望它的图片被判为1
g_loss.backward()
optimizer_G.step() # 更新生成器参数
# ---------------------
# 打印训练进度
# ---------------------
if i % 100 == 0: # 每100个batch打印一次
print(
f"Epoch [{epoch}/{num_epochs}] Batch [{i}/{len(dataloader)}] "
f"D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}"
)
g_losses.append(g_loss.item())
d_losses.append(d_loss.item())
# 每个Epoch结束后,用固定噪声生成图片,并保存以观察生成效果
generator.eval() # 切换到评估模式,不进行Dropout等操作
with torch.no_grad(): # 在此步禁用梯度计算
generated_samples = generator(fixed_noise).cpu().detach()
# 归一化回[0, 1]范围以便显示
generated_samples = (generated_samples + 1) / 2
img_list.append(generated_samples)
generator.train() # 切换回训练模式
print("训练完成!")
8. 可视化训练结果
我们可以绘制损失曲线,看看生成器和判别器在训练过程中是如何“对抗”的。
Python
# 绘制损失曲线
plt.figure(figsize=(10, 5))
plt.plot(d_losses, label='Discriminator Loss')
plt.plot(g_losses, label='Generator Loss')
plt.xlabel('Batch Index (approx)')
plt.ylabel('Loss')
plt.title('GAN Training Loss')
plt.legend()
plt.grid(True)
plt.show()
print("损失曲线图已显示。")
9. 可视化生成图片随Epoch的演变
观看生成器如何从一团模糊的噪声,逐渐学会“画”出清晰的手写数字,这将是整个实验中最令人兴奋的时刻!
Python
# 可视化生成图片的演变过程
def show_images(images, epoch):
fig = plt.figure(figsize=(8, 8))
for i in range(images.shape[0]):
plt.subplot(8, 8, i+1)
plt.imshow(images[i, 0, :, :], cmap='gray')
plt.axis('off')
plt.suptitle(f'Generated Images at Epoch {epoch}')
plt.show()
print("\n展示生成图片随Epoch的演变...")
for i, imgs in enumerate(img_list):
if i % 5 == 0 or i == len(img_list) - 1: # 每5个epoch或最后一个epoch显示一次
show_images(imgs, (i+1)) # (i+1) 是为了匹配 epoch 从 1 开始








四、实验结果解读与原理学习
1. 观察损失曲线
你会看到判别器损失 (D Loss) 和生成器损失 (G Loss) 在一开始可能会波动很大。理想情况下,它们会相互交织,并趋于一个相对稳定的状态,判别器损失大概会在 0.5 左右,生成器损失则根据实际情况。这表明双方达到了一个“纳什均衡”,谁也无法完全击败对方。
- 如果 D Loss 迅速降到接近0而 G Loss 很高,说明判别器太强,生成器“学不会”了。
- 如果 G Loss 迅速降到接近0而 D Loss 很高,说明生成器太强,判别器“辨别不出”了(这通常很难发生)。
2. 观察生成图片
- Epoch 0 或 Epoch 1: 生成的图片会非常模糊,看起来就是一团随机的像素点,难以辨认出任何数字。
- 中期 Epochs (例如 10-30): 你会开始看到一些模糊的数字轮廓出现,生成器正在学习数字的形状和结构。判别器也变得更擅长区分这些“粗糙”的假图片。
- 后期 Epochs (例如 40-50): 生成的数字会变得越来越清晰,越来越像真实的MNIST手写数字。虽然可能不如真实图片完美,但已经具有很高的逼真度,甚至你可能无法一眼分辨出哪些是生成的。
3. 你学到了什么?
通过这个简单的实验,你亲身体验了GANs的以下核心原理:
- 对抗性训练(Adversarial Training):生成器和判别器像两个对手一样相互竞争,一个努力创造,一个努力辨别。正是这种持续的对抗,推动了双方能力的共同提高。
- 生成器的学习:生成器通过判别器的“反馈”(损失信号),不断调整自己的参数,使其能够将简单的随机噪声转化为具有复杂结构和纹理的图像数据。
- 判别器的学习:判别器通过区分真实和虚假数据,学习数据分布的细微特征,变得越来越“挑剔”。
- 从噪声到数据:这种模型能够从完全无序的噪声中,学习并生成出与真实数据分布相似的新数据,这是其强大之处。
这个实验虽然简单,但它揭示了GANs背后的强大思想。在实际应用中,GANs的生成器和判别器会使用更复杂的神经网络结构(如卷积神经网络DCGAN、StyleGAN等),处理更复杂的数据集,从而生成出更令人惊叹的图像。但万变不离其宗,其核心的对抗学习原理是一致的。
更多推荐



所有评论(0)