Skip to content

Commit bd04ffa

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
Farthest point sampling CUDA
Summary: CUDA implementation of farthest point sampling algorithm. ## Visual comparison Compared to random sampling, farthest point sampling gives better coverage of the shape. {F658631262} ## Reduction Parallelized block reduction to find the max value at each iteration happens as follows: 1. First split the points into two equal sized parts (e.g. for a list with 8 values): `[20, 27, 6, 8 | 11, 10, 2, 33]` 2. Use half of the thread (4 threads) to compare pairs of elements from each half (e.g elements [0, 4], [1, 5] etc) and store the result in the first half of the list: `[20, 27, 6, 33 | 11, 10, 2, 33]` Now we no longer care about the second part but again divide the first part into two `[20, 27 | 6, 33| -, -, -, -]` Now we can use 2 threads to compare the 4 elements 4. Finally we have gotten down to a single pair `[20 | 33 | -, - | -, -, -, -]` Use 1 thread to compare the remaining two elements 5. The max will now be at thread id = 0 `[33 | - | -, - | -, -, -, -]` The reduction will give the farthest point for the selected batch index at this iteration. Reviewed By: bottler, jcjohnson Differential Revision: D30401803 fbshipit-source-id: 525bd5ae27c4b13b501812cfe62306bb003827d2
1 parent d9f7611 commit bd04ffa

11 files changed

+441
-33
lines changed

pytorch3d/csrc/ext.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
#include "point_mesh/point_mesh_cuda.h"
2727
#include "rasterize_meshes/rasterize_meshes.h"
2828
#include "rasterize_points/rasterize_points.h"
29-
#include "sample_pdf/sample_pdf.h"
3029
#include "sample_farthest_points/sample_farthest_points.h"
30+
#include "sample_pdf/sample_pdf.h"
3131

3232
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
3333
m.def("face_areas_normals_forward", &FaceAreasNormalsForward);

pytorch3d/csrc/point_mesh/point_mesh_cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ __global__ void DistanceForwardKernel(
121121
// Unroll the last 6 iterations of the loop since they will happen
122122
// synchronized within a single warp.
123123
if (tid < 32)
124-
WarpReduce<float>(min_dists, min_idxs, tid);
124+
WarpReduceMin<float>(min_dists, min_idxs, tid);
125125

126126
// Finally thread 0 writes the result to the output buffer.
127127
if (tid == 0) {

pytorch3d/csrc/sample_farthest_points/sample_farthest_points.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ at::Tensor FarthestPointSamplingCpu(
1515
const at::Tensor& points,
1616
const at::Tensor& lengths,
1717
const at::Tensor& K,
18-
const bool random_start_point) {
18+
const at::Tensor& start_idxs) {
1919
// Get constants
2020
const int64_t N = points.size(0);
2121
const int64_t P = points.size(1);
@@ -32,6 +32,7 @@ at::Tensor FarthestPointSamplingCpu(
3232
auto lengths_a = lengths.accessor<int64_t, 1>();
3333
auto k_a = K.accessor<int64_t, 1>();
3434
auto sampled_indices_a = sampled_indices.accessor<int64_t, 2>();
35+
auto start_idxs_a = start_idxs.accessor<int64_t, 1>();
3536

3637
// Initialize a mask to prevent duplicates
3738
// If true, the point has already been selected.
@@ -41,20 +42,15 @@ at::Tensor FarthestPointSamplingCpu(
4142
// distances from each point to any of the previously selected points
4243
std::vector<float> dists(P, std::numeric_limits<float>::max());
4344

44-
// Initialize random number generation for random starting points
45-
std::random_device rd;
46-
std::default_random_engine eng(rd());
47-
4845
for (int64_t n = 0; n < N; ++n) {
4946
// Resize and reset points mask and distances for each batch
5047
selected_points_mask.resize(lengths_a[n]);
5148
dists.resize(lengths_a[n]);
5249
std::fill(selected_points_mask.begin(), selected_points_mask.end(), false);
5350
std::fill(dists.begin(), dists.end(), std::numeric_limits<float>::max());
5451

55-
// Select a starting point index and save it
56-
std::uniform_int_distribution<int> distr(0, lengths_a[n] - 1);
57-
int64_t last_idx = random_start_point ? distr(eng) : 0;
52+
// Get the starting point index and save it
53+
int64_t last_idx = start_idxs_a[n];
5854
sampled_indices_a[n][0] = last_idx;
5955

6056
// Set the value of the mask at this point to false
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <ATen/ATen.h>
10+
#include <ATen/cuda/CUDAContext.h>
11+
#include <c10/cuda/CUDAGuard.h>
12+
#include <math.h>
13+
#include <stdio.h>
14+
#include <stdlib.h>
15+
#include "utils/pytorch3d_cutils.h"
16+
#include "utils/warp_reduce.cuh"
17+
18+
template <unsigned int block_size>
19+
__global__ void FarthestPointSamplingKernel(
20+
// clang-format off
21+
const at::PackedTensorAccessor64<float, 3, at::RestrictPtrTraits> points,
22+
const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits> lengths,
23+
const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits> K,
24+
at::PackedTensorAccessor64<int64_t, 2, at::RestrictPtrTraits> idxs,
25+
at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> min_point_dist,
26+
const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits> start_idxs
27+
// clang-format on
28+
) {
29+
// Get constants
30+
const int64_t N = points.size(0);
31+
const int64_t P = points.size(1);
32+
const int64_t D = points.size(2);
33+
34+
// Create single shared memory buffer which is split and cast to different
35+
// types: dists/dists_idx are used to save the maximum distances seen by the
36+
// points processed by any one thread and the associated point indices.
37+
// These values only need to be accessed by other threads in this block which
38+
// are processing the same batch and not by other blocks.
39+
extern __shared__ char shared_buf[];
40+
float* dists = (float*)shared_buf; // block_size floats
41+
int64_t* dists_idx = (int64_t*)&dists[block_size]; // block_size int64_t
42+
43+
// Get batch index and thread index
44+
const int64_t batch_idx = blockIdx.x;
45+
const size_t tid = threadIdx.x;
46+
47+
// If K is greater than the number of points in the pointcloud
48+
// we only need to iterate until the smaller value is reached.
49+
const int64_t k_n = min(K[batch_idx], lengths[batch_idx]);
50+
51+
// Write the first selected point to global memory in the first thread
52+
int64_t selected = start_idxs[batch_idx];
53+
if (tid == 0)
54+
idxs[batch_idx][0] = selected;
55+
56+
// Iterate to find k_n sampled points
57+
for (int64_t k = 1; k < k_n; ++k) {
58+
// Keep track of the maximum of the minimum distance to previously selected
59+
// points seen by this thread
60+
int64_t max_dist_idx = 0;
61+
float max_dist = -1.0;
62+
63+
// Iterate through all the points in this pointcloud. For already selected
64+
// points, the minimum distance to the set of previously selected points
65+
// will be 0.0 so they won't be selected again.
66+
for (int64_t p = tid; p < lengths[batch_idx]; p += block_size) {
67+
// Calculate the distance to the last selected point
68+
float dist2 = 0.0;
69+
for (int64_t d = 0; d < D; ++d) {
70+
float diff = points[batch_idx][selected][d] - points[batch_idx][p][d];
71+
dist2 += (diff * diff);
72+
}
73+
74+
// If the distance of point p to the last selected point is
75+
// less than the previous minimum distance of p to the set of selected
76+
// points, then updated the corresponding value in min_point_dist
77+
// so it always contains the min distance.
78+
const float p_min_dist = min(dist2, min_point_dist[batch_idx][p]);
79+
min_point_dist[batch_idx][p] = p_min_dist;
80+
81+
// Update the max distance and point idx for this thread.
82+
max_dist_idx = (p_min_dist > max_dist) ? p : max_dist_idx;
83+
max_dist = (p_min_dist > max_dist) ? p_min_dist : max_dist;
84+
}
85+
86+
// After going through all points for this thread, save the max
87+
// point and idx seen by this thread. Each thread sees P/block_size points.
88+
dists[tid] = max_dist;
89+
dists_idx[tid] = max_dist_idx;
90+
// Sync to ensure all threads in the block have updated their max point.
91+
__syncthreads();
92+
93+
// Parallelized block reduction to find the max point seen by
94+
// all the threads in this block for iteration k.
95+
// Each block represents one batch element so we can use a divide/conquer
96+
// approach to find the max, syncing all threads after each step.
97+
98+
for (int s = block_size / 2; s > 0; s >>= 1) {
99+
if (tid < s) {
100+
// Compare the best point seen by two threads and update the shared
101+
// memory at the location of the first thread index with the max out
102+
// of the two threads.
103+
if (dists[tid] < dists[tid + s]) {
104+
dists[tid] = dists[tid + s];
105+
dists_idx[tid] = dists_idx[tid + s];
106+
}
107+
}
108+
__syncthreads();
109+
}
110+
111+
// TODO(nikhilar): As reduction proceeds, the number of “active” threads
112+
// decreases. When tid < 32, there should only be one warp left which could
113+
// be unrolled.
114+
115+
// The overall max after reducing will be saved
116+
// at the location of tid = 0.
117+
selected = dists_idx[0];
118+
119+
if (tid == 0) {
120+
// Write the farthest point for iteration k to global memory
121+
idxs[batch_idx][k] = selected;
122+
}
123+
}
124+
}
125+
126+
at::Tensor FarthestPointSamplingCuda(
127+
const at::Tensor& points, // (N, P, 3)
128+
const at::Tensor& lengths, // (N,)
129+
const at::Tensor& K, // (N,)
130+
const at::Tensor& start_idxs) {
131+
// Check inputs are on the same device
132+
at::TensorArg p_t{points, "points", 1}, lengths_t{lengths, "lengths", 2},
133+
k_t{K, "K", 3}, start_idxs_t{start_idxs, "start_idxs", 4};
134+
at::CheckedFrom c = "FarthestPointSamplingCuda";
135+
at::checkAllSameGPU(c, {p_t, lengths_t, k_t, start_idxs_t});
136+
at::checkAllSameType(c, {lengths_t, k_t, start_idxs_t});
137+
138+
// Set the device for the kernel launch based on the device of points
139+
at::cuda::CUDAGuard device_guard(points.device());
140+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
141+
142+
TORCH_CHECK(
143+
points.size(0) == lengths.size(0),
144+
"Point and lengths must have the same batch dimension");
145+
146+
TORCH_CHECK(
147+
points.size(0) == K.size(0),
148+
"Points and K must have the same batch dimension");
149+
150+
const int64_t N = points.size(0);
151+
const int64_t P = points.size(1);
152+
const int64_t max_K = at::max(K).item<int64_t>();
153+
154+
// Initialize the output tensor with the sampled indices
155+
auto idxs = at::full({N, max_K}, -1, lengths.options());
156+
auto min_point_dist = at::full({N, P}, 1e10, points.options());
157+
158+
if (N == 0 || P == 0) {
159+
AT_CUDA_CHECK(cudaGetLastError());
160+
return idxs;
161+
}
162+
163+
// Set the number of blocks to the batch size so that the
164+
// block reduction step can be done for each pointcloud
165+
// to find the max distance point in the pointcloud at each iteration.
166+
const size_t blocks = N;
167+
168+
// Set the threads to the nearest power of 2 of the number of
169+
// points in the pointcloud (up to the max threads in a block).
170+
// This will ensure each thread processes the minimum necessary number of
171+
// points (P/threads).
172+
const int points_pow_2 = std::log(static_cast<double>(P)) / std::log(2.0);
173+
const size_t threads = max(min(1 << points_pow_2, MAX_THREADS_PER_BLOCK), 1);
174+
175+
// Create the accessors
176+
auto points_a = points.packed_accessor64<float, 3, at::RestrictPtrTraits>();
177+
auto lengths_a =
178+
lengths.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>();
179+
auto K_a = K.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>();
180+
auto idxs_a = idxs.packed_accessor64<int64_t, 2, at::RestrictPtrTraits>();
181+
auto start_idxs_a =
182+
start_idxs.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>();
183+
auto min_point_dist_a =
184+
min_point_dist.packed_accessor64<float, 2, at::RestrictPtrTraits>();
185+
186+
// Initialize the shared memory which will be used to store the
187+
// distance/index of the best point seen by each thread.
188+
size_t shared_mem = threads * sizeof(float) + threads * sizeof(int64_t);
189+
// TODO: using shared memory for min_point_dist gives an ~2x speed up
190+
// compared to using a global (N, P) shaped tensor, however for
191+
// larger pointclouds this may exceed the shared memory limit per block.
192+
// If a speed up is required for smaller pointclouds, then the storage
193+
// could be switched to shared memory if the required total shared memory is
194+
// within the memory limit per block.
195+
196+
// Support a case for all powers of 2 up to MAX_THREADS_PER_BLOCK possible per
197+
// block.
198+
switch (threads) {
199+
case 1024:
200+
FarthestPointSamplingKernel<1024>
201+
<<<blocks, threads, shared_mem, stream>>>(
202+
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
203+
break;
204+
case 512:
205+
FarthestPointSamplingKernel<512><<<blocks, threads, shared_mem, stream>>>(
206+
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
207+
break;
208+
case 256:
209+
FarthestPointSamplingKernel<256><<<blocks, threads, shared_mem, stream>>>(
210+
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
211+
break;
212+
case 128:
213+
FarthestPointSamplingKernel<128><<<blocks, threads, shared_mem, stream>>>(
214+
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
215+
break;
216+
case 64:
217+
FarthestPointSamplingKernel<64><<<blocks, threads, shared_mem, stream>>>(
218+
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
219+
break;
220+
case 32:
221+
FarthestPointSamplingKernel<32><<<blocks, threads, shared_mem, stream>>>(
222+
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
223+
break;
224+
case 16:
225+
FarthestPointSamplingKernel<16><<<blocks, threads, shared_mem, stream>>>(
226+
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
227+
break;
228+
case 8:
229+
FarthestPointSamplingKernel<8><<<blocks, threads, shared_mem, stream>>>(
230+
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
231+
break;
232+
case 4:
233+
FarthestPointSamplingKernel<4><<<threads, threads, shared_mem, stream>>>(
234+
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
235+
break;
236+
case 2:
237+
FarthestPointSamplingKernel<2><<<threads, threads, shared_mem, stream>>>(
238+
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
239+
break;
240+
case 1:
241+
FarthestPointSamplingKernel<1><<<threads, threads, shared_mem, stream>>>(
242+
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
243+
break;
244+
default:
245+
FarthestPointSamplingKernel<1024>
246+
<<<blocks, threads, shared_mem, stream>>>(
247+
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
248+
}
249+
250+
AT_CUDA_CHECK(cudaGetLastError());
251+
return idxs;
252+
}

pytorch3d/csrc/sample_farthest_points/sample_farthest_points.h

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,28 +29,44 @@
2929
// K: a tensor of length (N,) giving the number of
3030
// samples to select for each element in the batch.
3131
// The number of samples is typically << P.
32-
// random_start_point: bool, if True, a random point is selected as the
33-
// starting point for iterative sampling.
32+
// start_idxs: (N,) long Tensor giving the index of the first point to
33+
// sample. Default is all 0. When a random start point is required,
34+
// start_idxs should be set to a random value between [0, lengths[n]]
35+
// for batch element n.
3436
// Returns:
3537
// selected_indices: (N, K) array of selected indices. If the values in
3638
// K are not all the same, then the shape will be (N, max(K), D), and
3739
// padded with -1 for batch elements where k_i < max(K). The selected
3840
// points are gathered in the pytorch autograd wrapper.
3941

42+
at::Tensor FarthestPointSamplingCuda(
43+
const at::Tensor& points,
44+
const at::Tensor& lengths,
45+
const at::Tensor& K,
46+
const at::Tensor& start_idxs);
47+
4048
at::Tensor FarthestPointSamplingCpu(
4149
const at::Tensor& points,
4250
const at::Tensor& lengths,
4351
const at::Tensor& K,
44-
const bool random_start_point);
52+
const at::Tensor& start_idxs);
4553

4654
// Exposed implementation.
4755
at::Tensor FarthestPointSampling(
4856
const at::Tensor& points,
4957
const at::Tensor& lengths,
5058
const at::Tensor& K,
51-
const bool random_start_point) {
59+
const at::Tensor& start_idxs) {
5260
if (points.is_cuda() || lengths.is_cuda() || K.is_cuda()) {
53-
AT_ERROR("CUDA implementation not yet supported");
61+
#ifdef WITH_CUDA
62+
CHECK_CUDA(points);
63+
CHECK_CUDA(lengths);
64+
CHECK_CUDA(K);
65+
CHECK_CUDA(start_idxs);
66+
return FarthestPointSamplingCuda(points, lengths, K, start_idxs);
67+
#else
68+
AT_ERROR("Not compiled with GPU support.");
69+
#endif
5470
}
55-
return FarthestPointSamplingCpu(points, lengths, K, random_start_point);
71+
return FarthestPointSamplingCpu(points, lengths, K, start_idxs);
5672
}

pytorch3d/csrc/utils/pytorch3d_cutils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,6 @@
1515
#define CHECK_CONTIGUOUS_CUDA(x) \
1616
CHECK_CUDA(x); \
1717
CHECK_CONTIGUOUS(x)
18+
19+
// Max possible threads per block
20+
const int MAX_THREADS_PER_BLOCK = 1024;

0 commit comments

Comments
 (0)