pytorch小记(二十四):PyTorch 中的 `torch.full` 全面指南
在深度学习中,有时需要创建一个所有元素都相同的张量,例如作为常数初始值或掩码。PyTorch 提供了 `torch.full` 接口,功能灵活且参数丰富。下面我们按模块逐一展开。
·
pytorch小记(二十四):PyTorch 中的 `torch.full` 全面指南
PyTorch 中的 torch.full 全面指南
在深度学习中,有时需要创建一个所有元素都相同的张量,例如作为常数初始值或掩码。PyTorch 提供了 torch.full 接口,功能灵活且参数丰富。下面我们按模块逐一展开。
一、接口定义
torch.full(*sizes,
fill_value,
out=None,
dtype=None,
layout=torch.strided,
device=None,
requires_grad=False) → Tensor
-
返回值:形状由
*sizes决定,所有位置都填fill_value的张量。 -
等价签名:
torch.full(size: Tuple[int, ...], fill_value: Number, out: Tensor = None, dtype: torch.dtype = None, layout: torch.layout = torch.strided, device: torch.device = None, requires_grad: bool = False) → Tensor
二、参数详解
| 参数 | 说明 | 示例 |
|---|---|---|
*sizes |
张量形状:多个位置参数(如 2,3,4),或一个整型元组 (2,3) |
torch.full(2,3,4, fill_value=5) |
fill_value |
要填充的值,常见为标量(int、float、bool) |
fill_value=7 |
out |
可选,已有张量写入结果,in-place;keyword-only | torch.full((2,2), 9, out=my_tensor) |
dtype |
输出数据类型;若与 fill_value 类型不匹配,会做转换 |
dtype=torch.float64 |
layout |
存储布局,默认为 torch.strided(稠密张量) |
layout=torch.strided |
device |
输出张量设备,如 "cpu"、"cuda:0" |
device='cuda:0' |
requires_grad |
是否开启梯度追踪(常用于可学习参数) | requires_grad=True |
注意:
out、dtype、layout、device、requires_grad均为 关键字参数,必须以key=value形式传入,否则会被误认为是形状(sizes)的一部分。
三、常见使用场景
-
创建常数张量
作为偏置、掩码或特殊标志值:bias = torch.full((batch_size, num_features), 0.1) mask = torch.full((H, W), True, dtype=torch.bool) -
初始化权重
固定常数初始化:self.weight = torch.full((out_channels, in_channels), fill_value=0.01, requires_grad=True) -
占位符
在复杂流程中预分配内存:out = torch.empty(3, 3) const = torch.full((3,3), 5, out=out) # 直接写入 out -
与其他 API 配合
# full_like:沿用现有张量形状 ref = torch.zeros(2,4) filled = torch.full_like(ref, 3.14) # 结果 shape=(2,4), dtype=float32
四、具体示例与输出
以下示例固定随机种子(对 full 无影响,仅为演示一致性),并展示每步输出。
import torch
torch.manual_seed(0)
# 示例 1:最基础的 (2,3) 常数张量
a = torch.full(2, 7)
# 等价于 torch.full((2,), 7)
print("a:", a)
# 输出:
# a: tensor([7, 7])
# 示例 2:二维常数(位置参数 vs tuple)
b = torch.full(2, 3, fill_value=-1) # shape=(2,), fill_value=-1
print("\nb:", b)
# b: tensor([-1, -1])
c = torch.full((2,3), 5)
print("\nc:", c)
# c:
# tensor([[5, 5, 5],
# [5, 5, 5]])
d = torch.full(2, 3, fill_value=9) # 填充 9
print("\nd:", d)
# d: tensor([9, 9])
# 示例 3:指定 dtype 和 device
e = torch.full((2,2), 3.14, dtype=torch.float64)
print("\ne:", e, "\ne.dtype =", e.dtype)
# e:
# tensor([[3.1400, 3.1400],
# [3.1400, 3.1400]], dtype=torch.float64)
# (假设有 GPU 环境)
# f = torch.full((1,3), 0, device='cuda:0')
# print("\nf.device =", f.device)
# 示例 4:使用 out 关键字
out = torch.empty(2,2)
torch.full((2,2), 42, out=out)
print("\nout(after full):", out)
# out:
# tensor([[42, 42],
# [42, 42]])
# 示例 5:requires_grad=True
g = torch.full((3,), 1.0, requires_grad=True)
print("\ng:", g, "; requires_grad =", g.requires_grad)
# g: tensor([1., 1., 1.], requires_grad=True)
五、关键字参数设计原理
在 Python 里,当函数签名中出现 *sizes 时,所有位置参数都会被收集到 sizes 这个元组里,作为张量的形状。如果把 out、dtype 等也当位置参数传入,就会被当作形状维度导致类型错误或逻辑混乱。因此,PyTorch 将它们设计成 keyword-only arguments,只允许 out=…、dtype=…、device=… 等形式出现,保证了接口的清晰与安全。
总结
- 接口灵活:
torch.full(2,3, fill_value=val)或torch.full((2,3), val)都可; - 关键字参数:
out、dtype、layout、device、requires_grad强制使用key=value形式; - 常见场景:常数张量、权重初始化、占位符分配等;
- 示例丰富:提供了形状、数据类型、设备、in-place 输出、梯度追踪等全方位示例。
更多推荐


所有评论(0)