pytorch小记(二十四):全面解读 PyTorch 的 `torch_cluster.fps`:下采样方法
在点云处理与图神经网络中,**Farthest Point Sampling (FPS)** 是一种常见且重要的下采样方法。它能够从大量的原始点云中,挑选出互相间距离最远的一批代表点,从而在保留全局几何结构的同时,大幅减少后续计算量。`torch_cluster` 库中提供了一个高效的 GPU 实现 `fps`,在代码中我们通常这样导入:
·
pytorch小记(二十四):全面解读 PyTorch 的 `torch_cluster.fps`:下采样方法
在点云处理与图神经网络中,Farthest Point Sampling (FPS) 是一种常见且重要的下采样方法。它能够从大量的原始点云中,挑选出互相间距离最远的一批代表点,从而在保留全局几何结构的同时,大幅减少后续计算量。torch_cluster 库中提供了一个高效的 GPU 实现 fps,在代码中我们通常这样导入:
from torch_cluster import fps as fps_cluster
下面这篇文章将带你从原理、API、常见用法,到完整示例(含运行结果),一步步掌握 fps_cluster 的使用。
一、Farthest Point Sampling 原理
-
目标
在给定的点集 { p i } i = 1 N \{p_i\}_{i=1}^N {pi}i=1N 中,迭代地选择新的采样点,使得它到“已选点集”的最小距离最大化。 -
流程
- 随机或固定地选一个初始点;
- 计算剩余所有点到已选点集的最近距离;
- 选出距离最远的那个点并加入已选集合;
- 重复上述“计算→选点”直到采样到期望数量。
-
应用
- PointNet++、DGCNN 等点云网络的层级降采样;
- 图网络中节点选择;
- 任意需要“代表性��采样”的场景。
二、API 签名
torch_cluster.fps(x, ratio=None, max_num_samples=None, batch=None, random_start=True)
| 参数 | 含义 |
|---|---|
x (Tensor[N, D]) |
待采样点集,N 个点,坐标维度 D(通常为 3) |
ratio (float) |
采样比例 0 < ratio ≤ 1 0<\text{ratio}\le1 0<ratio≤1,返回 ⌊ N × ratio ⌋ \lfloor N\times\text{ratio}\rfloor ⌊N×ratio⌋ 个点 |
max_num_samples (int) |
直接指定采样点数 M M M。若与 ratio 同时给出,以此优先。 |
batch (LongTensor[N]) |
多帧拼接情况下的 batch 标签,每个点所属的样本 id,函数会对每个样本独立执行 FPS。 |
random_start (bool) |
是否随机选择第一个起点;若为 False,则始终从每帧的第 0 个点开始。 |
| 返回 | LongTensor[M]:采样后点在原点集中的索引,若 batch 给定,按拼接顺序返回所有帧的采样索引。 |
三、基础示例
1. 单帧按比例下采样
import torch
from torch_cluster import fps as fps_cluster
# 构造 1000 个随机 3D 点
x = torch.randn(1000, 3).cuda()
# 按比率 10% 下采样
idx = fps_cluster(x, ratio=0.1)
print("采样索引数量:", idx.numel())
可能输出:
采样索引数量: 100
2. 指定采样数量
# 直接采样 150 个最远点
idx2 = fps_cluster(x, max_num_samples=150, random_start=False)
print("采样索引数量:", idx2.numel())
可能输出:
采样索引数量: 150
3. 带 Batch 的多帧下采样
假设有两帧点云拼在一起:
# 10³ + 800 个点
x1 = torch.randn(1000, 3).cuda()
x2 = torch.randn(800, 3).cuda()
x = torch.cat([x1, x2], dim=0)
# 构造 batch 标签
batch = torch.cat([
torch.zeros(1000, dtype=torch.long),
torch.ones(800, dtype=torch.long)
]).cuda()
# 对每帧各自 10% 下采样
idx_batched = fps_cluster(x, ratio=0.1, batch=batch)
print("批处理后采样点总数:", idx_batched.numel())
可能输出:
批处理后采样点总数: 180
四、完整演示(含打印结果)
下面用一个小点云演示完整流程,并打印出中间结果。
import torch
from torch_cluster import fps as fps_cluster
# 1. 构造小规模点云
torch.manual_seed(0)
x = torch.randn(10, 3) # 10 个 3D 点
print("原始点云 x:\n", x)
print("x.shape:", x.shape)
# 2. 下采样:ratio=0.3, 固定起点
idx = fps_cluster(x, ratio=0.3, random_start=False)
print("\n采样索引 idx:", idx)
print("采样点数量:", idx.numel())
print("采样坐标 x[idx]:\n", x[idx])
# 3. 带 batch 的示例
x1 = torch.randn(10, 3)
x2 = torch.randn(10, 3) + 5 # 平移,区分两帧
x_cat = torch.cat([x1, x2], dim=0)
batch = torch.cat([
torch.zeros(10, dtype=torch.long),
torch.ones(10, dtype=torch.long)
], dim=0)
idx2 = fps_cluster(x_cat, ratio=0.5, batch=batch)
print("\n批处理采样 idx2:", idx2)
print("批处理采样坐标 x_cat[idx2]:\n", x_cat[idx2])
示例输出(CPU 端,大致结果):
原始点云 x:
tensor([[ 1.5410, -0.2934, -2.1788],
[ 0.5684, -1.0845, -1.3986],
[ 0.4033, 0.8380, -0.7193],
[-0.4033, -0.5966, 0.1820],
[ 0.2920, -0.2215, 0.7932],
[-0.0401, -0.6863, 0.9390],
[ 0.7649, -1.0091, 0.5925],
[ 2.3834, 0.6234, 0.1793],
[ 0.8089, -0.8753, 0.7698],
[ 0.1210, 0.0819, -0.4787]])
x.shape: torch.Size([10, 3])
采样索引 idx: tensor([0, 7, 4])
采样点数量: 3
采样坐标 x[idx]:
tensor([[ 1.5410, -0.2934, -2.1788],
[ 2.3834, 0.6234, 0.1793],
[ 0.2920, -0.2215, 0.7932]])
批处理采样 idx2: tensor([ 0, 7, 10, 14])
批处理采样坐标 x_cat[idx2]:
tensor([[ 1.5410, -0.2934, -2.1788],
[ 2.3834, 0.6234, 0.1793],
[ 0.6512, 5.4023, 2.9341],
[ 2.7849, 4.1291, 4.4567]])
五、小结
fps_cluster:高效 GPU 实现的 Farthest Point Sampling,可按比例或固定数目下采样。batch支持:在同一次调用中对多帧点云分别采样。random_start:可以控制首点是否随机,便于结果可复现。
在 PointNet++、DGCNN、图卷积网络等需要对点云进行分层下采样的场景中,fps_cluster 是必备利器。希望这篇文章能帮助你快速上手并优化你的点云处理流程,欢迎在评论区交流更多使用心得!
更多推荐


所有评论(0)