QuantLightningIndexer

【免费下载链接】ops-transformer 本项目是CANN提供的transformer类大模型算子库,实现网络在NPU上加速计算。 【免费下载链接】ops-transformer 项目地址: https://gitcode.com/cann/ops-transformer

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列加速卡产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:QuantLightningIndexer是推理场景下,SparseFlashAttention(SFA)前处理的计算,选出关键的稀疏token,并对输入query和key进行量化实现存8算8,获取最大收益。

  • 计算公式: $$out = \text{Top-}k\left{[1]{1\times g}@\left[(W@[1]{1\times S_{k}})\odot\text{ReLU}\left(\left(Scale_Q@Scale_K^T\right)\odot\left(Q_{index}^{Quant}@{\left(K_{index}^{Quant}\right)}^T\right)\right)\right]\right}$$ 主要计算过程为:

    1. 将某个token对应的输入参数query($Q_{index}^{Quant}\in\R^{g\times d}$)乘以给定上下文key($K_{index}^{Quant}\in\R^{S_{k}\times d}$),得到相关性。
    2. 相关性结果与querykey对应的反量化系数query_dequant_scale($Scale_Q$)和key_dequant_scale($Scale_K^T$)相乘,通过激活函数$ReLU$过滤无效负相关信号后,得到当前Token与所有前序Token的相关性分数向量。
    3. 将其与权重系数weights($W$)相乘后,沿g的方向,选取前$Top-k$个索引值得到输出$out$,作为SparseFlashAttention的输入。

参数说明

说明:

  • query、key、weights、query_dequant_scale、key_dequant_scale参数维度含义:B(Batch Size)表示输入样本批量大小、S(Sequence Length)表示输入样本序列长度、H(Head Size)表示hidden层的大小、N(Head Num)表示多头数、D(Head Dim)表示hidden层最小的单元尺寸,且满足D=H/N、T表示所有Batch输入样本序列长度的累加和。
  • 使用S1和S2分别表示query和key的输入样本序列长度,N1和N2分别表示query和key对应的多头数,k表示最后选取的索引个数。参数query中的D和参数key中的D值相等为128。T1和T2分别表示query和key的输入样本序列长度的累加和。
参数名 输入/输出/属性 描述 数据类型 数据格式
query 输入
  • 公式中的输入Q。
  • 不支持非连续。
  • layout_query为BSND时,shape为(B,S1,N1,D)。layout_query为TND时,shape为(T1,N1,D)。
  • N1支持[1, 64]。
INT8、FLOAT8_E4M3、HIFLOAT8 ND
key 输入
  • 公式中的输入K。
  • 支持非连续。
  • layout_key为PA_BSND时,shape为(block_num, block_size, N2, D)。layout_kv为BSND时,shape为(B, S2, N2, D)。layout_kv为TND时,shape为(T2, N2, D)。
  • block_num为PageAttention时block总数,block_size为一个block的token数。
  • N2仅支持1。
INT8、FLOAT8_E4M3、HIFLOAT8 ND
weights 输入
  • 公式中的输入W。
  • 不支持非连续。
  • layout_query为BSND时,shape为(B,S1,N1)。layout_query为TND时,shape为(T1,N1)。
FLOAT16、BFLOAT16 ND
query_dequant_scale 输入
  • 公式中Query的反量化系数Scale_Q。
  • 不支持非连续。
  • layout_query为BSND时,shape为(B,S1,N1)。layout_query为TND时,shape为(T1,N1)。
FLOAT、FLOAT16 ND
key_dequant_scale 输入
  • 公式中Key的反量化系数Scale_K。
  • 支持非连续。
  • layout_key为BSND时,shape为(B,S2,N2)。layout_key为TND时,shape为(T2,N2)。
  • layout_key为PA_BSND时,shape为(block_num, block_size, N2)。
  • block_num为PageAttention时block总数,block_size为一个block的token数。
FLOAT、FLOAT16 ND
actual_seq_lengths_query 输入
  • 每个Batch中,Query的有效token数。
  • 不支持非连续。
  • shape为(B,)
  • 如果不指定seqlen可传入None,表示和query的shape的S长度相同。
  • 该入参中每个Batch的有效token数不超过query中的维度S大小且不小于0,支持长度为B的一维tensor。
  • 当layout_query为TND时,该入参必须传入,且以该入参元素的数量作为B值,该入参中每个元素的值表示当前batch与之前所有batch的token数总和,即前缀和,因此后一个元素的值必须大于等于前一个元素的值。
  • 不能出现负值。
INT32 ND
actual_seq_lengths_key 输入
  • 每个Batch中,Key的有效token数。
  • 不支持非连续。
  • shape为(B,)
  • 如果不指定seqlen可传入None,表示和key的shape的S长度相同。
  • 该参数中每个Batch的有效token数不超过key/value中的维度S大小且不小于0,支持长度为B的一维tensor。
  • 当layout_key为TND或PA_BSND时,该入参必须传入,layout_key为TND,该参数中每个元素的值表示当前batch与之前所有batch的token数总和,即前缀和,因此后一个元素的值必须大于等于前一个元素的值。
INT32 ND
block_table 输入
  • 表示PageAttention中KV存储使用的block映射表。
  • 不支持非连续。
  • shape支持(B,S2_max/block_size)
  • PageAttention场景下,block_table必须为二维,第一维长度需要等于B,第二维长度不能小于maxBlockNumPerSeq(maxBlockNumPerSeq为每个batch中最大actual_seq_lengths_key对应的block数量)
  • block_size取值为16的整数倍,最大支持到1024。
INT32 ND
query_quant_mode 属性
  • 用于标识输入Query的量化模式。
  • 当前支持Per-Token-Head量化模式。
  • 当前仅支持传入0。
INT64 -
key_quant_mode 属性
  • 用于标识输入Key的量化模式。
  • 当前支持Per-Token-Head量化模式。
  • 当前仅支持传入0。
INT64 -
layout_query 属性
  • 用于标识输入Query的数据排布格式。
  • 当前支持BSND、TND。
  • 默认值为BSND。
STRING -
layout_key 属性
  • 用于标识输入Key的数据排布格式。
  • 当前支持PA_BSND、BSND、TND。
  • 在非PageAttention场景下,layout_key应与layout_query保持一致。
  • 默认值为BSND。
STRING -
sparse_count 属性
  • topK阶段需要保留的block数量。支持[1, 2048]。
  • 默认值为2048。
INT32 -
sparse_mode 属性
  • 表示sparse的模式。
  • sparse_mode为0时,代表defaultMask模式。
  • sparse_mode为3时,代表rightDownCausal模式的mask,对应以右顶点为划分的下三角场景。
  • 默认值为3。
INT32 -
pre_tokens 属性 用于稀疏计算,表示attention需要和前几个Token计算关联。仅支持默认值2^63-1。 INT64 -
next_tokens 属性 用于稀疏计算,表示attention需要和后几个Token计算关联。仅支持默认值2^63-1。 INT64 -
key_stride0 属性
  • 表示key获取stride第0维的信息。
  • 默认值为-1。
INT64 -
key_dequant_scale_stride0 属性
  • 表示key_dequant_scale获取stride第0维的信息。
  • 默认值为-1。
INT64 -
sparse_indices 输出
  • 公式中的Indices输出。
  • layout_query为"BSND"时输出shape为[B, S1, N2, sparse_count]。layout_query为"TND"时输出shape为[T1, N2, sparse_count]。
INT32 ND

Atlas A3 训练系列产品/Atlas A3 推理系列产品:

  • query和key的数据类型支持INT8
  • 仅支持weights、query_dequant_scale、key_dequant_scale数据类型为FLOAT16、FLOAT16、FLOAT16

Ascend 950PR/Ascend 950DT:

  • query N1仅支持8、16、24、32、64。
  • query和key的数据类型支持FLOAT8_E4M3、HIFLOAT8、INT8
  • 当query和key的数据类型为FLOAT8_E4M3时,支持weights、query_dequant_scale、key_dequant_scale的数据类型为BFLOAT16、FLOAT、FLOATFLOAT16、FLOAT16、FLOAT16
  • 当query和key的数据类型为HIFLOAT8时,仅支持weights、query_dequant_scale、key_dequant_scale数据类型为BFLOAT16、FLOAT、FLOAT
  • 当query和key的数据类型为INT8时,仅支持weights、query_dequant_scale、key_dequant_scale数据类型为FLOAT16、FLOAT16、FLOAT16

约束说明

  • 该接口支持图模式。
  • 该接口要求$W \odot Scale_Q$的结果在float16的表示范围内。
  • 该接口的TopK过程对NAN排序是未定义行为。

调用示例

调用方式 样例代码 说明
图模式 test_npu_quant_lightning_indexer 通过算子IR构图方式调用npu_quant_lightning_indexer算子
aclnn接口 test_aclnn_quant_lightning_indexer 通过 aclnnQuantLightningIndexer 接口方式调用算子

【免费下载链接】ops-transformer 本项目是CANN提供的transformer类大模型算子库,实现网络在NPU上加速计算。 【免费下载链接】ops-transformer 项目地址: https://gitcode.com/cann/ops-transformer

Logo

脑启社区是一个专注类脑智能领域的开发者社区。欢迎加入社区,共建类脑智能生态。社区为开发者提供了丰富的开源类脑工具软件、类脑算法模型及数据集、类脑知识库、类脑技术培训课程以及类脑应用案例等资源。

更多推荐