Skip to content

Commit d9f7611

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
Farthest point sampling C++
Summary: C++ implementation of iterative farthest point sampling. Reviewed By: jcjohnson Differential Revision: D30349887 fbshipit-source-id: d25990f857752633859fe00283e182858a870269
1 parent 3b7d78c commit d9f7611

File tree

6 files changed

+346
-19
lines changed

6 files changed

+346
-19
lines changed

pytorch3d/csrc/ext.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "rasterize_meshes/rasterize_meshes.h"
2828
#include "rasterize_points/rasterize_points.h"
2929
#include "sample_pdf/sample_pdf.h"
30+
#include "sample_farthest_points/sample_farthest_points.h"
3031

3132
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
3233
m.def("face_areas_normals_forward", &FaceAreasNormalsForward);
@@ -40,9 +41,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
4041
#endif
4142
m.def("knn_points_idx", &KNearestNeighborIdx);
4243
m.def("knn_points_backward", &KNearestNeighborBackward);
43-
44-
// Ball Query
4544
m.def("ball_query", &BallQuery);
45+
m.def("sample_farthest_points", &FarthestPointSampling);
4646
m.def(
4747
"mesh_normal_consistency_find_verts", &MeshNormalConsistencyFindVertices);
4848
m.def("gather_scatter", &GatherScatter);
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <torch/extension.h>
10+
#include <iterator>
11+
#include <random>
12+
#include <vector>
13+
14+
at::Tensor FarthestPointSamplingCpu(
15+
const at::Tensor& points,
16+
const at::Tensor& lengths,
17+
const at::Tensor& K,
18+
const bool random_start_point) {
19+
// Get constants
20+
const int64_t N = points.size(0);
21+
const int64_t P = points.size(1);
22+
const int64_t D = points.size(2);
23+
const int64_t max_K = torch::max(K).item<int64_t>();
24+
25+
// Initialize an output array for the sampled indices
26+
// of shape (N, max_K)
27+
auto long_opts = lengths.options();
28+
torch::Tensor sampled_indices = torch::full({N, max_K}, -1, long_opts);
29+
30+
// Create accessors for all tensors
31+
auto points_a = points.accessor<float, 3>();
32+
auto lengths_a = lengths.accessor<int64_t, 1>();
33+
auto k_a = K.accessor<int64_t, 1>();
34+
auto sampled_indices_a = sampled_indices.accessor<int64_t, 2>();
35+
36+
// Initialize a mask to prevent duplicates
37+
// If true, the point has already been selected.
38+
std::vector<unsigned char> selected_points_mask(P, false);
39+
40+
// Initialize to infinity a vector of
41+
// distances from each point to any of the previously selected points
42+
std::vector<float> dists(P, std::numeric_limits<float>::max());
43+
44+
// Initialize random number generation for random starting points
45+
std::random_device rd;
46+
std::default_random_engine eng(rd());
47+
48+
for (int64_t n = 0; n < N; ++n) {
49+
// Resize and reset points mask and distances for each batch
50+
selected_points_mask.resize(lengths_a[n]);
51+
dists.resize(lengths_a[n]);
52+
std::fill(selected_points_mask.begin(), selected_points_mask.end(), false);
53+
std::fill(dists.begin(), dists.end(), std::numeric_limits<float>::max());
54+
55+
// Select a starting point index and save it
56+
std::uniform_int_distribution<int> distr(0, lengths_a[n] - 1);
57+
int64_t last_idx = random_start_point ? distr(eng) : 0;
58+
sampled_indices_a[n][0] = last_idx;
59+
60+
// Set the value of the mask at this point to false
61+
selected_points_mask[last_idx] = true;
62+
63+
// For heterogeneous pointclouds, use the minimum of the
64+
// length for that cloud compared to K as the number of
65+
// points to sample
66+
const int64_t batch_k = std::min(lengths_a[n], k_a[n]);
67+
68+
// Iteratively select batch_k points per batch
69+
for (int64_t k = 1; k < batch_k; ++k) {
70+
// Iterate through all the points
71+
for (int64_t p = 0; p < lengths_a[n]; ++p) {
72+
if (selected_points_mask[p]) {
73+
// For already selected points set the distance to 0.0
74+
dists[p] = 0.0;
75+
continue;
76+
}
77+
78+
// Calculate the distance to the last selected point
79+
float dist2 = 0.0;
80+
for (int64_t d = 0; d < D; ++d) {
81+
float diff = points_a[n][last_idx][d] - points_a[n][p][d];
82+
dist2 += diff * diff;
83+
}
84+
85+
// If the distance of this point to the last selected point is closer
86+
// than the distance to any of the previously selected points, then
87+
// update this distance
88+
if (dist2 < dists[p]) {
89+
dists[p] = dist2;
90+
}
91+
}
92+
93+
// The aim is to pick the point that has the largest
94+
// nearest neighbour distance to any of the already selected points
95+
auto itr = std::max_element(dists.begin(), dists.end());
96+
last_idx = std::distance(dists.begin(), itr);
97+
98+
// Save selected point
99+
sampled_indices_a[n][k] = last_idx;
100+
101+
// Set the mask value to true to prevent duplicates.
102+
selected_points_mask[last_idx] = true;
103+
}
104+
}
105+
106+
return sampled_indices;
107+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
#include <torch/extension.h>
11+
#include <tuple>
12+
#include "utils/pytorch3d_cutils.h"
13+
14+
// Iterative farthest point sampling algorithm [1] to subsample a set of
15+
// K points from a given pointcloud. At each iteration, a point is selected
16+
// which has the largest nearest neighbor distance to any of the
17+
// already selected points.
18+
19+
// Farthest point sampling provides more uniform coverage of the input
20+
// point cloud compared to uniform random sampling.
21+
22+
// [1] Charles R. Qi et al, "PointNet++: Deep Hierarchical Feature Learning
23+
// on Point Sets in a Metric Space", NeurIPS 2017.
24+
25+
// Args:
26+
// points: (N, P, D) float32 Tensor containing the batch of pointclouds.
27+
// lengths: (N,) long Tensor giving the number of points in each pointcloud
28+
// (to support heterogeneous batches of pointclouds).
29+
// K: a tensor of length (N,) giving the number of
30+
// samples to select for each element in the batch.
31+
// The number of samples is typically << P.
32+
// random_start_point: bool, if True, a random point is selected as the
33+
// starting point for iterative sampling.
34+
// Returns:
35+
// selected_indices: (N, K) array of selected indices. If the values in
36+
// K are not all the same, then the shape will be (N, max(K), D), and
37+
// padded with -1 for batch elements where k_i < max(K). The selected
38+
// points are gathered in the pytorch autograd wrapper.
39+
40+
at::Tensor FarthestPointSamplingCpu(
41+
const at::Tensor& points,
42+
const at::Tensor& lengths,
43+
const at::Tensor& K,
44+
const bool random_start_point);
45+
46+
// Exposed implementation.
47+
at::Tensor FarthestPointSampling(
48+
const at::Tensor& points,
49+
const at::Tensor& lengths,
50+
const at::Tensor& K,
51+
const bool random_start_point) {
52+
if (points.is_cuda() || lengths.is_cuda() || K.is_cuda()) {
53+
AT_ERROR("CUDA implementation not yet supported");
54+
}
55+
return FarthestPointSamplingCpu(points, lengths, K, random_start_point);
56+
}

pytorch3d/ops/sample_farthest_points.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
from typing import Optional, Tuple, Union, List
99

1010
import torch
11+
from pytorch3d import _C
1112

1213
from .utils import masked_gather
1314

1415

15-
def sample_farthest_points_naive(
16+
def sample_farthest_points(
1617
points: torch.Tensor,
1718
lengths: Optional[torch.Tensor] = None,
1819
K: Union[int, List, torch.Tensor] = 50,
@@ -34,7 +35,7 @@ def sample_farthest_points_naive(
3435
points: (N, P, D) array containing the batch of pointclouds
3536
lengths: (N,) number of points in each pointcloud (to support heterogeneous
3637
batches of pointclouds)
37-
K: samples you want in each sampled point cloud (this is typically << P). If
38+
K: samples required in each sampled point cloud (this is typically << P). If
3839
K is an int then the same number of samples are selected for each
3940
pointcloud in the batch. If K is a tensor is should be length (N,)
4041
giving the number of samples to select for each element in the batch
@@ -52,6 +53,50 @@ def sample_farthest_points_naive(
5253
N, P, D = points.shape
5354
device = points.device
5455

56+
# Validate inputs
57+
if lengths is None:
58+
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
59+
60+
if lengths.shape[0] != N:
61+
raise ValueError("points and lengths must have same batch dimension.")
62+
63+
# TODO: support providing K as a ratio of the total number of points instead of as an int
64+
if isinstance(K, int):
65+
K = torch.full((N,), K, dtype=torch.int64, device=device)
66+
elif isinstance(K, list):
67+
K = torch.tensor(K, dtype=torch.int64, device=device)
68+
69+
if K.shape[0] != N:
70+
raise ValueError("K and points must have the same batch dimension")
71+
72+
# Check dtypes are correct and convert if necessary
73+
if not (points.dtype == torch.float32):
74+
points = points.to(torch.float32)
75+
if not (lengths.dtype == torch.int64):
76+
lengths = lengths.to(torch.int64)
77+
if not (K.dtype == torch.int64):
78+
K = K.to(torch.int64)
79+
80+
with torch.no_grad():
81+
# pyre-fixme[16]: `pytorch3d_._C` has no attribute `sample_farthest_points`.
82+
idx = _C.sample_farthest_points(points, lengths, K, random_start_point)
83+
sampled_points = masked_gather(points, idx)
84+
85+
return sampled_points, idx
86+
87+
88+
def sample_farthest_points_naive(
89+
points: torch.Tensor,
90+
lengths: Optional[torch.Tensor] = None,
91+
K: Union[int, List, torch.Tensor] = 50,
92+
random_start_point: bool = False,
93+
) -> Tuple[torch.Tensor, torch.Tensor]:
94+
"""
95+
Same Args/Returns as sample_farthest_points
96+
"""
97+
N, P, D = points.shape
98+
device = points.device
99+
55100
# Validate inputs
56101
if lengths is None:
57102
lengths = torch.full((N,), P, dtype=torch.int64, device=device)

tests/bm_sample_farthest_points.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from itertools import product
8+
9+
from fvcore.common.benchmark import benchmark
10+
from test_sample_farthest_points import TestFPS
11+
12+
13+
def bm_fps() -> None:
14+
kwargs_list = []
15+
backends = ["cpu", "cuda:0"]
16+
Ns = [8, 32]
17+
Ps = [64, 256]
18+
Ds = [3]
19+
Ks = [24]
20+
test_cases = product(Ns, Ps, Ds, Ks, backends)
21+
for case in test_cases:
22+
N, P, D, K, d = case
23+
kwargs_list.append({"N": N, "P": P, "D": D, "K": K, "device": d})
24+
25+
benchmark(
26+
TestFPS.sample_farthest_points_naive,
27+
"FPS_NAIVE_PYTHON",
28+
kwargs_list,
29+
warmup_iters=1,
30+
)
31+
32+
kwargs_list = [k for k in kwargs_list if k["device"] == "cpu"]
33+
benchmark(TestFPS.sample_farthest_points, "FPS_CPU", kwargs_list, warmup_iters=1)
34+
35+
36+
if __name__ == "__main__":
37+
bm_fps()

0 commit comments

Comments
 (0)