pypto.distributed.shmem_store

【免费下载链接】pypto PyPTO(发音: pai p-t-o):Parallel Tensor/Tile Operation编程范式。 【免费下载链接】pypto 项目地址: https://gitcode.com/cann/pypto

产品支持情况

产品 是否支持
Atlas A3 推理系列产品
Atlas A2 推理系列产品

功能说明

以 offsets 指定的 shared memory tensor 索引位置为基准,将输入的 Tensor 赋值到 shared memory tensor 的对应区域。

函数原型

shmem_store(
    src: Tensor,
    offsets: list[Union[int, SymbolicScalar]],
    dst: ShmemTensor,
    dst_pe: Union[int, SymbolicScalar],
    *,
    put_op: AtomicType = AtomicType.SET,
    pred: list[Tensor] = None,
) -> Tensor

参数说明

参数名 输入/输出 说明
src 输入 源操作数。
支持的数据类型为:DT_INT32,DT_FP16,DT_FP32,DT_BF16。
不支持空 Tensor;Shape 支持 2 - 4 维;Shape Size 不大于 2147483647(即 INT32_MAX)。
支持的数据格式为 ND。
offsets 输入 dst 的偏移量。
支持 int 或 SymbolicScalar 类型的列表。
offsets 的维度应与 dst 的维度一致,且每个维度的偏移量值应小于 dst 对应维度的大小。
dst 输入 目的操作数,一个 shared memory tensor,其形状为src.shape。
dst_pe 输入 shared memory tensor 所属的 pe。
支持的数据类型为 int 或 SymbolicScalar 类型。
0 <= pe < n_pes。
put_op 输入 数据传输时应用的原子操作类型。
支持的数据类型为: AtomicType.SET,AtomicType.ADD。
默认为 AtomicType.SET 类型。
pred 输入 用于控制操作执行的依赖关系张量列表。
对数据类型无要求。
不支持空 Tensor。

返回值说明

返回输出 Tensor:用于表示操作完成的依赖关系。

约束说明

  1. pred 不能包含 src,即 src 不可出现在 pred 中。
  2. src 的 dtype 必须和 dst 的 dtype 一致。

调用示例

TileShape 设置示例

说明:调用该接口前,应通过 set_vec_tile_shapes 设置 TileShape。TileShape 维度应和 src 一致。

  • 示例 1:输入的 shape 为 [m, n],TileShape 设置为 [m1, n1],则 m1,n1 分别用于切分 m,n 轴。

    pypto.set_vec_tile_shapes(4, 8)
    

接口调用示例

  • 示例 1:先创建一个 shared memory tensor。将输入数据赋值到 pe = 2 的 shared memory tensor 的指定区域,并与该视图原本的数据进行累加操作。注意,shared memory tensor 的 dtype 和 输入数据的 dtype 必须一致。

    input_tensor = pypto.tensor([16, 64], pypto.DT_BF16, "input_tensor")
    shmem_shape = input_tensor.shape
    shmem_tensor = pypto.distributed.create_shmem_tensor(group_name="tp", n_pes=8, dtype=pypto.DT_FP16, shape=shmem_shape)
    pypto.set_vec_tile_shapes(16, 64)
    store_out = pypto.experimental.shmem_store(
        src=input_tensor,
        offsets=[0, 0],
        dst=shmem_tensor,
        dst_pe=2,
        put_op=pypto.AtomicType.ADD,
    )
    
  • 示例 2:先创建一个 shared memory tensor。将输入数据赋值到 pe = 3 的 shared memory tensor 的指定区域,并覆盖该视图原本的数据。

    input_tensor = pypto.tensor([16, 64], pypto.DT_BF16, "input_tensor")
    shmem_shape = input_tensor.shape
    shmem_tensor = pypto.distributed.create_shmem_tensor(group_name="tp", n_pes=8, dtype=pypto.DT_FP16, shape=shmem_shape)
    pypto.set_vec_tile_shapes(16, 64)
    store_out = pypto.experimental.shmem_store(
        src=input_tensor,
        offsets=[0, 0],
        dst=shmem_tensor,
        dst_pe=3,
        put_op=pypto.AtomicType.SET,
    )
    

【免费下载链接】pypto PyPTO(发音: pai p-t-o):Parallel Tensor/Tile Operation编程范式。 【免费下载链接】pypto 项目地址: https://gitcode.com/cann/pypto

Logo

脑启社区是一个专注类脑智能领域的开发者社区。欢迎加入社区,共建类脑智能生态。社区为开发者提供了丰富的开源类脑工具软件、类脑算法模型及数据集、类脑知识库、类脑技术培训课程以及类脑应用案例等资源。

更多推荐