pytorch小记(二十三):全面解读 PyTorch 的 `torch.linspace`:等差序列生成与典型应用
在深度学习与科学计算中,我们经常需要在某个区间内生成等间隔的数值序列,比如采样时间轴、插值混合、位置编码、网格采样等场景。PyTorch 为此提供了一个高效便捷的函数——`torch.linspace`。本文将从函数签名、基本用法,到进阶参数和实战案例,详细讲解 `torch.linspace` 的原理与使用技巧,帮助你在项目中灵活运用。
·
pytorch小记(二十三):全面解读 PyTorch 的 `torch.linspace`:等差序列生成与典型应用
在深度学习与科学计算中,我们经常需要在某个区间内生成等间隔的数值序列,比如采样时间轴、插值混合、位置编码、网格采样等场景。PyTorch 为此提供了一个高效便捷的函数——torch.linspace。本文将从函数签名、基本用法,到进阶参数和实战案例,详细讲解 torch.linspace 的原理与使用技巧,帮助你在项目中灵活运用。
一、函数签名与参数详解
torch.linspace(
start: float,
end: float,
steps: int = 100,
*,
out: Tensor = None,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
device: Optional[torch.device] = None,
requires_grad: bool = False
) → Tensor
start:序列起始值(包含)。end:序列结束值(包含)。steps:总共要采样的点数,默认为 100。dtype:输出张量的数据类型,若不指定则默认为torch.float32。device:张量所在设备,如 CPU 或 GPU('cpu'、'cuda:0')。requires_grad:是否开启梯度追踪,默认为False。
注意:
torch.linspace返回的是一维张量,元素均匀分布在[start, end]区间。
二、基础示例:快速上手
import torch
# 在 [0, 1] 区间内等间隔生成 5 个值
x = torch.linspace(0, 1, steps=5)
print(x)
# tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])
- 共 5 个数,每两个相邻值间隔 0.25。
- 默认 dtype =
torch.float32,device = CPU。
三、覆盖默认属性:dtype、device、梯度
-
指定数据类型
y = torch.linspace(-1, 1, steps=7, dtype=torch.float64) print(y.dtype) # torch.float64 -
指定设备
if torch.cuda.is_available(): z = torch.linspace(0, 10, steps=11, device='cuda:0') print(z.device) # cuda:0 -
开启梯度
t = torch.linspace(0, 1, steps=10, requires_grad=True) # 随后对 t 做操作并反向求导时,它会保留梯度
四、与 torch.randn 的对比
| 函数 | 参数 | 生成内容 | 典型场景 |
|---|---|---|---|
torch.randn(size) |
size |
正态分布随机数 | 随机噪声、权重初始化 |
torch.linspace |
start, end, steps |
等差序列 | 采样时间轴、插值、位置编码、网格 |
如果要自己计算等差值,也可以使用 Python 原生或 NumPy,但 torch.linspace 可直接生成 GPU 张量,并支持自动梯度和 dtype/device 管理。
五、典型应用场景
1. 插值 / 线性混合
在两个张量之间做平滑过渡:
import torch
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])
# 5 步插值
alphas = torch.linspace(0, 1, steps=5).unsqueeze(1) # [5,1]
mixes = a * (1 - alphas) + b * alphas
print(mixes)
# tensor([[1.0000, 2.0000, 3.0000],
# [1.7500, 2.7500, 3.7500],
# [2.5000, 3.5000, 4.5000],
# [3.2500, 4.2500, 5.2500],
# [4.0000, 5.0000, 6.0000]])
2. 位置编码(Transformer)
Transformer 中的正余弦位置编码需要一个从 0 到 -log(10000) 的等间隔向量:
import math
import torch
d_model = 64
# 生成 [0, -ln(10000)] 的 d_model/2 长度等差序列
div_term = torch.exp(torch.linspace(0, -math.log(10000.0), d_model // 2))
print(div_term.shape) # torch.Size([32])
3. 网格采样(Meshgrid)
搭配 torch.meshgrid 可生成二维或多维网格坐标:
import torch
x = torch.linspace(-1, 1, steps=100)
y = torch.linspace(-1, 1, steps=100)
# 生成 100×100 网格
X, Y = torch.meshgrid(x, y, indexing='ij')
print(X.shape, Y.shape) # torch.Size([100, 100]) torch.Size([100, 100])
4. 时间/坐标轴采样
在信号处理或动画渲染中,需要等时间间隔的时间点:
fs = 44100 # 采样率
n_samples = 1024
# 从 0 开始,每个采样点间隔 1/fs,总共 1024 点
t = torch.linspace(0, (n_samples - 1) / fs, steps=n_samples)
六、进阶提示
out参数:可以提前准备一个张量,避免新分配,提高性能。memory_format:在需要非连续布局时(如 channels-last)可用。- 结合自动求导:当
requires_grad=True时,对等差序列做的所有运算都可参与反向传播。
七、小结
- 核心用途:在
[start, end]区间内生成steps个等间隔数值,返回一维张量。 - 优势:无需手动计算间隔;可直接生成指定 dtype、device、supports gradient;支持 GPU 加速。
- 常见场景:插值混合、位置编码、网格采样、坐标/时间轴生成等。
更多推荐


所有评论(0)