在深度学习与科学计算中,我们经常需要在某个区间内生成等间隔的数值序列,比如采样时间轴、插值混合、位置编码、网格采样等场景。PyTorch 为此提供了一个高效便捷的函数——torch.linspace。本文将从函数签名、基本用法,到进阶参数和实战案例,详细讲解 torch.linspace 的原理与使用技巧,帮助你在项目中灵活运用。


一、函数签名与参数详解

torch.linspace(
    start: float,
    end: float,
    steps: int = 100,
    *,
    out: Tensor = None,
    dtype: Optional[torch.dtype] = None,
    layout: torch.layout = torch.strided,
    device: Optional[torch.device] = None,
    requires_grad: bool = False
) → Tensor
  • start:序列起始值(包含)。
  • end:序列结束值(包含)。
  • steps:总共要采样的点数,默认为 100。
  • dtype:输出张量的数据类型,若不指定则默认为 torch.float32
  • device:张量所在设备,如 CPU 或 GPU('cpu''cuda:0')。
  • requires_grad:是否开启梯度追踪,默认为 False

注意:torch.linspace 返回的是一维张量,元素均匀分布在 [start, end] 区间。


二、基础示例:快速上手

import torch

# 在 [0, 1] 区间内等间隔生成 5 个值
x = torch.linspace(0, 1, steps=5)
print(x)
# tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])
  • 共 5 个数,每两个相邻值间隔 0.25。
  • 默认 dtype = torch.float32,device = CPU。

三、覆盖默认属性:dtype、device、梯度

  1. 指定数据类型

    y = torch.linspace(-1, 1, steps=7, dtype=torch.float64)
    print(y.dtype)  # torch.float64
    
  2. 指定设备

    if torch.cuda.is_available():
        z = torch.linspace(0, 10, steps=11, device='cuda:0')
        print(z.device)  # cuda:0
    
  3. 开启梯度

    t = torch.linspace(0, 1, steps=10, requires_grad=True)
    # 随后对 t 做操作并反向求导时,它会保留梯度
    

四、与 torch.randn 的对比

函数 参数 生成内容 典型场景
torch.randn(size) size 正态分布随机数 随机噪声、权重初始化
torch.linspace start, end, steps 等差序列 采样时间轴、插值、位置编码、网格

如果要自己计算等差值,也可以使用 Python 原生或 NumPy,但 torch.linspace 可直接生成 GPU 张量,并支持自动梯度和 dtype/device 管理。


五、典型应用场景

1. 插值 / 线性混合

在两个张量之间做平滑过渡:

import torch

a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])
# 5 步插值
alphas = torch.linspace(0, 1, steps=5).unsqueeze(1)  # [5,1]
mixes = a * (1 - alphas) + b * alphas
print(mixes)
# tensor([[1.0000, 2.0000, 3.0000],
#         [1.7500, 2.7500, 3.7500],
#         [2.5000, 3.5000, 4.5000],
#         [3.2500, 4.2500, 5.2500],
#         [4.0000, 5.0000, 6.0000]])

2. 位置编码(Transformer)

Transformer 中的正余弦位置编码需要一个从 0-log(10000) 的等间隔向量:

import math
import torch

d_model = 64
# 生成 [0, -ln(10000)] 的 d_model/2 长度等差序列
div_term = torch.exp(torch.linspace(0, -math.log(10000.0), d_model // 2))
print(div_term.shape)  # torch.Size([32])

3. 网格采样(Meshgrid)

搭配 torch.meshgrid 可生成二维或多维网格坐标:

import torch

x = torch.linspace(-1, 1, steps=100)
y = torch.linspace(-1, 1, steps=100)
# 生成 100×100 网格
X, Y = torch.meshgrid(x, y, indexing='ij')
print(X.shape, Y.shape)  # torch.Size([100, 100]) torch.Size([100, 100])

4. 时间/坐标轴采样

在信号处理或动画渲染中,需要等时间间隔的时间点:

fs = 44100            # 采样率
n_samples = 1024
# 从 0 开始,每个采样点间隔 1/fs,总共 1024 点
t = torch.linspace(0, (n_samples - 1) / fs, steps=n_samples)

六、进阶提示

  • out 参数:可以提前准备一个张量,避免新分配,提高性能。
  • memory_format:在需要非连续布局时(如 channels-last)可用。
  • 结合自动求导:当 requires_grad=True 时,对等差序列做的所有运算都可参与反向传播。

七、小结

  • 核心用途:在 [start, end] 区间内生成 steps 个等间隔数值,返回一维张量。
  • 优势:无需手动计算间隔;可直接生成指定 dtype、device、supports gradient;支持 GPU 加速。
  • 常见场景:插值混合、位置编码、网格采样、坐标/时间轴生成等。
Logo

脑启社区是一个专注类脑智能领域的开发者社区。欢迎加入社区,共建类脑智能生态。社区为开发者提供了丰富的开源类脑工具软件、类脑算法模型及数据集、类脑知识库、类脑技术培训课程以及类脑应用案例等资源。

更多推荐