PyTorch 中的 torch.full 全面指南

在深度学习中,有时需要创建一个所有元素都相同的张量,例如作为常数初始值或掩码。PyTorch 提供了 torch.full 接口,功能灵活且参数丰富。下面我们按模块逐一展开。


一、接口定义

torch.full(*sizes,
           fill_value,
           out=None,
           dtype=None,
           layout=torch.strided,
           device=None,
           requires_grad=False) → Tensor
  • 返回值:形状由 *sizes 决定,所有位置都填 fill_value 的张量。

  • 等价签名

    torch.full(size: Tuple[int, ...],
               fill_value: Number,
               out: Tensor = None,
               dtype: torch.dtype = None,
               layout: torch.layout = torch.strided,
               device: torch.device = None,
               requires_grad: bool = False) → Tensor
    

二、参数详解

参数 说明 示例
*sizes 张量形状:多个位置参数(如 2,3,4),或一个整型元组 (2,3) torch.full(2,3,4, fill_value=5)
fill_value 要填充的值,常见为标量(intfloatbool fill_value=7
out 可选,已有张量写入结果,in-place;keyword-only torch.full((2,2), 9, out=my_tensor)
dtype 输出数据类型;若与 fill_value 类型不匹配,会做转换 dtype=torch.float64
layout 存储布局,默认为 torch.strided(稠密张量) layout=torch.strided
device 输出张量设备,如 "cpu""cuda:0" device='cuda:0'
requires_grad 是否开启梯度追踪(常用于可学习参数) requires_grad=True

注意

  • outdtypelayoutdevicerequires_grad 均为 关键字参数,必须以 key=value 形式传入,否则会被误认为是形状(sizes)的一部分。

三、常见使用场景

  1. 创建常数张量
    作为偏置、掩码或特殊标志值:

    bias = torch.full((batch_size, num_features), 0.1)
    mask = torch.full((H, W), True, dtype=torch.bool)
    
  2. 初始化权重
    固定常数初始化:

    self.weight = torch.full((out_channels, in_channels), fill_value=0.01, requires_grad=True)
    
  3. 占位符
    在复杂流程中预分配内存:

    out = torch.empty(3, 3)
    const = torch.full((3,3), 5, out=out)  # 直接写入 out
    
  4. 与其他 API 配合

    # full_like:沿用现有张量形状
    ref = torch.zeros(2,4)
    filled = torch.full_like(ref, 3.14)  # 结果 shape=(2,4), dtype=float32
    

四、具体示例与输出

以下示例固定随机种子(对 full 无影响,仅为演示一致性),并展示每步输出。

import torch
torch.manual_seed(0)

# 示例 1:最基础的 (2,3) 常数张量
a = torch.full(2, 7)  
# 等价于 torch.full((2,), 7)
print("a:", a)  
# 输出:
# a: tensor([7, 7])

# 示例 2:二维常数(位置参数 vs tuple)
b = torch.full(2, 3, fill_value=-1)  # shape=(2,), fill_value=-1
print("\nb:", b)
# b: tensor([-1, -1])

c = torch.full((2,3), 5)
print("\nc:", c)
# c:
# tensor([[5, 5, 5],
#         [5, 5, 5]])

d = torch.full(2, 3, fill_value=9)  # 填充 9
print("\nd:", d)
# d: tensor([9, 9])

# 示例 3:指定 dtype 和 device
e = torch.full((2,2), 3.14, dtype=torch.float64)
print("\ne:", e, "\ne.dtype =", e.dtype)
# e:
# tensor([[3.1400, 3.1400],
#         [3.1400, 3.1400]], dtype=torch.float64)

# (假设有 GPU 环境)
# f = torch.full((1,3), 0, device='cuda:0')
# print("\nf.device =", f.device)

# 示例 4:使用 out 关键字
out = torch.empty(2,2)
torch.full((2,2), 42, out=out)
print("\nout(after full):", out)
# out:
# tensor([[42, 42],
#         [42, 42]])

# 示例 5:requires_grad=True
g = torch.full((3,), 1.0, requires_grad=True)
print("\ng:", g, "; requires_grad =", g.requires_grad)
# g: tensor([1., 1., 1.], requires_grad=True)

五、关键字参数设计原理

在 Python 里,当函数签名中出现 *sizes 时,所有位置参数都会被收集到 sizes 这个元组里,作为张量的形状。如果把 outdtype 等也当位置参数传入,就会被当作形状维度导致类型错误或逻辑混乱。因此,PyTorch 将它们设计成 keyword-only arguments,只允许 out=…dtype=…device=… 等形式出现,保证了接口的清晰与安全。


总结

  • 接口灵活torch.full(2,3, fill_value=val)torch.full((2,3), val) 都可;
  • 关键字参数outdtypelayoutdevicerequires_grad 强制使用 key=value 形式;
  • 常见场景:常数张量、权重初始化、占位符分配等;
  • 示例丰富:提供了形状、数据类型、设备、in-place 输出、梯度追踪等全方位示例。
Logo

脑启社区是一个专注类脑智能领域的开发者社区。欢迎加入社区,共建类脑智能生态。社区为开发者提供了丰富的开源类脑工具软件、类脑算法模型及数据集、类脑知识库、类脑技术培训课程以及类脑应用案例等资源。

更多推荐