从零实现 Attention 机制:深入理解 Transformer 的核心
本文带你从零开始,手把手实现完整的 Attention 机制,包括 Scaled Dot-Product Attention、Multi-Head Attention、Grouped Query Attention 和 KV Cache 优化。在深度学习领域,Transformer 架构已经成为了大语言模型(LLM)的基石。通过从零实现,我们不仅理解了原理,更掌握了实现细节和优化技巧。欢迎 sta

从零实现 Attention 机制:深入理解 Transformer 的核心
为什么 ChatGPT、LLaMA 这些大模型如此强大?答案就在 Attention 机制中。
本文带你从零开始,手把手实现完整的 Attention 机制,包括 Scaled Dot-Product Attention、Multi-Head Attention、Grouped Query Attention 和 KV Cache 优化。不仅理解原理,更要掌握实现细节!
💡开源代码工程:https://github.com/rixin2025/attention-from-scratch/tree/main
*** 欢迎star和讨论 ***
📑 目录
📖 前言
在深度学习领域,Transformer 架构已经成为了大语言模型(LLM)的基石。从 GPT 系列到 LLaMA、Mistral,几乎所有主流的大模型都基于 Transformer。而 Transformer 的核心,就是 Attention 机制。
然而,很多人在学习 Attention 时,往往只停留在公式层面:
Attention(Q, K, V) = softmax(QK^T / √d_k) @ V
真正理解 Attention,需要从代码实现开始!
本项目 attention-from-scratch 提供了完整的、可运行的 Attention 实现,帮助你:
- ✅ 深入理解 Attention 的数学原理和计算过程
- ✅ 掌握 Multi-Head Attention 的实现细节
- ✅ 理解 Grouped Query Attention (GQA) 的优化思想
- ✅ 学会 KV Cache 的性能优化技巧
- ✅ 为学习 TensorRT-LLM XQA 模块打下基础
🎯 项目亮点
1. 完整的实现体系
- 📝 从最基础的 Scaled Dot-Product Attention 开始
- 🔢 实现完整的 Multi-Head Attention 模块
- ⚡ 实现工业级的 Grouped Query Attention
- 🚀 实现高效的 KV Cache 优化
2. 交互式学习
- 📓 4 个精心设计的 Jupyter Notebook
- 📊 可视化演示和性能分析
- 💡 循序渐进的学习路径
3. 工程化实践
- ✅ 完整的单元测试覆盖
- 📦 清晰的代码结构和注释
- 🔧 可直接用于生产环境
🚀 快速开始
安装依赖
git clone https://github.com/rixin2025/attention-from-scratch.git
cd attention-from-scratch
pip install -r requirements.txt
运行示例
from src.attention import MultiHeadAttention
from src.gqa import GroupedQueryAttention
from src.kv_cache import MultiHeadAttentionWithCache
# Multi-Head Attention
mha = MultiHeadAttention(d_model=512, num_heads=8)
output, attn_weights = mha(x, x, x)
# Grouped Query Attention (更高效)
gqa = GroupedQueryAttention(d_model=512, num_q_heads=32, num_kv_heads=8)
output, attn_weights = gqa(x, x, x)
# 带 KV Cache 的优化版本
mha_cache = MultiHeadAttentionWithCache(
d_model=512, num_heads=8, max_seq_len=2048
)
output, attn_weights = mha_cache(x, x, x, use_cache=True)
📚 核心内容详解
1. Scaled Dot-Product Attention:Attention 的基础
核心公式:
Attention(Q, K, V) = softmax(QK^T / √d_k) @ V
为什么需要 scaling factor (√d_k)?
当 d_k 很大时,QK^T 的点积值会变得很大,导致 softmax 进入饱和区域,梯度变得很小。除以 √d_k 可以稳定训练过程。
实现要点:
def scaled_dot_product_attention(query, key, value, mask=None):
d_k = query.size(-1)
# 1. 计算注意力分数
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# 2. 应用 mask(因果 mask 或 padding mask)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# 3. Softmax 归一化
attention_weights = F.softmax(scores, dim=-1)
# 4. 加权求和
output = torch.matmul(attention_weights, value)
return output, attention_weights
可视化说明:
输入: Q [batch, heads, seq_q, d_k]
K [batch, heads, seq_k, d_k]
V [batch, heads, seq_k, d_v]
步骤1: Q @ K^T → [batch, heads, seq_q, seq_k] (注意力分数矩阵)
步骤2: softmax(分数 / √d_k) → [batch, heads, seq_q, seq_k] (注意力权重)
步骤3: 权重 @ V → [batch, heads, seq_q, d_v] (输出)
2. Multi-Head Attention:并行计算多个注意力
核心思想:
- 将输入投影到多个子空间(多个头)
- 每个头独立计算 Attention
- 最后拼接所有头的输出
为什么需要多头?
不同的头可以关注不同的信息:
- 头1:关注语法关系
- 头2:关注语义关系
- 头3:关注长距离依赖
- …
实现架构:
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
# Q, K, V 的投影矩阵
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model) # 输出投影
def forward(self, query, key, value):
# 1. 投影并分割成多个头
Q = self.W_q(query).view(..., num_heads, d_k).transpose(1, 2)
K = self.W_k(key).view(..., num_heads, d_k).transpose(1, 2)
V = self.W_v(value).view(..., num_heads, d_k).transpose(1, 2)
# 2. 每个头独立计算 Attention
attn_output, _ = scaled_dot_product_attention(Q, K, V)
# 3. 拼接所有头
attn_output = attn_output.transpose(1, 2).contiguous().view(..., d_model)
# 4. 输出投影
output = self.W_o(attn_output)
return output
参数量分析:
- Q/K/V 投影:
3 × d_model × d_model - 输出投影:
d_model × d_model - 总计:
4 × d_model²
对于 d_model=4096,参数量约为 67M。
3. Grouped Query Attention (GQA):内存与性能的平衡
问题背景:
在推理阶段,需要缓存 Key 和 Value(KV Cache)。对于 MHA:
- 32 个 Q 头 → 32 个 K 头 + 32 个 V 头
- KV Cache 内存占用巨大!
GQA 的解决方案:
让多个 Q 头共享一组 KV 头:
- MHA: 32 Q 头 → 32 K 头 + 32 V 头 (1:1)
- GQA: 32 Q 头 → 8 K 头 + 8 V 头 (4:1)
- MQA: 32 Q 头 → 1 K 头 + 1 V 头 (32:1)
内存对比(batch=32, seq_len=2048, FP16):
| 类型 | KV 头数 | 内存占用 | 相对 MHA |
|---|---|---|---|
| MHA | 32 | 512 MB | 100% |
| GQA-8 | 8 | 128 MB | 25% |
| GQA-4 | 4 | 64 MB | 12.5% |
| MQA | 1 | 16 MB | 3.1% |
实现关键:
class GroupedQueryAttention(nn.Module):
def __init__(self, d_model, num_q_heads, num_kv_heads):
# Q 投影:d_model → num_q_heads * d_k
self.W_q = nn.Linear(d_model, num_q_heads * d_k)
# K, V 投影:d_model → num_kv_heads * d_k (更少!)
self.W_k = nn.Linear(d_model, num_kv_heads * d_k)
self.W_v = nn.Linear(d_model, num_kv_heads * d_k)
def forward(self, query, key, value):
# 1. 投影
Q = self.W_q(query).view(..., num_q_heads, d_k)
K = self.W_k(key).view(..., num_kv_heads, d_k)
V = self.W_v(value).view(..., num_kv_heads, d_k)
# 2. 扩展 K, V 以匹配 Q 的头数
# 每个 KV 头复制 num_groups 次
K = K.repeat_interleave(num_groups, dim=1) # [..., num_q_heads, d_k]
V = V.repeat_interleave(num_groups, dim=1)
# 3. 计算 Attention(与 MHA 相同)
output = scaled_dot_product_attention(Q, K, V)
return output
工业界应用:
- LLaMA 2 70B: 64 Q 头,8 KV 头 (8:1)
- Mistral 7B: 32 Q 头,8 KV 头 (4:1)
- PaLM: 128 Q 头,1 KV 头 (MQA)
性能权衡:
- 参数量减少 25%(32→8 KV 头)
- KV Cache 内存减少 75%
- 模型质量保持 98%(几乎无损)
4. KV Cache:推理加速的关键优化
问题背景:
在自回归生成中(如 GPT),每次只生成一个 token,但需要 attend 到所有历史 token。如果不使用缓存:
生成第1个token: 计算 Attention(1个token, 1个token)
生成第2个token: 计算 Attention(2个token, 2个token) ← 重复计算!
生成第3个token: 计算 Attention(3个token, 3个token) ← 重复计算!
...
时间复杂度:O(n²),其中 n 是序列长度。
KV Cache 的解决方案:
缓存已计算的 Key 和 Value,避免重复计算:
Prefill 阶段(处理 prompt):
- 计算所有 token 的 K, V
- 存入缓存
Decode 阶段(生成新 token):
- 只计算新 token 的 K, V
- 从缓存读取历史的 K, V
- 拼接后计算 Attention
时间复杂度:O(n)!
实现架构:
class KVCache:
def __init__(self, batch_size, num_heads, max_seq_len, head_dim):
# 预分配缓存空间
self.k_cache = torch.zeros(batch_size, num_heads, max_seq_len, head_dim)
self.v_cache = torch.zeros(batch_size, num_heads, max_seq_len, head_dim)
self.cache_len = 0
def update(self, key, value, start_pos=None):
# 增量更新缓存
end_pos = start_pos + key.size(2)
self.k_cache[:, :, start_pos:end_pos] = key
self.v_cache[:, :, start_pos:end_pos] = value
self.cache_len = end_pos
return self.k_cache[:, :, :end_pos], self.v_cache[:, :, :end_pos]
性能提升(prompt_len=100, gen_len=100):
| Prompt 长度 | 无缓存 (ms) | 有缓存 (ms) | 加速比 |
|---|---|---|---|
| 10 | 2.5 | 1.2 | 2.1x |
| 50 | 8.3 | 1.3 | 6.4x |
| 100 | 15.7 | 1.4 | 11.2x |
| 200 | 30.2 | 1.5 | 20.1x |
两个阶段:
-
Prefill 阶段:
# 处理完整 prompt,初始化缓存 prompt = tokenize("Hello, how are you?") output, _ = model(prompt, prompt, prompt, use_cache=True, start_pos=0) -
Decode 阶段:
# 逐个生成 token,使用缓存 for i in range(max_gen_len): new_token = generate_next_token() output, _ = model(new_token, new_token, new_token, use_cache=True, start_pos=cache_len)
📊 性能对比总结
参数量对比(d_model=4096, num_heads=32)
| 类型 | Q/K/V 头数 | 参数量 | 相对 MHA |
|---|---|---|---|
| MHA | 32/32/32 | 67.1M | 100% |
| GQA-8 | 32/8/8 | 50.3M | 75% |
| GQA-4 | 32/4/4 | 41.9M | 62% |
| MQA | 32/1/1 | 33.6M | 50% |
KV Cache 内存对比(batch=32, seq_len=2048, FP16)
| 类型 | KV 头数 | 内存 (MB) | 相对 MHA |
|---|---|---|---|
| MHA | 32 | 512 | 100% |
| GQA-8 | 8 | 128 | 25% |
| GQA-4 | 4 | 64 | 12.5% |
| MQA | 1 | 16 | 3.1% |
🎓 学习路径
阶段 1:基础理论(1-2 天)
📓 Notebook: 01_scaled_dot_product.ipynb
- 理解 Attention 的数学原理
- 实现基础的 Scaled Dot-Product Attention
- 理解 scaling factor 的作用
- 掌握 mask 的使用(因果 mask、padding mask)
关键问题:
- 为什么需要除以
√d_k? - 因果 mask 如何防止信息泄露?
- Attention 如何捕捉序列依赖?
阶段 2:多头机制(1-2 天)
📓 Notebook: 02_multi_head_attention.ipynb
- 理解多头的并行计算
- 实现完整的 Multi-Head Attention
- 分析参数量和计算复杂度
- 理解自注意力 vs 交叉注意力
关键问题:
- 为什么多头比单头效果好?
- 如何高效实现多头并行计算?
- 不同头关注的信息有什么不同?
阶段 3:GQA 优化(1-2 天)
📓 Notebook: 03_grouped_query_attention.ipynb
- 理解 MHA、MQA、GQA 的区别
- 实现 Grouped Query Attention
- 分析内存和性能权衡
- 了解工业界应用案例
关键问题:
- GQA 如何在质量和效率间平衡?
- 为什么 LLaMA 2、Mistral 都选择 GQA?
- KV Cache 内存如何计算?
阶段 4:KV Cache(1-2 天)
📓 Notebook: 04_kv_cache.ipynb
- 理解 Prefill 和 Decode 阶段
- 实现增量式 KV Cache 更新
- 分析性能提升
- 理解内存占用
关键问题:
- KV Cache 如何将 O(n²) 降到 O(n)?
- Prefill 和 Decode 阶段有什么区别?
- 如何管理 KV Cache 的内存?
🔧 实际应用场景
1. 大模型推理优化
在部署 LLaMA、Mistral 等大模型时:
- 使用 GQA 减少 KV Cache 内存
- 使用 KV Cache 加速生成
- 结合 FlashAttention 进一步优化
2. 理解 TensorRT-LLM XQA
本项目是 TensorRT-LLM XQA 模块的简化版:
- 本项目: Python 实现,易于理解
- XQA: CUDA 实现,高度优化
- 学习路径: 先理解本项目,再深入 XQA
3. 自定义 Attention 变体
基于本项目,可以轻松实现:
- FlashAttention(内存高效)
- Sparse Attention(稀疏注意力)
- Longformer Attention(长序列)
💡 核心代码片段
完整的 Attention 计算流程
# 1. Scaled Dot-Product Attention
def scaled_dot_product_attention(query, key, value, mask=None):
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, value)
return output, attention_weights
# 2. Multi-Head Attention
class MultiHeadAttention(nn.Module):
def forward(self, query, key, value):
Q = self.W_q(query).view(..., num_heads, d_k).transpose(1, 2)
K = self.W_k(key).view(..., num_heads, d_k).transpose(1, 2)
V = self.W_v(value).view(..., num_heads, d_k).transpose(1, 2)
attn_output, _ = scaled_dot_product_attention(Q, K, V)
output = self.W_o(attn_output.transpose(1, 2).view(..., d_model))
return output
# 3. Grouped Query Attention
class GroupedQueryAttention(nn.Module):
def forward(self, query, key, value):
Q = self.W_q(query).view(..., num_q_heads, d_k)
K = self.W_k(key).view(..., num_kv_heads, d_k)
V = self.W_v(value).view(..., num_kv_heads, d_k)
# 扩展 KV 以匹配 Q
K = K.repeat_interleave(num_groups, dim=1)
V = V.repeat_interleave(num_groups, dim=1)
output = scaled_dot_product_attention(Q, K, V)
return output
# 4. KV Cache
class KVCache:
def update(self, key, value, start_pos):
end_pos = start_pos + key.size(2)
self.k_cache[:, :, start_pos:end_pos] = key
self.v_cache[:, :, start_pos:end_pos] = value
return self.k_cache[:, :, :end_pos], self.v_cache[:, :, :end_pos]
🚀 快速开始
1. 克隆项目
git clone https://github.com/rixin2025/attention-from-scratch.git
cd attention-from-scratch
2. 安装依赖
pip install -r requirements.txt
3. 运行 Notebooks
jupyter notebook notebooks/
学习顺序建议:
01_scaled_dot_product.ipynb- 理解基础02_multi_head_attention.ipynb- 理解多头03_grouped_query_attention.ipynb- 理解 GQA04_kv_cache.ipynb- 理解优化
4. 运行测试
pytest tests/ -v
📚 参考资料
- Attention Is All You Need - Transformer 原始论文
- GQA: Training Generalized Multi-Query Transformer Models - GQA 论文
- The Illustrated Transformer - 可视化 Transformer
- FlashAttention - 内存高效的 Attention
- vLLM - 高效的 LLM 推理框架
🤝 贡献
欢迎提交 Issue 和 Pull Request!
如果你觉得这个项目对你有帮助,请给一个 ⭐ Star,这是对我最大的鼓励!
📝 总结
Attention 机制是 Transformer 和大语言模型的核心。通过从零实现,我们不仅理解了原理,更掌握了实现细节和优化技巧。
关键要点:
- Scaled Dot-Product Attention 是基础,理解 scaling factor 的作用
- Multi-Head Attention 通过并行计算多个头,捕捉不同类型的信息
- Grouped Query Attention 在质量和效率间找到平衡,是工业界的主流选择
- KV Cache 将推理时间复杂度从 O(n²) 降到 O(n),是加速的关键
下一步:
- 深入学习 FlashAttention(python-cpp-cuda)
- 研究 PagedAttention
- cuda内存模型优化技巧
- 探索 TensorRT-LLM XQA 模块原理及工程化
- nv性能分析工具
- 实现自己的 Attention 变体
🔗 相关链接
- GitHub: https://github.com/rixin2025/attention-from-scratch
- Issues: 欢迎提交问题和建议
- Discussions: 欢迎讨论和分享经验
如果这篇文章对你有帮助,请给项目一个 ⭐ Star,让更多人看到!
让更多人了解 Attention 机制,一起推动 AI 技术的发展!
作者: jensen.li
创建时间: 2026-02-17
License: MIT
后续将基于 TensorRT-LLM XQA 模块更加深入分析,添加更多优化手段和工程化技巧介绍;欢迎 star 和讨论!
本文由mdnice多平台发布
更多推荐


所有评论(0)