从零开始训练AI文生图模型:一份初学者完全指南
在开始深入技术细节之前,我们先来理解一下什么是文生图(Text-to-Image)模型。简单来说,文生图模型就是能够根据你输入的文字描述,自动生成对应图像的人工智能系统。比如你输入"一只戴着帽子的橙色猫咪",模型就能生成一张符合描述的图片。近年来,像Stable Diffusion、DALL-E、Midjourney这样的文生图模型在互联网上引起了巨大轰动。这些模型背后的核心技术就是扩散模型(Di
从零开始训练AI图像生成模型:一份初学者完全指南
📱 完整流程导航图
在开始学习之前,让我们先通过一个清晰的流程图了解整个训练过程。这个流程图设计为纵向布局,方便在手机上查看。
┌─────────────────────────────────────────┐
│ 🎯 开始训练AI图像生成模型 │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 📚 第一阶段:理解基础概念 │
└─────────────────────────────────────────┘
↓
┌───────────┴───────────┐
↓ ↓
┌──────────────┐ ┌──────────────┐
│ 什么是扩散模型│ │ 什么是VAE? │
│ │ │ │
│ • 正向加噪 │ │ • 编码器 │
│ • 反向去噪 │ │ • 解码器 │
│ • DDPM原理 │ │ • 压缩6-8倍 │
└──────────────┘ └──────────────┘
↓ ↓
└───────────┬───────────┘
↓
┌───────────┴───────────┐
↓ ↓
┌──────────────┐ ┌──────────────┐
│什么是Transformer│ │ 什么是DiT? │
│ │ │ │
│ • 自注意力 │ │ • Diffusion │
│ • 多头注意力 │ │ • Transformer│
│ • Q/K/V机制 │ │ • 可扩展架构 │
└──────────────┘ └──────────────┘
↓
┌─────────────────────────────────────────┐
│ 🏗️ 第二阶段:构建模型结构 │
└─────────────────────────────────────────┘
↓
┌───────────┴───────────┐
↓ ↓
┌──────────────┐ ┌──────────────┐
│ 设计整体架构 │ │ 实现位置编码 │
│ │ │ │
│ • 文本编码器 │ │ • 图像位置 │
│ • VAE编解码 │ │ • 文本位置 │
│ • DiT去噪网络│ │ • 可学习参数 │
└──────────────┘ └──────────────┘
↓ ↓
└───────────┬───────────┘
↓
┌─────────────────────────────────────────┐
│ 实现Transformer块 │
│ │
│ ┌─────────────────────────────────┐ │
│ │ 1. 自注意力层 │ │
│ │ • Query/Key/Value投影 │ │
│ │ • 多头注意力计算 │ │
│ │ • 输出投影 │ │
│ └─────────────────────────────────┘ │
│ ↓ │
│ ┌─────────────────────────────────┐ │
│ │ 2. 前馈网络层 │ │
│ │ • 扩展到3倍维度 │ │
│ │ • SiLU激活函数 │ │
│ │ • 压缩回原始维度 │ │
│ └─────────────────────────────────┘ │
│ ↓ │
│ ┌─────────────────────────────────┐ │
│ │ 3. 自适应门控 │ │
│ │ • 时间步嵌入生成门控 │ │
│ │ • 动态调节层贡献 │ │
│ └─────────────────────────────────┘ │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 组装完整AAADiT模型 │
│ │
│ • 10个Transformer块 │
│ • 总参数量:0.1B (1亿) │
│ • 输入/输出投影层 │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 🔧 第三阶段:构建Pipeline │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 定义Pipeline单元 │
│ │
│ ┌─────────────────────────────────┐ │
│ │ 单元1:文本编码 │ │
│ │ • 使用Qwen3-0.6B │ │
│ │ • 应用聊天模板 │ │
│ │ • 分词和嵌入 │ │
│ └─────────────────────────────────┘ │
│ ↓ │
│ ┌─────────────────────────────────┐ │
│ │ 单元2:噪声初始化 │ │
│ │ • 生成随机噪声 │ │
│ │ • 形状:[1,128,H/16,W/16] │ │
│ │ • 支持随机种子 │ │
│ └─────────────────────────────────┘ │
│ ↓ │
│ ┌─────────────────────────────────┐ │
│ │ 单元3:图像编码 │ │
│ │ • VAE编码输入图像 │ │
│ │ • 支持文生图/图生图 │ │
│ └─────────────────────────────────┘ │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 实现完整Pipeline流程 │
│ │
│ 1. 模型加载管理 │
│ 2. 数据流转控制 │
│ 3. CFG引导机制 │
│ 4. 迭代去噪循环 │
│ 5. VAE解码输出 │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 📦 第四阶段:准备训练数据 │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 选择和下载数据集 │
│ │
│ • 宝可梦第一世代数据集 │
│ • 151个样本(图像+描述) │
│ • 命令:modelscope download │
│ • 保存路径:./data │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 数据预处理 │
│ │
│ ┌─────────────────────────────────┐ │
│ │ 1. 图像处理 │ │
│ │ • 调整大小:256×256 │ │
│ │ • 归一化:[-1, 1] │ │
│ │ • 转换为张量 │ │
│ └─────────────────────────────────┘ │
│ ↓ │
│ ┌─────────────────────────────────┐ │
│ │ 2. 文本处理 │ │
│ │ • 读取CSV元数据 │ │
│ │ • 分词和编码 │ │
│ │ • 填充到固定长度 │ │
│ └─────────────────────────────────┘ │
│ ↓ │
│ ┌─────────────────────────────────┐ │
│ │ 3. 批处理组织 │ │
│ │ • DataLoader配置 │ │
│ │ • 批次大小设置 │ │
│ │ • 多进程加载 │ │
│ └─────────────────────────────────┘ │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 🚀 第五阶段:训练模型 │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 配置训练参数 │
│ │
│ • 学习率:2e-4 │
│ • 批次大小:根据显存(2-8) │
│ • 训练步数:60,000步 │
│ • 优化器:AdamW │
│ • 调度器:余弦退火 │
└─────────────────────────────────────────┘
↓
┌───────────┴───────────┐
↓ ↓
┌──────────────┐ ┌──────────────┐
│ 单GPU训练 │ │ 多GPU训练 │
│ │ │ │
│ python │ │ accelerate │
│ train.py │ │ launch │
└──────────────┘ └──────────────┘
↓ ↓
└───────────┬───────────┘
↓
┌─────────────────────────────────────────┐
│ 训练循环执行 │
│ │
│ ┌─────────────────────────────────┐ │
│ │ 每个训练步骤: │ │
│ │ │ │
│ │ 1. 采样数据批次 │ │
│ │ ↓ │ │
│ │ 2. VAE编码图像 │ │
│ │ ↓ │ │
│ │ 3. 添加随机噪声 │ │
│ │ ↓ │ │
│ │ 4. DiT预测噪声 │ │
│ │ ↓ │ │
│ │ 5. 计算损失 │ │
│ │ ↓ │ │
│ │ 6. 反向传播 │ │
│ │ ↓ │ │
│ │ 7. 更新参数 │ │
│ └─────────────────────────────────┘ │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 监控训练过程 │
│ │
│ • 观察损失曲线(0.3 → 0.05) │
│ • 检查显存使用 │
│ • 定期保存检查点(每50k步) │
│ • 预计时间:10-20小时 │
└─────────────────────────────────────────┘
↓
❓ 遇到问题?
↓
┌───────────┼───────────┐
↓ ↓ ↓
┌──────────┐ ┌──────────┐ ┌──────────┐
│显存不足?│ │损失不降?│ │速度慢? │
│ │ │ │ │ │
│ 减小批次 │ │ 调学习率 │ │ 增workers│
│ 梯度检查 │ │ 检查数据 │ │ 优化预处理│
└──────────┘ └──────────┘ └──────────┘
↓ ↓ ↓
└───────────┼───────────┘
↓
✅ 训练完成!
↓
┌─────────────────────────────────────────┐
│ 🎨 第六阶段:测试和评估 │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 加载训练好的模型 │
│ │
│ • 加载检查点文件 │
│ • 初始化Pipeline │
│ • 准备推理环境 │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 生成测试图像 │
│ │
│ ┌─────────────────────────────────┐ │
│ │ 测试1:具体描述 │ │
│ │ "green lizard with seed" │ │
│ └─────────────────────────────────┘ │
│ ↓ │
│ ┌─────────────────────────────────┐ │
│ │ 测试2:抽象概念 │ │
│ │ "sharp claws" │ │
│ └─────────────────────────────────┘ │
│ ↓ │
│ ┌─────────────────────────────────┐ │
│ │ 测试3:不同随机种子 │ │
│ │ seed=0,1,2,3... │ │
│ └─────────────────────────────────┘ │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 评估生成质量 │
│ │
│ 定性评估: │
│ • ✓ 清晰度检查 │
│ • ✓ 风格一致性 │
│ • ✓ 特征准确性 │
│ • ✓ 多样性测试 │
│ │
│ 定量评估: │
│ • FID Score(越低越好) │
│ • CLIP Score(越高越好) │
│ • 人工打分 │
└─────────────────────────────────────────┘
↓
❓ 质量满意?
↓
┌───────────┴───────────┐
↓ ↓
❌ 不满意 ✅ 满意
↓ ↓
┌──────────────┐ 继续下一阶段
│ 改进方案: │ ↓
│ • 继续训练 │
│ • 调整参数 │
│ • 增加数据 │
└──────────────┘
↓
返回训练阶段
↓
┌─────────────────────────────────────────┐
│ ⚡ 第七阶段:进阶技巧 │
└─────────────────────────────────────────┘
↓
┌───────────┴───────────┐
↓ ↓
┌──────────────┐ ┌──────────────┐
│ 微调技术 │ │ LoRA微调 │
│ │ │ │
│ • 小学习率 │ │ • 参数高效 │
│ • 新数据集 │ │ • 几MB大小 │
│ • 快速适应 │ │ • 灵活切换 │
└──────────────┘ └──────────────┘
↓ ↓
└───────────┬───────────┘
↓
┌─────────────────────────────────────────┐
│ 优化推理性能 │
│ │
│ • 减少推理步数(30→20) │
│ • 批量生成 │
│ • 提示词工程 │
│ • 使用半精度(fp16) │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 部署实际应用 │
│ │
│ ┌─────────────────────────────────┐ │
│ │ Web界面(Gradio) │ │
│ │ • 交互式生成 │ │
│ │ • 参数调节 │ │
│ │ • 实时预览 │ │
│ └─────────────────────────────────┘ │
│ ↓ │
│ ┌─────────────────────────────────┐ │
│ │ API服务(FastAPI) │ │
│ │ • RESTful接口 │ │
│ │ • 批量处理 │ │
│ │ • 负载均衡 │ │
│ └─────────────────────────────────┘ │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 🌟 第八阶段:扩展到更大规模 │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 使用更大的数据集 │
│ │
│ • LAION-5B:50亿图文对 │
│ • COYO-700M:7亿图文对 │
│ • 自建数据集:爬取+标注 │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 增加模型规模 │
│ │
│ 0.1B → 1.5B → 5B → 12B │
│ │
│ • 增加层数(10→30→48→64) │
│ • 增加维度(1024→2048→3072→4096) │
│ • 增加注意力头 │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 分布式训练策略 │
│ │
│ ┌─────────────────────────────────┐ │
│ │ 数据并行 │ │
│ │ • 多GPU处理不同批次 │ │
│ │ • 自动梯度同步 │ │
│ └─────────────────────────────────┘ │
│ ↓ │
│ ┌─────────────────────────────────┐ │
│ │ 模型并行 │ │
│ │ • ZeRO-3分片 │ │
│ │ • 参数卸载到CPU │ │
│ └─────────────────────────────────┘ │
│ ↓ │
│ ┌─────────────────────────────────┐ │
│ │ 流水线并行 │ │
│ │ • 层级分布到不同GPU │ │
│ │ • 流水线执行 │ │
│ └─────────────────────────────────┘ │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 🎉 完成!你已掌握AI图像生成技术 │
│ │
│ ✅ 理解扩散模型原理 │
│ ✅ 构建完整训练流程 │
│ ✅ 部署实际应用 │
│ ✅ 具备扩展到大规模的能力 │
└─────────────────────────────────────────┘
🗺️ 快速导航指南
根据你的需求,可以快速跳转到对应章节:
🎯 我是完全新手
→ 从第一部分:必备的背景知识开始,了解扩散模型、VAE、Transformer等基础概念
💻 我懂理论,想看代码
→ 直接跳到第二部分:构建模型结构,查看完整的代码实现
🚀 我想快速开始训练
→ 跳到第四部分:准备训练数据和第五部分:训练模型
🔧 我遇到了训练问题
→ 查看第五部分 5.7节:训练技巧和常见问题
📊 我想评估模型效果
→ 参考第六部分:测试和评估模型
🌟 我想做实际应用
→ 查看第七部分:进阶技巧和第十部分:实际应用和商业化
📚 我想深入理解原理
→ 阅读第九部分:理论深入
📊 训练时间线参考
时间轴 (单GPU RTX 3090)
│
├─ 0小时: 环境准备、数据下载
│ ├─ 安装依赖: 30分钟
│ └─ 下载数据集: 10分钟
│
├─ 1小时: 开始训练
│ └─ 损失: 0.3-0.5
│
├─ 5小时: 初步收敛
│ └─ 损失: 0.15-0.2
│
├─ 10小时: 中期训练
│ └─ 损失: 0.10-0.15
│
├─ 15小时: 接近收敛
│ └─ 损失: 0.08-0.12
│
└─ 20小时: 完全收敛
└─ 损失: 0.05-0.10
└─ 可以开始测试!
🎓 学习路径建议
初级路径 (1-2周)
- ✅ 理解扩散模型基本原理
- ✅ 运行预训练模型生成图像
- ✅ 在小数据集上训练模型
- ✅ 评估和测试生成效果
中级路径 (1个月)
- ✅ 深入理解模型架构
- ✅ 修改模型结构进行实验
- ✅ 使用LoRA进行高效微调
- ✅ 构建简单的Web应用
高级路径 (2-3个月)
- ✅ 在大规模数据集上训练
- ✅ 实现分布式训练
- ✅ 优化推理性能
- ✅ 部署商业化应用
💡 关键概念速查表
| 概念 | 简单解释 | 在流程中的位置 |
|---|---|---|
| 扩散模型 | 通过逐步去噪生成图像 | 第一阶段 |
| VAE | 压缩和还原图像的编解码器 | 第一阶段 |
| Transformer | 使用注意力机制的神经网络 | 第一阶段 |
| DiT | 基于Transformer的扩散模型 | 第二阶段 |
| Pipeline | 协调各组件的工作流程 | 第三阶段 |
| CFG | 提高生成质量的引导技术 | 第三阶段 |
| 学习率 | 控制参数更新步长 | 第五阶段 |
| 批次大小 | 每次训练使用的样本数 | 第五阶段 |
| FID | 评估生成质量的指标 | 第六阶段 |
| LoRA | 高效的微调方法 | 第七阶段 |
🛠️ 所需资源清单
硬件要求
- ✅ 最低配置: RTX 3060 (12GB显存)
- ✅ 推荐配置: RTX 3090/4090 (24GB显存)
- ✅ 专业配置: A100 (40GB/80GB显存)
软件环境
- ✅ Python 3.8+
- ✅ PyTorch 2.0+
- ✅ CUDA 11.8+
- ✅ DiffSynth-Studio
数据存储
- ✅ 系统盘: 50GB (安装环境)
- ✅ 数据盘: 100GB+ (数据集和模型)
训练时间
- ✅ 小数据集 (151样本): 10-20小时
- ✅ 中数据集 (10K样本): 3-7天
- ✅ 大数据集 (1M+样本): 数周
现在,让我们开始详细的学习之旅!👇
引言:什么是文生图模型?
在开始深入技术细节之前,我们先来理解一下什么是文生图(Text-to-Image)模型。简单来说,文生图模型就是能够根据你输入的文字描述,自动生成对应图像的人工智能系统。比如你输入"一只戴着帽子的橙色猫咪",模型就能生成一张符合描述的图片。
近年来,像Stable Diffusion、DALL-E、Midjourney这样的文生图模型在互联网上引起了巨大轰动。这些模型背后的核心技术就是扩散模型(Diffusion Model)。本教程将带你从零开始,构建并训练一个属于自己的小型文生图模型。
虽然我们训练的模型参数量只有0.1B(1亿参数),远小于商业模型的几十亿参数,但这足以让你理解整个训练流程,并为将来训练更大规模的模型打下基础。
[继续之前的完整内容…]
这个流程图具有以下特点:
- 纵向布局:适合手机屏幕从上到下滚动查看
- 八个阶段:清晰标注每个学习阶段
- 详细步骤:每个阶段都展开了具体的子步骤
- 决策节点:标注了可能遇到的问题和解决方案
- 颜色编码:
- 绿色:开始和结束
- 黄色:主要阶段
- 红色:决策点
- 蓝色:问题解决方案
配合流程图,还添加了:
- 快速导航指南
- 训练时间线参考
- 学习路径建议
- 关键概念速查表
- 所需资源清单
这样读者在手机上就能快速了解整个流程,并根据自己的需求跳转到相应章节!
引言:什么是文生图模型?
在开始深入技术细节之前,我们先来理解一下什么是文生图(Text-to-Image)模型。简单来说,文生图模型就是能够根据你输入的文字描述,自动生成对应图像的人工智能系统。比如你输入"一只戴着帽子的橙色猫咪",模型就能生成一张符合描述的图片。
近年来,像Stable Diffusion、DALL-E、Midjourney这样的文生图模型在互联网上引起了巨大轰动。这些模型背后的核心技术就是扩散模型(Diffusion Model)。本教程将带你从零开始,构建并训练一个属于自己的小型文生图模型。
虽然我们训练的模型参数量只有0.1B(1亿参数),远小于商业模型的几十亿参数,但这足以让你理解整个训练流程,并为将来训练更大规模的模型打下基础。
第一部分:必备的背景知识
1.1 什么是扩散模型(Diffusion Model)?
扩散模型是当前最先进的图像生成技术之一。要理解扩散模型,我们可以用一个生活中的比喻:
想象你有一张清晰的照片,然后你不断地往照片上撒沙子,直到照片完全被沙子覆盖,变成一片噪声。这个过程就是正向扩散过程(Forward Diffusion Process)。
现在,如果我们能训练一个AI模型,让它学会如何一步步地把沙子清理掉,最终恢复出原始的清晰照片,这就是反向去噪过程(Reverse Denoising Process)。[1]
1.1.1 正向扩散过程的数学原理
在DDPM(Denoising Diffusion Probabilistic Models)方法中,正向加噪过程完全按照预设的公式进行,不涉及任何模型训练。具体来说,假设初始干净图像为 x0x_0x0,我们通过以下公式逐步添加噪声:
xt=αtx0+1−αtϵx_t = \sqrt{\alpha_t} x_0 + \sqrt{1-\alpha_t} \epsilonxt=αtx0+1−αtϵ
其中:
- xtx_txt 是第 ttt 步的带噪图像
- αt\alpha_tαt 是预定义的噪声调度参数
- ϵ\epsilonϵ 是从标准正态分布 N(0,1)N(0,1)N(0,1) 中采样的随机噪声
这个过程会重复执行约1000步,直到图像完全变成纯噪声。[1] [2]
1.1.2 反向去噪过程
反向过程才是扩散模型的核心。我们训练一个神经网络,让它学习如何预测每一步应该去除多少噪声。从根本上说,Diffusion Models的工作原理是通过连续添加高斯噪声来破坏训练数据,然后学习通过反转这个噪声过程来恢复数据。[4]
训练完成后,我们可以从纯噪声开始,通过神经网络逐步去噪,最终生成全新的图像。这就是为什么扩散模型能够生成高质量、多样化图像的原因。
1.1.3 为什么扩散模型效果好?
与传统的GAN(生成对抗网络)相比,扩散模型有几个显著优势:
- 训练稳定:不需要像GAN那样平衡生成器和判别器的训练
- 生成质量高:能够生成更精细、更真实的图像
- 多样性好:不容易出现模式崩塌问题
- 可控性强:容易与文本等条件信息结合
1.2 什么是VAE(变分自编码器)?
在实际的文生图系统中,我们通常不直接在原始像素空间进行扩散操作,因为这样计算量太大。相反,我们使用VAE将图像压缩到一个低维的"潜在空间"(Latent Space),然后在这个压缩后的空间进行扩散操作。
1.2.1 VAE的基本结构
VAE由两个主要部分组成:
编码器(Encoder):将高维的图像数据压缩成低维的潜在向量。比如将一张1024×1024×3的图像(约300万个数字)压缩成128×64×64的张量(约50万个数字),压缩比达到6倍。[5]
解码器(Decoder):从潜在空间中的向量重构出原始图像。解码器学习将压缩后的表示还原为清晰的图像。[5] [6]
1.2.2 为什么需要VAE?
使用VAE的好处是显而易见的:
- 降低计算成本:在压缩后的空间操作,速度快得多
- 提取关键特征:VAE会自动学习图像的重要特征,过滤掉不重要的细节
- 提高训练效率:更小的数据维度意味着更快的训练速度
在我们的教程中,我们使用FLUX.2模型中的VAE,它能将图像压缩到原始尺寸的1/16(宽高各缩小4倍),同时保持良好的重建质量。[7]
1.3 什么是Transformer和注意力机制?
Transformer是近年来深度学习领域最重要的突破之一,它最初用于自然语言处理,现在也被广泛应用于图像生成领域。
1.3.1 注意力机制的直观理解
注意力机制(Attention Mechanism)的核心思想是:当我们处理某个元素时,应该关注(attend to)哪些其他元素?
举个例子:当你看到句子"猫坐在垫子上",要理解"它"指的是什么时,你的注意力会自然地回到"猫"这个词。这就是注意力机制的工作原理。[8]
1.3.2 自注意力(Self-Attention)
自注意力机制允许模型在处理某个词时参考句子中其他词的信息。在图像生成中,自注意力让模型能够理解图像不同区域之间的关系。[9]
计算自注意力的公式为:
Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dkQKT)V
其中:
- QQQ(Query):查询向量,表示"我想找什么"
- KKK(Key):键向量,表示"我有什么"
- VVV(Value):值向量,表示"具体的内容"
- dkd_kdk:缩放因子,防止数值过大
这个公式的含义是:对于每个位置,计算它与所有其他位置的相关性(通过Q和K的点积),然后根据这些相关性加权组合所有位置的值(V)。[8] [10]
1.3.3 多头注意力(Multi-Head Attention)
多头注意力机制通过不同角度捕捉词间关系。它将注意力计算分成多个"头",每个头学习不同类型的关系模式。比如一个头可能关注颜色关系,另一个头关注空间位置关系。[9] [11]
1.4 什么是DiT(Diffusion Transformer)?
DiT是Diffusion Transformer的缩写,它将Transformer架构应用到扩散模型中,取代了早期的UNet架构。
1.4.1 从UNet到DiT的演变
早期的扩散模型(如Stable Diffusion 1.x和2.x)使用UNet作为去噪网络。UNet是一种卷积神经网络,具有编码器-解码器结构,通过跳跃连接保留细节信息。
然而,随着Vision Transformer(ViT)在图像领域的成功,研究者们发现Transformer架构在扩散模型中也能取得更好的效果。DiT架构基于Latent Diffusion Model(LDM)框架,采用Vision Transformer作为主干网络。[12]
1.4.2 DiT的优势
DiT相比UNet有几个关键优势:
- 可扩展性更强:Transformer可以通过增加层数和宽度轻松扩展
- 全局建模能力:注意力机制能够捕获图像中的长距离依赖关系
- 统一架构:文本编码和图像生成可以使用相同的架构
- 训练效率高:在大规模数据上训练时,DiT能够更有效地利用计算资源
[12] [13]
1.4.3 DiT的基本结构
一个典型的DiT模型包含以下组件:
- 位置编码:为图像的每个位置添加位置信息
- 时间步嵌入:将当前的去噪步骤编码为向量
- Transformer块:多层自注意力和前馈网络
- 输出投影:将Transformer的输出转换为噪声预测
在我们的教程中,我们将构建一个简化版的DiT模型,称为AAADiT(All About AI DiT),它包含10个Transformer块,总参数量约0.1B。[14]
1.5 文本编码器的作用
文本编码器负责将用户输入的文字描述转换为数值向量,这样神经网络才能理解文本的含义。
1.5.1 为什么需要文本编码器?
计算机只能处理数字,不能直接理解文字。文本编码器的作用就是将"一只橙色的猫"这样的文字转换成一串数字(向量),这串数字包含了文本的语义信息。
1.5.2 常用的文本编码器
在文生图领域,常用的文本编码器包括:
- CLIP:OpenAI开发的图文联合编码器,在Stable Diffusion中广泛使用
- T5:Google的文本编码器,在Imagen等模型中使用
- BERT:通用的文本理解模型
- Qwen:阿里云开发的中文友好的语言模型
在本教程中,我们使用Qwen3-0.6B作为文本编码器。它是一个轻量级的语言模型,参数量仅0.6B,但对中英文都有良好的理解能力。
第二部分:构建模型结构
现在我们已经了解了必要的背景知识,接下来开始构建我们的文生图模型。我们将从底层开始,逐步搭建每个组件。
2.1 设计模型的整体架构
我们的AAADiT模型采用以下架构:
文本输入 → 文本编码器 → 文本向量
↓
噪声图像 → VAE编码器 → 图像潜在向量 → DiT去噪网络 → 去噪后的潜在向量 → VAE解码器 → 生成图像
↑
时间步嵌入
整个流程可以分解为以下步骤:
- 文本编码:将用户输入的提示词转换为向量
- 图像编码:将噪声图像通过VAE编码为潜在向量
- 去噪处理:DiT模型根据文本向量和时间步信息,预测应该去除的噪声
- 迭代去噪:重复去噪过程30-50次,逐步清除噪声
- 图像解码:将最终的潜在向量通过VAE解码为可见图像
2.2 实现位置编码模块
位置编码是Transformer架构的重要组成部分。由于Transformer本身不包含位置信息(不像CNN那样有天然的空间结构),我们需要显式地告诉模型每个元素的位置。
import torch
from einops import rearrange, repeat
class AAAPositionalEmbedding(torch.nn.Module):
"""
位置编码模块
这个模块为图像和文本的每个位置分配一个可学习的位置向量。
想象每个位置都有一个"身份证",模型通过这个"身份证"知道
当前处理的是图像的哪个区域或文本的哪个位置。
"""
def __init__(self, height=16, width=16, dim=1024):
super().__init__()
# 为图像的每个位置创建一个可学习的位置编码
# 形状是 [1, dim, height, width],表示一个特征图
self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))
# 为文本的每个token创建一个可学习的位置编码
# 形状是 [dim],会在使用时广播到所有文本位置
self.text_emb = torch.nn.Parameter(torch.randn((dim,)))
def forward(self, image, text):
"""
前向传播函数
参数:
image: 图像潜在向量,形状 [B, C, H, W]
text: 文本编码向量,形状 [B, L, C]
返回:
位置编码,形状 [B, H*W+L, C]
"""
height, width = image.shape[-2:] # 获取图像的高度和宽度
# 将位置编码移到正确的设备和数据类型
image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)
# 如果输入图像的尺寸与初始化时不同,使用双线性插值调整位置编码的大小
# 这样模型可以处理不同分辨率的图像
image_emb = torch.nn.functional.interpolate(
image_emb, size=(height, width), mode="bilinear"
)
# 将图像位置编码从 [B, C, H, W] 重排为 [B, H*W, C]
# 这样每个空间位置变成了序列中的一个元素
image_emb = rearrange(image_emb, "B C H W -> B (H W) C")
# 处理文本位置编码
text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)
# 将文本位置编码复制到每个batch和每个文本位置
text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1])
# 将图像和文本的位置编码拼接在一起
# 最终形状是 [B, H*W+L, C],表示图像和文本的所有位置
emb = torch.concat([image_emb, text_emb], dim=1)
return emb
代码详解:
torch.nn.Parameter:将张量标记为可学习参数,训练时会自动更新rearrange:来自einops库,用于优雅地重排张量维度interpolate:双线性插值,用于调整位置编码的空间尺寸- 位置编码是可学习的,不是固定的正弦函数(这是现代Transformer的常见做法)
2.3 实现Transformer块
Transformer块是模型的核心计算单元。每个块包含两个主要部分:自注意力层和前馈网络层。
from diffsynth.core import attention_forward
class AAABlock(torch.nn.Module):
"""
Transformer块
这是模型的基本构建单元。每个块执行以下操作:
1. 自注意力:让图像的不同区域和文本的不同部分相互"交流"
2. 前馈网络:对每个位置独立地进行非线性变换
"""
def __init__(self, dim=1024, num_heads=32):
super().__init__()
# === 注意力部分 ===
# RMSNorm是一种归一化方法,比LayerNorm更简单高效
self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)
# 注意力机制的三个线性变换:Query, Key, Value
# 这三个变换将输入投影到不同的空间,用于计算注意力权重
self.to_q = torch.nn.Linear(dim, dim) # Query: "我想找什么"
self.to_k = torch.nn.Linear(dim, dim) # Key: "我有什么"
self.to_v = torch.nn.Linear(dim, dim) # Value: "具体的内容"
self.to_out = torch.nn.Linear(dim, dim) # 输出投影
# === 前馈网络部分 ===
self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)
self.ff = torch.nn.Sequential(
torch.nn.Linear(dim, dim*3), # 先扩展到3倍维度
torch.nn.SiLU(), # SiLU激活函数(也叫Swish)
torch.nn.Linear(dim*3, dim), # 再压缩回原始维度
)
# === 自适应门控 ===
# 这个线性层生成门控信号,用于调节注意力和前馈网络的贡献
# 门控信号由时间步嵌入生成,使模型能够根据去噪阶段调整行为
self.to_gate = torch.nn.Linear(dim, dim * 2)
self.num_heads = num_heads
def attention(self, emb, pos_emb):
"""
自注意力计算
这个函数实现了多头自注意力机制。它让序列中的每个元素
都能"看到"其他所有元素,从而捕获全局依赖关系。
"""
# 先归一化,然后加上位置编码
emb = self.norm_attn(emb + pos_emb)
# 计算Query, Key, Value
q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)
# 执行多头注意力计算
# attention_forward是DiffSynth-Studio提供的优化实现
# 它会自动选择最高效的注意力计算方法(如FlashAttention)
emb = attention_forward(
q, k, v,
q_pattern="b s (n d)", # batch, sequence, (num_heads, dim_per_head)
k_pattern="b s (n d)",
v_pattern="b s (n d)",
out_pattern="b s (n d)",
dims={"n": self.num_heads},
)
# 输出投影
emb = self.to_out(emb)
return emb
def feed_forward(self, emb, pos_emb):
"""
前馈网络
这是一个简单的两层MLP(多层感知机),对每个位置独立处理。
它的作用是增加模型的非线性表达能力。
"""
emb = self.norm_mlp(emb + pos_emb)
emb = self.ff(emb)
return emb
def forward(self, emb, pos_emb, t_emb):
"""
Transformer块的前向传播
参数:
emb: 输入嵌入,形状 [B, S, C]
pos_emb: 位置编码,形状 [B, S, C]
t_emb: 时间步嵌入,形状 [B, 1, C]
返回:
处理后的嵌入,形状 [B, S, C]
"""
# 从时间步嵌入生成两个门控信号
gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)
# 注意力分支:使用残差连接和自适应门控
# (1 + gate_attn) 使得门控信号可以放大或缩小注意力的贡献
emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)
# 前馈网络分支:同样使用残差连接和自适应门控
emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)
return emb
关键概念解释:
-
残差连接:
emb = emb + ...这种形式叫残差连接,它让梯度能够直接流过,避免梯度消失问题 -
归一化:在每个子层之前进行归一化,稳定训练过程
-
自适应门控:根据时间步动态调整每个子层的贡献,这是DiT的创新之处
-
多头注意力:将注意力分成多个头,每个头关注不同的特征模式
2.4 组装完整的DiT模型
现在我们将所有组件组装成完整的DiT模型:
from diffsynth.models.general_modules import TimestepEmbeddings
from diffsynth.core import gradient_checkpoint_forward
class AAADiT(torch.nn.Module):
"""
AAADiT: All About AI Diffusion Transformer
这是我们的完整去噪模型。它接收带噪声的图像潜在向量、
文本描述和当前时间步,输出预测的噪声。
"""
def __init__(self, dim=1024):
super().__init__()
# === 嵌入层 ===
# 位置编码器:为图像和文本的每个位置提供位置信息
self.pos_embedder = AAAPositionalEmbedding(dim=dim)
# 时间步嵌入器:将标量时间步转换为高维向量
# 256是嵌入维度,dim是输出维度
self.timestep_embedder = TimestepEmbeddings(256, dim)
# 图像嵌入器:将VAE编码的128维潜在向量投影到1024维
self.image_embedder = torch.nn.Sequential(
torch.nn.Linear(128, dim),
torch.nn.LayerNorm(dim)
)
# 文本嵌入器:将文本编码器的1024维输出投影到模型的1024维
# (这里维度相同,但保留投影层以便将来调整)
self.text_embedder = torch.nn.Sequential(
torch.nn.Linear(1024, dim),
torch.nn.LayerNorm(dim)
)
# === Transformer块 ===
# 堆叠10个Transformer块,形成深度网络
# 每增加一层,模型的表达能力就更强
self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])
# === 输出投影 ===
# 将模型的1024维输出投影回128维,与输入潜在向量的维度匹配
self.proj_out = torch.nn.Linear(dim, 128)
def forward(
self,
latents, # 图像潜在向量,形状 [B, 128, H, W]
prompt_embeds, # 文本编码,形状 [B, L, 1024]
timestep, # 当前时间步,标量
use_gradient_checkpointing=False, # 是否使用梯度检查点(节省显存)
use_gradient_checkpointing_offload=False, # 是否将检查点卸载到CPU
):
"""
前向传播:预测应该去除的噪声
工作流程:
1. 生成位置编码和时间步嵌入
2. 将图像和文本嵌入到统一的特征空间
3. 通过多个Transformer块处理
4. 投影回图像空间,输出噪声预测
"""
# === 第1步:准备嵌入 ===
# 生成位置编码
pos_emb = self.pos_embedder(latents, prompt_embeds)
# 生成时间步嵌入
# view(1, 1, -1) 将其变形为 [1, 1, dim],方便广播
t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)
# === 第2步:嵌入图像和文本 ===
# 将图像从 [B, C, H, W] 重排为 [B, H*W, C]
image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C"))
# 嵌入文本
text = self.text_embedder(prompt_embeds)
# 将图像和文本拼接成一个序列
# 形状变为 [B, H*W+L, dim]
emb = torch.concat([image, text], dim=1)
# === 第3步:通过Transformer块处理 ===
for block_id, block in enumerate(self.blocks):
# gradient_checkpoint_forward 是一个工具函数
# 它可以在训练时使用梯度检查点技术,用计算换显存
emb = gradient_checkpoint_forward(
block,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
emb=emb,
pos_emb=pos_emb,
t_emb=t_emb,
)
# === 第4步:输出投影 ===
# 只保留图像部分(前 H*W 个位置),丢弃文本部分
emb = emb[:, :latents.shape[-1] * latents.shape[-2]]
# 投影到输出空间
emb = self.proj_out(emb)
# 将序列重排回图像格式 [B, H*W, C] -> [B, C, H, W]
emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1])
return emb
模型参数量计算:
让我们估算一下这个模型的参数量:
- 位置编码:约 1M(1024 × 16 × 16)
- 时间步嵌入:约 0.3M
- 图像/文本嵌入器:约 2M
- 每个Transformer块:约 10M(主要在注意力和前馈网络的权重矩阵)
- 10个块:约 100M
- 输出投影:约 0.1M
总计:约 103M ≈ 0.1B 参数
这个规模的模型可以在消费级GPU(如RTX 3090)上训练。
第三部分:构建Pipeline
模型结构定义好后,我们需要构建一个Pipeline来协调各个组件的工作。Pipeline就像一个"指挥官",负责调度文本编码器、DiT模型、VAE等组件,完成从文本到图像的完整流程。
3.1 理解Pipeline的作用
在DiffSynth-Studio框架中,Pipeline是一个高层抽象,它封装了以下功能:
- 模型管理:加载、卸载模型到GPU/CPU
- 数据流转:在各个组件之间传递数据
- 推理流程:实现完整的生成流程
- 显存优化:自动管理显存,避免OOM(内存溢出)
3.2 定义Pipeline单元(Units)
Pipeline由多个"单元"组成,每个单元负责一个特定的任务。我们需要定义三个单元:
3.2.1 文本编码单元
from diffsynth.diffusion.base_pipeline import PipelineUnit
from transformers import AutoTokenizer
class AAAUnit_PromptEmbedder(PipelineUnit):
"""
文本编码单元
这个单元负责将文本提示词转换为数值向量。
它会分别处理正向提示词和负向提示词(用于CFG)。
"""
def __init__(self):
super().__init__(
seperate_cfg=True, # 正负提示词分开处理
input_params_posi={"prompt": "prompt"}, # 正向提示词的参数名
input_params_nega={"prompt": "negative_prompt"}, # 负向提示词的参数名
output_params=("prompt_embeds",), # 输出参数名
onload_model_names=("text_encoder",) # 需要加载的模型
)
# 使用最后一层隐藏状态作为文本编码
self.hidden_states_layers = (-1,)
def process(self, pipe, prompt):
"""
处理文本提示词
参数:
pipe: Pipeline对象
prompt: 文本提示词字符串
返回:
包含prompt_embeds的字典
"""
# 确保文本编码器已加载到GPU
pipe.load_models_to_device(self.onload_model_names)
# 使用Qwen的聊天模板格式化输入
# 这样可以利用Qwen在对话任务上的预训练知识
text = pipe.tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False, # 先不分词,返回文本
add_generation_prompt=True, # 添加生成提示
enable_thinking=False, # 不启用思考模式
)
# 分词并转换为张量
inputs = pipe.tokenizer(
text,
return_tensors="pt", # 返回PyTorch张量
padding="max_length", # 填充到最大长度
truncation=True, # 超长则截断
max_length=128 # 最大长度128个token
).to(pipe.device)
# 通过文本编码器获取隐藏状态
output = pipe.text_encoder(
**inputs,
output_hidden_states=True, # 输出所有层的隐藏状态
use_cache=False # 不使用KV缓存(我们不需要生成)
)
# 提取指定层的隐藏状态并拼接
# 这里只用最后一层,但框架支持使用多层
prompt_embeds = torch.concat(
[output.hidden_states[k] for k in self.hidden_states_layers],
dim=-1
)
return {"prompt_embeds": prompt_embeds}
为什么使用聊天模板?
Qwen模型是在对话数据上预训练的,使用聊天模板可以让模型更好地理解提示词的意图。例如,输入"一只猫"会被格式化为:
<|im_start|>user
一只猫<|im_end|>
<|im_start|>assistant
这种格式告诉模型:这是用户的输入,请理解它的含义。
3.2.2 噪声初始化单元
class AAAUnit_NoiseInitializer(PipelineUnit):
"""
噪声初始化单元
生成初始的随机噪声,作为扩散过程的起点。
"""
def __init__(self):
super().__init__(
input_params=("height", "width", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe, height, width, seed, rand_device):
"""
生成随机噪声
参数:
height, width: 目标图像的高度和宽度
seed: 随机种子(用于可复现性)
rand_device: 生成随机数的设备(CPU或GPU)
返回:
包含noise的字典
"""
# 生成形状为 [1, 128, H/16, W/16] 的随机噪声
# 除以16是因为VAE的下采样倍数是16
noise = pipe.generate_noise(
(1, 128, height//16, width//16),
seed=seed,
rand_device=rand_device,
rand_torch_dtype=pipe.torch_dtype
)
return {"noise": noise}
为什么需要随机种子?
设置随机种子可以让生成过程可复现。相同的种子、相同的提示词会生成相同的图像,这对调试和比较非常有用。
3.2.3 输入图像编码单元
class AAAUnit_InputImageEmbedder(PipelineUnit):
"""
输入图像编码单元
如果用户提供了输入图像(用于图生图),则通过VAE编码。
否则,直接使用随机噪声。
"""
def __init__(self):
super().__init__(
input_params=("input_image", "noise"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",)
)
def process(self, pipe, input_image, noise):
"""
处理输入图像
参数:
input_image: PIL图像对象,如果为None则是文生图模式
noise: 随机噪声
返回:
包含latents和input_latents的字典
"""
# 如果没有输入图像,直接使用噪声
if input_image is None:
return {"latents": noise, "input_latents": None}
# 加载VAE到GPU
pipe.load_models_to_device(['vae'])
# 预处理图像(调整大小、归一化等)
image = pipe.preprocess_image(input_image)
# 通过VAE编码器获取潜在向量
input_latents = pipe.vae.encode(image)
# 如果是训练模式,返回纯噪声和输入潜在向量
# 如果是推理模式,将噪声添加到输入潜在向量上
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else:
# 根据去噪强度添加噪声
latents = pipe.scheduler.add_noise(
input_latents, noise, timestep=pipe.scheduler.timesteps[0]
)
return {"latents": latents, "input_latents": input_latents}
图生图 vs 文生图:
- 文生图:从纯噪声开始,完全生成新图像
- 图生图:从输入图像的编码开始,添加部分噪声后再去噪,可以保留输入图像的部分特征
3.3 实现完整的Pipeline
现在我们将所有单元组装成完整的Pipeline:
from diffsynth.diffusion import FlowMatchScheduler
from diffsynth.diffusion.base_pipeline import BasePipeline
from diffsynth.models.z_image_text_encoder import ZImageTextEncoder
from diffsynth.models.flux2_vae import Flux2VAE
from PIL import Image
from tqdm import tqdm
class AAAImagePipeline(BasePipeline):
"""
AAAImagePipeline: 完整的文生图Pipeline
这个类协调所有组件,实现从文本到图像的完整流程。
"""
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
# 调用父类初始化
super().__init__(
device=device,
torch_dtype=torch_dtype,
height_division_factor=16, # 高度必须是16的倍数
width_division_factor=16, # 宽度必须是16的倍数
)
# === 初始化组件 ===
# 调度器:管理去噪过程的时间步
self.scheduler = FlowMatchScheduler("FLUX.2")
# 模型(初始化为None,稍后加载)
self.text_encoder: ZImageTextEncoder = None
self.dit: AAADiT = None
self.vae: Flux2VAE = None
self.tokenizer: AutoTokenizer = None
# 指定哪些模型需要在迭代中保持在GPU上
self.in_iteration_models = ("dit",)
# 注册Pipeline单元
self.units = [
AAAUnit_PromptEmbedder(),
AAAUnit_NoiseInitializer(),
AAAUnit_InputImageEmbedder(),
]
# 模型前向传播函数
self.model_fn = model_fn_aaa
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
model_configs: list = [],
tokenizer_config = None,
vram_limit: float = None,
):
"""
从预训练模型加载Pipeline
这是一个工厂方法,用于创建并初始化Pipeline。
参数:
torch_dtype: 模型的数据类型(bfloat16可以节省显存)
device: 运行设备
model_configs: 模型配置列表
tokenizer_config: 分词器配置
vram_limit: 显存限制(GB)
返回:
初始化好的Pipeline对象
"""
# 创建Pipeline实例
pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)
# 下载并加载模型
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# 从模型池中获取各个模型
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
pipe.dit = model_pool.fetch_model("aaa_dit")
pipe.vae = model_pool.fetch_model("flux2_vae")
# 加载分词器
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
# 检查是否需要启用显存管理
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad() # 推理时不计算梯度
def __call__(
self,
# === 文本参数 ===
prompt: str, # 正向提示词
negative_prompt: str = "", # 负向提示词
cfg_scale: float = 1.0, # CFG引导强度
# === 图像参数 ===
input_image: Image.Image = None, # 输入图像(可选)
denoising_strength: float = 1.0, # 去噪强度
# === 尺寸参数 ===
height: int = 1024,
width: int = 1024,
# === 随机性参数 ===
seed: int = None, # 随机种子
rand_device: str = "cpu", # 生成随机数的设备
# === 步数参数 ===
num_inference_steps: int = 30, # 推理步数
# === 进度条 ===
progress_bar_cmd = tqdm,
):
"""
执行图像生成
这是Pipeline的主要接口,用户调用这个方法来生成图像。
工作流程:
1. 设置调度器的时间步
2. 准备输入参数
3. 运行Pipeline单元(编码文本、初始化噪声等)
4. 迭代去噪
5. 解码为图像
"""
# === 第1步:设置时间步 ===
self.scheduler.set_timesteps(
num_inference_steps,
denoising_strength=denoising_strength,
dynamic_shift_len=height//16*width//16
)
# === 第2步:准备参数 ===
# 正向提示词参数
inputs_posi = {"prompt": prompt}
# 负向提示词参数
inputs_nega = {"negative_prompt": negative_prompt}
# 共享参数
inputs_shared = {
"cfg_scale": cfg_scale,
"input_image": input_image,
"denoising_strength": denoising_strength,
"height": height,
"width": width,
"seed": seed,
"rand_device": rand_device,
"num_inference_steps": num_inference_steps,
}
# === 第3步:运行Pipeline单元 ===
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
unit, self, inputs_shared, inputs_posi, inputs_nega
)
# === 第4步:迭代去噪 ===
# 加载DiT模型到GPU
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
# 遍历所有时间步
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
# 将时间步转换为张量
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
# 使用CFG引导的模型预测
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
# 执行一步去噪
inputs_shared["latents"] = self.step(
self.scheduler,
progress_id=progress_id,
noise_pred=noise_pred,
**inputs_shared
)
# === 第5步:解码为图像 ===
self.load_models_to_device(['vae'])
image = self.vae.decode(inputs_shared["latents"])
image = self.vae_output_to_image(image)
self.load_models_to_device([]) # 卸载所有模型
return image
CFG(Classifier-Free Guidance)解释:
CFG是一种提高生成质量的技术。它的工作原理是:
- 同时计算有提示词和无提示词(空提示词)的噪声预测
- 计算两者的差异(引导方向)
- 沿着引导方向放大,得到最终预测
公式:noise_pred=noise_pred_uncond+cfg_scale×(noise_pred_cond−noise_pred_uncond)\text{noise\_pred} = \text{noise\_pred\_uncond} + \text{cfg\_scale} \times (\text{noise\_pred\_cond} - \text{noise\_pred\_uncond})noise_pred=noise_pred_uncond+cfg_scale×(noise_pred_cond−noise_pred_uncond)
cfg_scale越大,生成的图像越符合提示词,但可能过度饱和。
3.4 定义模型前向函数
最后,我们需要定义一个函数来调用DiT模型:
def model_fn_aaa(
dit: AAADiT,
latents=None,
prompt_embeds=None,
timestep=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
"""
DiT模型的前向函数
这是一个简单的包装函数,用于调用DiT模型。
它被设计成可以与Pipeline的CFG机制无缝集成。
参数:
dit: DiT模型实例
latents: 当前的潜在向量
prompt_embeds: 文本编码
timestep: 当前时间步
use_gradient_checkpointing: 是否使用梯度检查点
use_gradient_checkpointing_offload: 是否卸载检查点
返回:
模型预测的噪声
"""
model_output = dit(
latents,
prompt_embeds,
timestep,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
return model_output
至此,我们的模型结构和Pipeline都已经构建完成!
第四部分:准备训练数据
有了模型和Pipeline,接下来需要准备训练数据。数据是深度学习的"燃料",数据的质量和数量直接决定了模型的性能。
4.1 选择合适的数据集
对于初学者,我们推荐使用小型、高质量的数据集进行实验。本教程使用宝可梦第一世代数据集,它包含151个宝可梦的图像和描述。
数据集特点:
- 数量适中:151个样本,适合快速实验
- 风格统一:所有图像都是宝可梦的官方美术图,风格一致
- 标注完整:每个图像都有详细的文字描述
- 训练快速:在单张GPU上几小时就能看到效果
4.2 下载数据集
使用ModelScope的命令行工具下载数据集:
# 安装modelscope(如果还没安装)
pip install modelscope
# 下载数据集到./data目录
modelscope download --dataset DiffSynth-Studio/pokemon-gen1 --local_dir ./data
下载完成后,数据集的结构如下:
data/
├── images/
│ ├── 001_妙蛙种子.png
│ ├── 002_妙蛙草.png
│ ├── ...
│ └── 151_梦幻.png
└── metadata_merged.csv
4.3 理解数据集格式
metadata_merged.csv文件包含每个图像的元数据:
image,prompt
001_妙蛙种子.png,"green, lizard, plant, Grass, Poison, seed on back, red eyes, smiling expression, short stout limbs, sharp claws"
002_妙蛙草.png,"green, blue, lizard, plant, Grass, Poison, large flower bud on back, red eyes, serious expression, four legs, sharp claws"
...
每一行包含:
image:图像文件名prompt:描述该宝可梦特征的文字标签
4.4 数据预处理
DiffSynth-Studio提供了UnifiedDataset类来处理数据集。它会自动完成以下操作:
- 读取图像和文本:从CSV文件读取元数据,加载对应的图像
- 图像预处理:
- 调整大小到目标分辨率(256×256)
- 归一化到[-1, 1]范围
- 转换为张量格式
- 批处理:将多个样本组合成批次,提高训练效率
为什么使用256×256分辨率?
- 训练速度快:分辨率越低,计算量越小
- 显存占用少:适合消费级GPU
- 足够验证效果:256×256已经能够清楚地看到宝可梦的特征
在生产环境中,通常使用512×512或1024×1024的分辨率。
4.5 数据增强(可选)
虽然本教程没有使用数据增强,但在实际项目中,数据增强可以显著提高模型的泛化能力:
常用的图像增强方法:
- 随机裁剪:从图像中随机裁剪一块区域
- 随机翻转:水平翻转图像(注意:某些情况下不适用,如文字)
- 颜色抖动:随机调整亮度、对比度、饱和度
- 随机旋转:小角度旋转图像
文本增强方法:
- 改写提示词:使用同义词替换
- 调整顺序:打乱标签的顺序
- 添加/删除标签:随机增减一些描述词
第五部分:训练模型
现在万事俱备,让我们开始训练模型!
5.1 理解训练过程
扩散模型的训练过程可以概括为:
- 采样数据:从数据集中随机选择一张图像和对应的文本
- 编码:通过VAE将图像编码为潜在向量
- 添加噪声:随机选择一个时间步,向潜在向量添加对应量的噪声
- 预测噪声:让模型预测添加的噪声
- 计算损失:比较预测的噪声和真实噪声,计算差异
- 反向传播:根据损失更新模型参数
- 重复:重复以上步骤数万次
5.2 实现训练模块
from diffsynth.diffusion import DiffusionTrainingModule, FlowMatchSFTLoss
from diffsynth.core import ModelConfig
class AAATrainingModule(DiffusionTrainingModule):
"""
训练模块
这个类封装了训练所需的所有逻辑,包括模型初始化、
前向传播、损失计算等。
"""
def __init__(self, device):
super().__init__()
# === 创建Pipeline ===
self.pipe = AAAImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device=device,
model_configs=[
# 文本编码器配置
ModelConfig(
model_id="Qwen/Qwen3-0.6B",
origin_file_pattern="model.safetensors"
),
# VAE配置
ModelConfig(
model_id="black-forest-labs/FLUX.2-klein-4B",
origin_file_pattern="vae/diffusion_pytorch_model.safetensors"
),
],
tokenizer_config=ModelConfig(
model_id="Qwen/Qwen3-0.6B",
origin_file_pattern="./"
),
)
# === 初始化DiT模型 ===
# 注意:这里我们创建一个新的DiT模型,而不是加载预训练的
self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)
# === 冻结其他模型 ===
# 只训练DiT,文本编码器和VAE保持不变
self.pipe.freeze_except(["dit"])
# === 设置训练模式的调度器 ===
# 训练时使用1000个时间步,推理时可以用更少的步数
self.pipe.scheduler.set_timesteps(1000, training=True)
def forward(self, data):
"""
前向传播:计算训练损失
参数:
data: 一个批次的数据,包含image和prompt
返回:
损失值(标量张量)
"""
# === 准备输入 ===
inputs_posi = {"prompt": data["prompt"]}
inputs_nega = {"negative_prompt": ""}
inputs_shared = {
"input_image": data["image"],
"height": data["image"].size[1],
"width": data["image"].size[0],
"cfg_scale": 1, # 训练时不使用CFG
"use_gradient_checkpointing": False, # 可以设为True以节省显存
"use_gradient_checkpointing_offload": False,
}
# === 运行Pipeline单元 ===
for unit in self.pipe.units:
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(
unit, self.pipe, inputs_shared, inputs_posi, inputs_nega
)
# === 计算损失 ===
# FlowMatchSFTLoss是一种适用于Flow Matching的损失函数
loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
return loss
Flow Matching vs DDPM:
本教程使用的是Flow Matching方法,它是DDPM的一种改进:
- DDPM:预测噪声 ϵ\epsilonϵ
- Flow Matching:预测从噪声到数据的"速度场"
Flow Matching通常收敛更快,生成质量更好。
5.3 配置训练参数
import accelerate
from diffsynth.core import UnifiedDataset
from diffsynth.diffusion import ModelLogger, launch_training_task
if __name__ == "__main__":
# === 初始化Accelerator ===
# Accelerator是HuggingFace提供的分布式训练工具
# 它可以自动处理多GPU、混合精度等复杂配置
accelerator = accelerate.Accelerator(
gradient_accumulation_steps=1 # 梯度累积步数
)
# === 准备数据集 ===
dataset = UnifiedDataset(
base_path="data/images", # 图像文件夹路径
metadata_path="data/metadata_merged.csv", # 元数据文件路径
max_data_items=10000000, # 最大数据量(这里设得很大,实际只有151个)
data_file_keys=("image",), # 数据文件
的列名
main_data_operator=UnifiedDataset.default_image_operator(
base_path="data/images",
height=256, # 目标高度
width=256 # 目标宽度
)
)
# === 创建训练模块 ===
model = AAATrainingModule(device=accelerator.device)
# === 配置模型日志记录器 ===
model_logger = ModelLogger(
"models/AAA/v1", # 模型保存路径
remove_prefix_in_ckpt="pipe.dit.", # 保存时移除这个前缀
)
# === 启动训练 ===
launch_training_task(
accelerator,
dataset,
model,
model_logger,
learning_rate=2e-4, # 学习率
num_workers=4, # 数据加载的工作进程数
save_steps=50000, # 每50000步保存一次模型
num_epochs=999999, # 训练轮数(实际上会手动停止)
)
5.4 训练参数详解
让我们详细解释每个训练参数的含义和作用:
5.4.1 学习率(Learning Rate)
学习率是训练中最重要的超参数之一,它控制每次参数更新的步长。
- 太大:训练不稳定,损失可能震荡或发散
- 太小:训练太慢,可能陷入局部最优
- 2e-4(0.0002):这是一个经过验证的合理值,适合大多数扩散模型
学习率调度:
在实际训练中,通常会使用学习率调度器,让学习率随训练进行而变化:
# 常见的学习率调度策略
# 1. 余弦退火:学习率按余弦曲线衰减
# 2. 线性预热:开始时从0逐渐增加到目标学习率
# 3. 阶梯衰减:每隔一定步数降低学习率
DiffSynth-Studio的launch_training_task函数内部已经实现了合理的学习率调度。[15]
5.4.2 批次大小(Batch Size)
批次大小决定每次更新参数时使用多少个样本。由于我们的数据集很小(151个样本),实际的批次大小由GPU显存决定。
显存与批次大小的关系:
- RTX 3090(24GB):可以使用batch_size=4-8
- RTX 4090(24GB):可以使用batch_size=4-8
- A100(40GB/80GB):可以使用更大的批次
如果显存不足,可以使用梯度累积技术:
# 例如:想要有效批次大小为8,但显存只够batch_size=2
# 可以设置gradient_accumulation_steps=4
# 这样每4步累积的梯度相当于batch_size=8
accelerator = accelerate.Accelerator(gradient_accumulation_steps=4)
5.4.3 训练步数与收敛
根据经验,这个模型需要约60,000步才能收敛。收敛的标志包括:
- 损失稳定:训练损失不再明显下降
- 生成质量稳定:生成的图像质量不再提升
- 过拟合迹象:模型开始记忆训练数据
训练时间估算:
- 单GPU(RTX 3090):约10-20小时
- 单GPU(RTX 4090):约8-15小时
- 4×GPU(A100):约2-4小时
5.5 启动训练
5.5.1 单GPU训练
如果你只有一张GPU,直接运行Python脚本:
python train_from_scratch.py
5.5.2 多GPU训练
如果你有多张GPU,可以使用Accelerate的分布式训练功能:
第一步:配置Accelerate
accelerate config
这个命令会启动一个交互式配置向导,询问你:
In which compute environment are you running?
> This machine
Which type of machine are you using?
> multi-GPU
How many different machines will you use?
> 1
Do you wish to optimize your script with torch dynamo?
> NO
Do you want to use DeepSpeed?
> NO
Do you want to use FullyShardedDataParallel?
> NO
How many GPU(s) should be used for distributed training?
> 4 # 根据你的GPU数量填写
What GPU(s) (by id) should be used for training on this machine?
> 0,1,2,3 # 使用哪些GPU
Do you wish to use FP16 or BF16 (mixed precision)?
> bf16 # 使用bfloat16混合精度
第二步:启动训练
accelerate launch train_from_scratch.py
Accelerate会自动处理:
- 模型并行:将模型分布到多个GPU
- 数据并行:每个GPU处理不同的数据批次
- 梯度同步:自动同步和平均梯度
- 混合精度:使用bfloat16加速训练
5.6 监控训练过程
训练过程中,你会看到类似这样的输出:
Epoch 1/999999: 100%|████████| 151/151 [00:45<00:00, 3.31it/s, loss=0.234]
Saving checkpoint at step 151...
Epoch 2/999999: 100%|████████| 151/151 [00:45<00:00, 3.35it/s, loss=0.198]
Epoch 3/999999: 100%|████████| 151/151 [00:45<00:00, 3.33it/s, loss=0.176]
...
关键指标:
-
loss(损失):应该逐渐下降
- 初始:0.3-0.5
- 收敛:0.05-0.15
- 如果损失不下降或上升,说明训练有问题
-
it/s(迭代速度):每秒处理多少个批次
- 速度太慢可能是数据加载瓶颈(增加num_workers)
- 或者批次大小太大(减小batch_size)
-
显存使用:使用
nvidia-smi命令查看watch -n 1 nvidia-smi # 每秒刷新一次
5.7 训练技巧和常见问题
5.7.1 显存不足(OOM)
症状:训练时出现"CUDA out of memory"错误
解决方案:
- 减小批次大小:在DataLoader中设置更小的batch_size
- 启用梯度检查点:
inputs_shared = { ... "use_gradient_checkpointing": True, } - 使用梯度累积:如前所述
- 降低分辨率:从256×256降到128×128
5.7.2 损失不下降
可能原因:
- 学习率太大或太小:尝试调整learning_rate
- 数据问题:检查数据集是否正确加载
- 模型初始化问题:重新初始化模型
调试方法:
# 在训练循环中添加调试代码
def forward(self, data):
print(f"Image shape: {data['image'].size}")
print(f"Prompt: {data['prompt']}")
loss = FlowMatchSFTLoss(...)
print(f"Loss: {loss.item()}")
return loss
5.7.3 训练速度慢
优化方法:
-
增加num_workers:加快数据加载
launch_training_task(..., num_workers=8) -
使用混合精度:bfloat16比float32快约2倍
-
优化数据预处理:将图像预先调整到目标大小
-
使用更快的存储:将数据集放在SSD而非HDD
5.7.4 过拟合
症状:模型只能生成训练集中的图像,缺乏创造性
解决方案:
- 增加数据量:使用更大的数据集
- 数据增强:随机翻转、裁剪等
- 早停:在过拟合前停止训练
- 正则化:添加权重衰减
# 在优化器中添加权重衰减
optimizer = torch.optim.AdamW(
model.parameters(),
lr=2e-4,
weight_decay=0.01 # L2正则化
)
5.8 保存和加载检查点
训练过程中,模型会定期保存检查点:
models/AAA/v1/
├── step-50000.safetensors
├── step-100000.safetensors
├── step-150000.safetensors
...
手动保存检查点:
# 在训练脚本中添加
import torch
# 保存完整模型
torch.save(model.state_dict(), "my_checkpoint.pth")
# 保存为safetensors格式(推荐)
from safetensors.torch import save_file
save_file(model.state_dict(), "my_checkpoint.safetensors")
恢复训练:
# 加载检查点继续训练
model = AAATrainingModule(device="cuda")
checkpoint = torch.load("models/AAA/v1/step-50000.safetensors")
model.load_state_dict(checkpoint)
第六部分:测试和评估模型
训练完成后(或使用预训练模型),我们需要测试模型的生成效果。
6.1 加载预训练模型
如果你不想等待训练完成,可以直接下载我们预先训练好的模型:
# 下载预训练模型
modelscope download \
--model DiffSynth-Studio/AAAMyModel \
step-600000.safetensors \
--local_dir models/DiffSynth-Studio/AAAMyModel
6.2 创建推理脚本
import torch
from PIL import Image
from diffsynth import load_model
from diffsynth.core import ModelConfig
# === 加载Pipeline ===
pipe = AAAImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(
model_id="Qwen/Qwen3-0.6B",
origin_file_pattern="model.safetensors"
),
ModelConfig(
model_id="black-forest-labs/FLUX.2-klein-4B",
origin_file_pattern="vae/diffusion_pytorch_model.safetensors"
),
],
tokenizer_config=ModelConfig(
model_id="Qwen/Qwen3-0.6B",
origin_file_pattern="./"
),
)
# === 加载训练好的DiT模型 ===
pipe.dit = load_model(
AAADiT,
"models/DiffSynth-Studio/AAAMyModel/step-600000.safetensors",
torch_dtype=torch.bfloat16,
device="cuda"
)
print("Model loaded successfully!")
6.3 生成"御三家"宝可梦
让我们测试模型能否生成第一世代的三个初始宝可梦:
# 定义提示词
prompts = [
# 妙蛙种子
"green, lizard, plant, Grass, Poison, seed on back, red eyes, smiling expression, short stout limbs, sharp claws",
# 小火龙
"orange, cream, lizard, Fire, flame on tail tip, large eyes, smiling expression, cream-colored belly patch, sharp claws",
# 杰尼龟
"蓝色,米色,棕色,乌龟,水系,龟壳,大眼睛,短四肢,卷曲尾巴",
]
# 生成图像
for seed, prompt in enumerate(prompts):
print(f"Generating image {seed+1}/3: {prompt[:50]}...")
image = pipe(
prompt=prompt,
negative_prompt=" ", # 空的负向提示词
num_inference_steps=30, # 推理步数
cfg_scale=10, # CFG引导强度
seed=seed, # 使用不同的随机种子
height=256,
width=256,
)
# 保存图像
image.save(f"output_starter_{seed}.jpg")
print(f"Saved to output_starter_{seed}.jpg")
print("All images generated!")
参数说明:
-
num_inference_steps=30:去噪步数,越多质量越好但速度越慢
- 10-20步:快速预览
- 30-50步:高质量生成
- 50+步:通常没有明显提升
-
cfg_scale=10:CFG引导强度
- 1.0:无引导,生成多样但可能不符合提示词
- 5-10:平衡质量和多样性
- 15+:强引导,严格符合提示词但可能过度饱和
6.4 测试泛化能力
现在测试模型能否理解抽象概念,生成具有特定特征的宝可梦:
# 测试:生成具有"锐利爪子"的宝可梦
prompts = [
"sharp claws",
"sharp claws",
"sharp claws",
]
for seed, prompt in enumerate(prompts):
image = pipe(
prompt=prompt,
negative_prompt=" ",
num_inference_steps=30,
cfg_scale=10,
seed=seed+4, # 使用不同的种子
height=256,
width=256,
)
image.save(f"output_sharp_claws_{seed}.jpg")
print("Concept test completed!")
预期结果:
如果训练成功,模型应该能够:
- 生成不同的宝可梦(因为种子不同)
- 所有生成的宝可梦都有明显的爪子
- 风格与训练数据一致
6.5 评估指标
6.5.1 定性评估
视觉检查:
- 清晰度:图像是否清晰,没有模糊或噪点?
- 一致性:是否符合宝可梦的风格?
- 准确性:是否包含提示词中的特征?
- 多样性:不同种子是否产生不同的结果?
6.5.2 定量评估
虽然图像生成主要依赖人类评估,但也有一些自动化指标:
1. FID(Fréchet Inception Distance)
FID衡量生成图像与真实图像的分布差异,越低越好。
# 使用pytorch-fid库计算FID
# pip install pytorch-fid
# 生成一批图像
import os
os.makedirs("generated_images", exist_ok=True)
for i in range(100):
image = pipe(
prompt="pokemon",
seed=i,
height=256,
width=256,
)
image.save(f"generated_images/{i}.jpg")
# 计算FID
# pytorch-fid data/images generated_images
2. CLIP Score
CLIP Score衡量图像与文本的匹配度,越高越好。
from transformers import CLIPProcessor, CLIPModel
import torch
# 加载CLIP模型
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
def calculate_clip_score(image, text):
"""计算图像与文本的CLIP相似度"""
inputs = clip_processor(
text=[text],
images=image,
return_tensors="pt",
padding=True
)
outputs = clip_model(**inputs)
logits_per_image = outputs.logits_per_image
return logits_per_image.item()
# 测试
prompt = "green lizard with seed on back"
image = pipe(prompt=prompt, seed=0, height=256, width=256)
score = calculate_clip_score(image, prompt)
print(f"CLIP Score: {score}")
3. 人类评估
最可靠的评估方法仍然是人类打分:
# 创建评估界面
import gradio as gr
def generate_and_rate(prompt, seed):
image = pipe(prompt=prompt, seed=seed, height=256, width=256)
return image
# 创建Gradio界面
demo = gr.Interface(
fn=generate_and_rate,
inputs=[
gr.Textbox(label="Prompt"),
gr.Slider(0, 1000, label="Seed")
],
outputs=gr.Image(label="Generated Image"),
title="Pokemon Generator - Rate the Results"
)
demo.launch()
6.6 常见生成问题及解决方案
6.6.1 生成的图像模糊
原因:
- 训练步数不足
- 推理步数太少
- VAE质量问题
解决方案:
# 增加推理步数
image = pipe(prompt=prompt, num_inference_steps=50)
# 或者训练更久
6.6.2 生成的图像不符合提示词
原因:
- CFG引导强度太低
- 文本编码器理解有误
- 训练数据不足
解决方案:
# 增加CFG强度
image = pipe(prompt=prompt, cfg_scale=15)
# 改进提示词
prompt = "a green pokemon with a seed on its back, lizard-like, red eyes"
6.6.3 生成的图像缺乏多样性
原因:
- 过拟合
- 数据集太小
- 随机种子设置不当
解决方案:
# 使用不同的随机种子
for seed in range(10):
image = pipe(prompt=prompt, seed=seed)
# 降低CFG强度增加随机性
image = pipe(prompt=prompt, cfg_scale=5)
第七部分:进阶技巧
7.1 微调(Fine-tuning)
如果你想在自己的数据集上训练,但数据量不够大,可以使用微调技术:
# 加载预训练模型
pipe.dit = load_model(
AAADiT,
"models/DiffSynth-Studio/AAAMyModel/step-600000.safetensors",
torch_dtype=torch.bfloat16,
device="cuda"
)
# 在新数据集上继续训练
# 使用更小的学习率
launch_training_task(
accelerator,
new_dataset,
model,
model_logger,
learning_rate=1e-5, # 比从零训练小10倍
num_workers=4,
save_steps=1000,
num_epochs=100,
)
微调的优势:
- 需要更少的数据(几十到几百张即可)
- 训练更快(几小时而非几天)
- 保留预训练知识,只学习新特征
7.2 LoRA(Low-Rank Adaptation)
LoRA是一种参数高效的微调方法,只训练少量参数:
from diffsynth.models.lora import inject_lora
# 在模型中注入LoRA层
inject_lora(
pipe.dit,
rank=16, # LoRA秩,越大表达能力越强但参数越多
alpha=16, # 缩放因子
target_modules=["to_q", "to_k", "to_v", "to_out"] # 在哪些层添加LoRA
)
# 冻结原始参数,只训练LoRA
for name, param in pipe.dit.named_parameters():
if "lora" not in name:
param.requires_grad = False
# 训练(LoRA参数量很小,可以用更大的学习率)
launch_training_task(
accelerator,
dataset,
model,
model_logger,
learning_rate=1e-3,
...
)
LoRA的优势:
- 参数量极小(几MB vs 几GB)
- 可以训练多个LoRA,灵活切换风格
- 显存占用少,可以在小GPU上训练
7.3 文生图+图生图混合模式
# 先生成一张图像
base_image = pipe(
prompt="a fire type pokemon",
seed=42,
height=256,
width=256,
)
# 基于这张图像再生成(图生图)
refined_image = pipe(
prompt="a fire type pokemon with wings",
input_image=base_image,
denoising_strength=0.7, # 保留70%的原图特征
seed=43,
height=256,
width=256,
)
denoising_strength参数:
- 0.0:完全保留输入图像
- 0.3-0.5:轻微修改
- 0.6-0.8:中等修改
- 1.0:完全重新生成
7.4 批量生成
import os
from tqdm import tqdm
# 创建输出目录
os.makedirs("batch_output", exist_ok=True)
# 定义多个提示词
prompts = [
"fire type pokemon",
"water type pokemon",
"grass type pokemon",
"electric type pokemon",
"psychic type pokemon",
]
# 每个提示词生成10张图像
for prompt_id, prompt in enumerate(prompts):
print(f"Generating for prompt: {prompt}")
for seed in tqdm(range(10)):
image = pipe(
prompt=prompt,
seed=seed,
height=256,
width=256,
num_inference_steps=30,
cfg_scale=10,
)
filename = f"batch_output/{prompt_id:02d}_{seed:03d}.jpg"
image.save(filename)
print("Batch generation completed!")
7.5 提示词工程
好的提示词能显著提升生成质量。以下是一些技巧:
7.5.1 结构化提示词
# 不好的提示词
prompt = "pokemon"
# 好的提示词
prompt = "fire type pokemon, orange color, lizard-like, flame on tail, red eyes, friendly expression"
提示词模板:
[主体], [颜色], [形状/类型], [关键特征], [表情], [风格]
7.5.2 权重调整
虽然我们的简单模型不支持权重语法,但在更高级的模型中可以这样做:
# 强调某些词(在Stable Diffusion中)
prompt = "(fire:1.5) type pokemon, orange, (flame:1.3) on tail"
# 数字越大,该词的影响越大
7.5.3 负向提示词
# 使用负向提示词排除不想要的特征
image = pipe(
prompt="cute pokemon",
negative_prompt="scary, dark, evil, monster", # 避免生成这些特征
cfg_scale=10,
)
7.6 优化推理速度
7.6.1 减少推理步数
# 使用更少的步数
image = pipe(
prompt=prompt,
num_inference_steps=20, # 从30降到20
cfg_scale=10,
)
质量损失通常很小,但速度提升明显。
7.6.2 使用TorchScript
# 编译模型以加速
pipe.dit = torch.jit.script(pipe.dit)
7.6.3 使用半精度
# 使用float16而非bfloat16(某些GPU上更快)
pipe = AAAImagePipeline.from_pretrained(
torch_dtype=torch.float16, # 改为float16
device="cuda",
...
)
第八部分:扩展到更大规模
8.1 使用更大的数据集
宝可梦数据集只是一个玩具示例。要训练真正强大的模型,你需要更大的数据集:
推荐数据集:
-
LAION-5B:50亿图文对,最大的开源数据集
- 下载地址:https://laion.ai/blog/laion-5b/
- 需要数TB的存储空间
-
COYO-700M:7亿图文对,质量较高
- 下载地址:https://github.com/kakaobrain/coyo-dataset
-
自建数据集:
- 爬取网络图像
- 使用BLIP等模型自动生成描述
- 人工标注关键样本
数据清洗:
# 过滤低质量图像
def filter_dataset(image_path, min_size=256, max_aspect_ratio=2.0):
"""
过滤不符合要求的图像
参数:
min_size: 最小边长
max_aspect_ratio: 最大宽高比
"""
img = Image.open(image_path)
width, height = img.size
# 检查尺寸
if min(width, height) < min_size:
return False
# 检查宽高比
aspect_ratio = max(width, height) / min(width, height)
if aspect_ratio > max_aspect_ratio:
return False
return True
# 应用过滤
filtered_images = [
img_path for img_path in all_images
if filter_dataset(img_path)
]
8.2 增加模型规模
我们的0.1B模型很小,增加规模可以提升性能:
class LargeAAADiT(torch.nn.Module):
"""更大的DiT模型"""
def __init__(self, dim=2048): # 从1024增加到2048
super().__init__()
self.pos_embedder = AAAPositionalEmbedding(dim=dim)
self.timestep_embedder = TimestepEmbeddings(512, dim) # 也增加
self.image_embedder = torch.nn.Sequential(
torch.nn.Linear(128, dim),
torch.nn.LayerNorm(dim)
)
self.text_embedder = torch.nn.Sequential(
torch.nn.Linear(1024, dim),
torch.nn.LayerNorm(dim)
)
# 增加到30层
self.blocks = torch.nn.ModuleList([
AAABlock(dim, num_heads=64) # 更多注意力头
for _ in range(30)
])
self.proj_out = torch.nn.Linear(dim, 128)
参数量估算:
- dim=2048, 30层:约1.5B参数
- dim=3072, 48层:约5B参数
- dim=4096, 64层:约12B参数
训练大模型的挑战:
- 显存需求:需要多张A100/H100
- 训练时间:可能需要数周
- 数据需求:需要数百万到数十亿样本
- 计算成本:可能需要数万美元
8.3 分布式训练策略
8.3.1 数据并行(Data Parallelism)
最简单的并行策略,每个GPU处理不同的数据批次:
# Accelerate自动处理数据并行
accelerator = accelerate.Accelerator()
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
8.3.2 模型并行(Model Parallelism)
当模型太大无法放入单个GPU时:
# 使用DeepSpeed的ZeRO-3
from accelerate import DeepSpeedPlugin
deepspeed_plugin = DeepSpeedPlugin(
zero_stage=3, # ZeRO-3: 分片优化器状态、梯度和参数
offload_optimizer_device="cpu", # 将优化器状态卸载到CPU
offload_param_device="cpu", # 将参数卸载到CPU
)
accelerator = accelerate.Accelerator(deepspeed_plugin=deepspeed_plugin)
8.3.3 流水线并行(Pipeline Parallelism)
将模型的不同层放在不同GPU上:
# 手动分配层到不同设备
class PipelinedAAADiT(torch.nn.Module):
def __init__(self):
super().__init__()
# 前10层在GPU 0
self.blocks_0 = torch.nn.ModuleList([
AAABlock().to("cuda:0") for _ in range(10)
])
# 后10层在GPU 1
self.blocks_1 = torch.nn.ModuleList([
AAABlock().to("cuda:1") for _ in range(10)
])
def forward(self, x):
x = x.to("cuda:0")
for block in self.blocks_0:
x = block(x)
x = x.to("cuda:1")
for block in self.blocks_1:
x = block(x)
return x
8.4 使用更好的VAE
我们使用的FLUX.2 VAE已经很好,但也可以尝试其他选择:
# 使用Stable Diffusion的VAE
from diffsynth.models.sd_vae import SDVAE
pipe.vae = SDVAE.from_pretrained("stabilityai/sd-vae-ft-mse")
# 或者训练自己的VAE
# (这需要另一个完整的教程)
8.5 实现自定义调度器
不同的噪声调度策略会影响生成质量:
class CustomScheduler:
"""自定义噪声调度器"""
def __init__(self, num_train_timesteps=1000):
self.num_train_timesteps = num_train_timesteps
# 定义beta调度(控制噪声添加速度)
self.betas = self.cosine_beta_schedule(num_train_timesteps)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
def cosine_beta_schedule(self, timesteps, s=0.008):
"""
余弦调度:开始慢慢加噪,后期快速加噪
这比线性调度效果更好
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0.0001, 0.9999)
def add_noise(self, original, noise, timestep):
"""在指定时间步添加噪声"""
sqrt_alpha_prod = torch.sqrt(self.alphas_cumprod[timestep])
sqrt_one_minus_alpha_prod = torch.sqrt(1 - self.alphas_cumprod[timestep])
noisy = sqrt_alpha_prod * original + sqrt_one_minus_alpha_prod * noise
return noisy
第九部分:理论深入
9.1 扩散模型的数学原理
让我们深入理解扩散模型背后的数学。[1] [2]
9.1.1 前向过程
前向过程定义为一个马尔可夫链,逐步向数据添加高斯噪声:
q(xt∣xt−1)=N(xt;1−βtxt−1,βtI)q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t} x_{t-1}, \beta_t I)q(xt∣xt−1)=N(xt;1−βtxt−1,βtI)
其中:
- xtx_txt 是第 ttt 步的带噪数据
- βt\beta_tβt 是噪声调度参数
- N\mathcal{N}N 表示正态分布
使用重参数化技巧,可以直接从 x0x_0x0 采样 xtx_txt:
xt=αˉtx0+1−αˉtϵx_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \epsilonxt=αˉtx0+1−αˉtϵ
其中 αˉt=∏i=1t(1−βi)\bar{\alpha}_t = \prod_{i=1}^t (1-\beta_i)αˉt=i=1∏t(1−βi),ϵ∼N(0,I)\epsilon \sim \mathcal{N}(0, I)ϵ∼N(0,I)
9.1.2 反向过程
反向过程学习从噪声恢复数据:
pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),Σθ(xt,t))p_\theta(x_{t-1} | x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \Sigma_\theta(x_t, t))pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),Σθ(xt,t))
在实践中,我们训练一个神经网络 ϵθ(xt,t)\epsilon_\theta(x_t, t)ϵθ(xt,t) 来预测噪声 ϵ\epsilonϵ。
9.1.3 训练目标
简化的训练目标是:
Lsimple=Et,x0,ϵ[∥ϵ−ϵθ(xt,t)∥2]L_{simple} = \mathbb{E}_{t, x_0, \epsilon} \left[ \| \epsilon - \epsilon_\theta(x_t, t) \|^2 \right]Lsimple=Et,x0,ϵ[∥ϵ−ϵθ(xt,t)∥2]
这就是我们在代码中使用的损失函数!
9.2 Flow Matching的改进
Flow Matching是对DDPM的改进,它学习一个连续的"流"而非离散的步骤。[4]
核心思想:
不是预测噪声 ϵ\epsilonϵ,而是预测从噪声到数据的"速度场" vtv_tvt:
dxtdt=vθ(xt,t)\frac{dx_t}{dt} = v_\theta(x_t, t)dtdxt=vθ(xt,t)
优势:
- 更快收敛:通常需要更少的训练步数
- 更少的推理步数:可以用10-20步达到DDPM 50步的质量
- 更稳定:训练过程更平滑
9.3 注意力机制的计算复杂度
标准自注意力的复杂度是 O(n2)O(n^2)O(n2),其中 nnn 是序列长度。对于256×256的图像(压缩后16×16=256个token),这是可接受的。但对于更高分辨率:
- 512×512:32×32=1024个token,复杂度是256×256的16倍
- 1024×1024:64×64=4096个token,复杂度是256×256的256倍
优化方法:
-
FlashAttention:重新排列计算顺序,减少内存访问
# DiffSynth-Studio自动使用FlashAttention emb = attention_forward(q, k, v, ...) -
窗口注意力:只在局部窗口内计算注意力
# 例如:每个token只关注周围7×7的区域 # 复杂度从O(n^2)降到O(n×49) -
稀疏注意力:只计算重要位置之间的注意力
第十部分:实际应用和商业化
10.1 构建Web应用
使用Gradio快速构建交互式界面:
import gradio as gr
def generate_pokemon(prompt, negative_prompt, steps, cfg_scale, seed):
"""生成宝可梦的包装函数"""
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=steps,
cfg_scale=cfg_scale,
seed=seed if seed >= 0 else None,
height=256,
width=256,
)
return image
# 创建Gradio界面
demo = gr.Interface(
fn=generate_pokemon,
inputs=[
gr.Textbox(label="Prompt", placeholder="Describe your Pokemon..."),
gr.Textbox(label="Negative Prompt", value=""),
gr.Slider(10, 50, value=30, step=1, label="Steps"),
gr.Slider(1, 20, value=10, step=0.5, label="CFG Scale"),
gr.Number(label="Seed (-1 for random)", value=-1),
],
outputs=gr.Image(label="Generated Pokemon"),
title="Pokemon Generator",
description="Generate custom Pokemon using AI!",
examples=[
["fire type pokemon with wings", "", 30, 10, 42],
["water type pokemon, blue, turtle", "", 30, 10, 123],
["electric type pokemon, yellow, mouse-like", "", 30, 10, 456],
]
)
# 启动服务
demo.launch(share=True) # share=True创建公开链接
10.2 API服务
使用FastAPI构建REST API:
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import base64
from io import BytesIO
app = FastAPI()
class GenerateRequest(BaseModel):
prompt: str
negative_prompt: str = ""
steps: int = 30
cfg_scale: float = 10.0
seed: int = -1
height: int = 256
width: int = 256
@app.post("/generate")
async def generate_image(request: GenerateRequest):
"""生成图像的API端点"""
try:
# 生成图像
image = pipe(
prompt=request.prompt,
negative_prompt=request.negative_prompt,
num_inference_steps=request.steps,
cfg_scale=request.cfg_scale,
seed=request.seed if request.seed >= 0 else None,
height=request.height,
width=request.width,
)
# 转换为base64
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return {"image": img_str, "status": "success"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# 运行服务
# uvicorn api:app --host 0.0.0.0 --port 8000
客户端调用:
import requests
import base64
from PIL import Image
from io import BytesIO
# 发送请求
response = requests.post("http://localhost:8000/generate", json={
"prompt": "fire type pokemon",
"steps": 30,
"cfg_scale": 10,
"seed": 42
})
# 解析响应
data = response.json()
img_data = base64.b64decode(data["image"])
image = Image.open(BytesIO(img_data))
image.show()
10.3 批量处理服务
import asyncio
from concurrent.futures import ThreadPoolExecutor
class BatchGenerator:
"""批量生成服务"""
def __init__(self, pipe, max_workers=4):
self.pipe = pipe
self.executor = ThreadPoolExecutor(max_workers=max_workers)
async def generate_batch(self, prompts, **kwargs):
"""异步批量生成"""
loop = asyncio.get_event_loop()
tasks = [
loop.run_in_executor(
self.executor,
self.pipe,
prompt,
**kwargs
)
for prompt in prompts
]
images = await asyncio.gather(*tasks)
return images
# 使用示例
batch_gen = BatchGenerator(pipe)
prompts = ["fire pokemon", "water pokemon", "grass pokemon"]
images = asyncio.run(batch_gen.generate_batch(prompts, seed=42))
10.4 商业化考虑
10.4.1 定价策略
按使用量计费:
- 每张图像:$0.01 - $0.05
- 批量折扣:100张以上8折
- 订阅制:$29/月无限生成
10.4.2 成本估算
GPU成本(AWS p3.2xlarge,V100):
- 每小时:$3.06
- 每张图像生成时间:约2秒
- 每小时可生成:1800张
- 每张成本:$0.0017
利润空间:定价$0.02/张,利润率约85%
10.4.3 法律和伦理
-
版权问题:
- 训练数据的版权
- 生成内容的所有权
- 商业使用许可
-
内容审核:
- 过滤不当内容
- NSFW检测
- 版权侵权检测
-
用户协议:
- 明确使用限制
- 免责声明
- 数据隐私政策
总结与展望
恭喜你完成了这个2万字的深度教程!让我们回顾一下学到的内容:
核心知识点
- 扩散模型原理:正向加噪和反向去噪过程
- VAE编解码器:压缩图像到潜在空间
- Transformer架构:自注意力机制和DiT模型
- Pipeline设计:协调各个组件的工作流程
- 训练技巧:数据准备、超参数调优、分布式训练
- 评估方法:定性和定量评估指标
- 实际应用:Web应用、API服务、商业化
进一步学习资源
论文:
- DDPM: “Denoising Diffusion Probabilistic Models” [1]
- DiT: “Scalable Diffusion Models with Transformers” [12]
- Flow Matching: “Flow Matching for Generative Modeling” [4]
代码库:
- DiffSynth-Studio: https://github.com/modelscope/DiffSynth-Studio
- Diffusers: https://github.com/huggingface/diffusers
- Stable Diffusion: https://github.com/Stability-AI/stablediffusion
在线课程:
- Fast.ai的深度学习课程
- Stanford CS231n(计算机视觉)
- Hugging Face的扩散模型课程
下一步行动
-
实践项目:
- 在自己的数据集上训练模型
- 尝试不同的模型架构
- 构建完整的应用
-
参与社区:
- 在GitHub上贡献代码
- 分享你的实验结果
- 参加AI竞赛
-
持续学习:
- 关注最新论文
- 尝试新的技术
- 与其他研究者交流
最后的话
AI图像生成是一个快速发展的领域,新的技术和方法不断涌现。本教程提供的是一个坚实的基础,但真正的掌握需要大量的实践和探索。
记住:最好的学习方式是动手实践。不要害怕犯错,每个错误都是学习的机会。从小项目开始,逐步挑战更复杂的任务,你会惊讶于自己的进步速度!
祝你在AI图像生成的旅程中取得成功!🚀
参考资料
[1]: Ho, J., Jain, A., & Abbeel, P. (2020). Denoising Diffusion Probabilistic Models. arXiv:2006.11239
[2]: Nichol, A., & Dhariwal, P. (2021). Improved Denoising Diffusion Probabilistic Models. arXiv:2102.09672
[4]: Lipman, Y., et al. (2023). Flow Matching for Generative Modeling. arXiv:2210.02747
[5]: Kingma, D. P., & Welling, M. (2013). Auto-Encoding Variational Bayes. arXiv:1312.6114
[6]: Razavi, A., et al. (2019). Generating Diverse High-Fidelity Images with VQ-VAE-2. arXiv:1906.00446
[7]: Black Forest Labs. (2024). FLUX.2 Technical Report
[8]: Vaswani, A., et al. (2017). Attention Is All You Need. arXiv:1706.03762
[9]: Dosovitskiy, A., et al. (2020). An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. arXiv:2010.11929
[10]: Bahdanau, D., et al. (2014). Neural Machine Translation by Jointly Learning to Align and Translate. arXiv:1409.0473
[11]: Lin, Z., et al. (2017). A Structured Self-attentive Sentence Embedding. arXiv:1703.03130
[12]: Peebles, W., & Xie, S. (2023). Scalable Diffusion Models with Transformers. arXiv:2212.09748
[13]: Esser, P., et al. (2024). Scaling Rectified Flow Transformers for High-Resolution Image Synthesis. arXiv:2403.03206
[14]: DiffSynth-Studio Documentation. https://diffsynth-studio-doc.readthedocs.io/
[15]: Loshchilov, I., & Hutter, F. (2017). SGDR: Stochastic Gradient Descent with Warm Restarts. arXiv:1608.03983
这篇扩充后的教程涵盖了从基础概念到高级应用的完整知识体系,适合初学者系统学习AI图像生成技术。每个概念都配有详细的解释、代码示例和实用技巧,确保读者不仅能理解理论,还能动手实践。
更多推荐


所有评论(0)