CANN/ops-transformer量化闪电索引器
|产品| 是否支持 ||:----------------------------|:-----------:||<term>Ascend 950PR/Ascend 950DT</term>|√||<term>Atlas A3 训练系列产品/Atlas A3 推理系列产品</term>|√||<term>Atlas A2 训练系列产品
·
QuantLightningIndexer
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列加速卡产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
算子功能:QuantLightningIndexer是推理场景下,SparseFlashAttention(SFA)前处理的计算,选出关键的稀疏token,并对输入query和key进行量化实现存8算8,获取最大收益。
-
计算公式: $$out = \text{Top-}k\left{[1]{1\times g}@\left[(W@[1]{1\times S_{k}})\odot\text{ReLU}\left(\left(Scale_Q@Scale_K^T\right)\odot\left(Q_{index}^{Quant}@{\left(K_{index}^{Quant}\right)}^T\right)\right)\right]\right}$$ 主要计算过程为:
- 将某个token对应的输入参数
query($Q_{index}^{Quant}\in\R^{g\times d}$)乘以给定上下文key($K_{index}^{Quant}\in\R^{S_{k}\times d}$),得到相关性。 - 相关性结果与
query和key对应的反量化系数query_dequant_scale($Scale_Q$)和key_dequant_scale($Scale_K^T$)相乘,通过激活函数$ReLU$过滤无效负相关信号后,得到当前Token与所有前序Token的相关性分数向量。 - 将其与权重系数
weights($W$)相乘后,沿g的方向,选取前$Top-k$个索引值得到输出$out$,作为SparseFlashAttention的输入。
- 将某个token对应的输入参数
参数说明
说明:
- query、key、weights、query_dequant_scale、key_dequant_scale参数维度含义:B(Batch Size)表示输入样本批量大小、S(Sequence Length)表示输入样本序列长度、H(Head Size)表示hidden层的大小、N(Head Num)表示多头数、D(Head Dim)表示hidden层最小的单元尺寸,且满足D=H/N、T表示所有Batch输入样本序列长度的累加和。
- 使用S1和S2分别表示query和key的输入样本序列长度,N1和N2分别表示query和key对应的多头数,k表示最后选取的索引个数。参数query中的D和参数key中的D值相等为128。T1和T2分别表示query和key的输入样本序列长度的累加和。
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| query | 输入 |
|
INT8、FLOAT8_E4M3、HIFLOAT8 | ND |
| key | 输入 |
|
INT8、FLOAT8_E4M3、HIFLOAT8 | ND |
| weights | 输入 |
|
FLOAT16、BFLOAT16 | ND |
| query_dequant_scale | 输入 |
|
FLOAT、FLOAT16 | ND |
| key_dequant_scale | 输入 |
|
FLOAT、FLOAT16 | ND |
| actual_seq_lengths_query | 输入 |
|
INT32 | ND |
| actual_seq_lengths_key | 输入 |
|
INT32 | ND |
| block_table | 输入 |
|
INT32 | ND |
| query_quant_mode | 属性 |
|
INT64 | - |
| key_quant_mode | 属性 |
|
INT64 | - |
| layout_query | 属性 |
|
STRING | - |
| layout_key | 属性 |
|
STRING | - |
| sparse_count | 属性 |
|
INT32 | - |
| sparse_mode | 属性 |
|
INT32 | - |
| pre_tokens | 属性 | 用于稀疏计算,表示attention需要和前几个Token计算关联。仅支持默认值2^63-1。 | INT64 | - |
| next_tokens | 属性 | 用于稀疏计算,表示attention需要和后几个Token计算关联。仅支持默认值2^63-1。 | INT64 | - |
| key_stride0 | 属性 |
|
INT64 | - |
| key_dequant_scale_stride0 | 属性 |
|
INT64 | - |
| sparse_indices | 输出 |
|
INT32 | ND |
Atlas A3 训练系列产品/Atlas A3 推理系列产品:
- query和key的数据类型支持
INT8。 - 仅支持weights、query_dequant_scale、key_dequant_scale数据类型为
FLOAT16、FLOAT16、FLOAT16。
Ascend 950PR/Ascend 950DT:
- query N1仅支持8、16、24、32、64。
- query和key的数据类型支持
FLOAT8_E4M3、HIFLOAT8、INT8。 - 当query和key的数据类型为
FLOAT8_E4M3时,支持weights、query_dequant_scale、key_dequant_scale的数据类型为BFLOAT16、FLOAT、FLOAT或FLOAT16、FLOAT16、FLOAT16; - 当query和key的数据类型为
HIFLOAT8时,仅支持weights、query_dequant_scale、key_dequant_scale数据类型为BFLOAT16、FLOAT、FLOAT; - 当query和key的数据类型为
INT8时,仅支持weights、query_dequant_scale、key_dequant_scale数据类型为FLOAT16、FLOAT16、FLOAT16。
约束说明
- 该接口支持图模式。
- 该接口要求$W \odot Scale_Q$的结果在
float16的表示范围内。 - 该接口的TopK过程对NAN排序是未定义行为。
调用示例
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| 图模式 | test_npu_quant_lightning_indexer | 通过算子IR构图方式调用npu_quant_lightning_indexer算子 |
| aclnn接口 | test_aclnn_quant_lightning_indexer | 通过 aclnnQuantLightningIndexer 接口方式调用算子 |
更多推荐



所有评论(0)