Skip to content

Commit 103da63

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
Ball Query
Summary: Implementation of ball query from PointNet++. This function is similar to KNN (find the neighbors in p2 for all points in p1). These are the key differences: - It will return the **first** K neighbors within a specified radius as opposed to the **closest** K neighbors. - As all the points in p2 do not need to be considered to find the closest K, the algorithm is much faster than KNN when p2 has a large number of points. - The neighbors are not sorted - Due to the radius threshold it is not guaranteed that there will be K neighbors even if there are more than K points in p2. - The padding value for `idx` is -1 instead of 0. # Note: - Some of the code is very similar to KNN so it could be possible to modify the KNN forward kernels to support ball query. - Some users might want to use kNN with ball query - for this we could provide a wrapper function around the current `knn_points` which enables applying the radius threshold afterwards as an alternative. This could be called `ball_query_knn`. Reviewed By: jcjohnson Differential Revision: D30261362 fbshipit-source-id: 66b6a7e0114beff7164daf7eba21546ff41ec450
1 parent e5c58a8 commit 103da63

File tree

10 files changed

+709
-1
lines changed

10 files changed

+709
-1
lines changed
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <ATen/ATen.h>
10+
#include <ATen/cuda/CUDAContext.h>
11+
#include <c10/cuda/CUDAGuard.h>
12+
#include <math.h>
13+
#include <stdio.h>
14+
#include <stdlib.h>
15+
#include "utils/pytorch3d_cutils.h"
16+
17+
// A chunk of work is blocksize-many points of P1.
18+
// The number of potential chunks to do is N*(1+(P1-1)/blocksize)
19+
// call (1+(P1-1)/blocksize) chunks_per_cloud
20+
// These chunks are divided among the gridSize-many blocks.
21+
// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc .
22+
// In chunk i, we work on cloud i/chunks_per_cloud on points starting from
23+
// blocksize*(i%chunks_per_cloud).
24+
25+
template <typename scalar_t>
26+
__global__ void BallQueryKernel(
27+
const at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> p1,
28+
const at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> p2,
29+
const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits>
30+
lengths1,
31+
const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits>
32+
lengths2,
33+
at::PackedTensorAccessor64<int64_t, 3, at::RestrictPtrTraits> idxs,
34+
at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> dists,
35+
const int64_t K,
36+
const float radius2) {
37+
const int64_t N = p1.size(0);
38+
const int64_t chunks_per_cloud = (1 + (p1.size(1) - 1) / blockDim.x);
39+
const int64_t chunks_to_do = N * chunks_per_cloud;
40+
const int D = p1.size(2);
41+
42+
for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
43+
const int64_t n = chunk / chunks_per_cloud; // batch_index
44+
const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
45+
int64_t i = start_point + threadIdx.x;
46+
47+
// Check if point is valid in heterogeneous tensor
48+
if (i >= lengths1[n]) {
49+
continue;
50+
}
51+
52+
// Iterate over points in p2 until desired count is reached or
53+
// all points have been considered
54+
for (int64_t j = 0, count = 0; j < lengths2[n] && count < K; ++j) {
55+
// Calculate the distance between the points
56+
scalar_t dist2 = 0.0;
57+
for (int d = 0; d < D; ++d) {
58+
scalar_t diff = p1[n][i][d] - p2[n][j][d];
59+
dist2 += (diff * diff);
60+
}
61+
62+
if (dist2 < radius2) {
63+
// If the point is within the radius
64+
// Set the value of the index to the point index
65+
idxs[n][i][count] = j;
66+
dists[n][i][count] = dist2;
67+
68+
// increment the number of selected samples for the point i
69+
++count;
70+
}
71+
}
72+
}
73+
}
74+
75+
std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
76+
const at::Tensor& p1, // (N, P1, 3)
77+
const at::Tensor& p2, // (N, P2, 3)
78+
const at::Tensor& lengths1, // (N,)
79+
const at::Tensor& lengths2, // (N,)
80+
int K,
81+
float radius) {
82+
// Check inputs are on the same device
83+
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
84+
lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4};
85+
at::CheckedFrom c = "BallQueryCuda";
86+
at::checkAllSameGPU(c, {p1_t, p2_t, lengths1_t, lengths2_t});
87+
at::checkAllSameType(c, {p1_t, p2_t});
88+
89+
// Set the device for the kernel launch based on the device of p1
90+
at::cuda::CUDAGuard device_guard(p1.device());
91+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
92+
93+
TORCH_CHECK(
94+
p2.size(2) == p1.size(2), "Point sets must have the same last dimension");
95+
96+
const int N = p1.size(0);
97+
const int P1 = p1.size(1);
98+
const int64_t K_64 = K;
99+
const float radius2 = radius * radius;
100+
101+
// Output tensor with indices of neighbors for each point in p1
102+
auto long_dtype = lengths1.options().dtype(at::kLong);
103+
auto idxs = at::full({N, P1, K}, -1, long_dtype);
104+
auto dists = at::zeros({N, P1, K}, p1.options());
105+
106+
if (idxs.numel() == 0) {
107+
AT_CUDA_CHECK(cudaGetLastError());
108+
return std::make_tuple(idxs, dists);
109+
}
110+
111+
const size_t blocks = 256;
112+
const size_t threads = 256;
113+
114+
AT_DISPATCH_FLOATING_TYPES(
115+
p1.scalar_type(), "ball_query_kernel_cuda", ([&] {
116+
BallQueryKernel<<<blocks, threads, 0, stream>>>(
117+
p1.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
118+
p2.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
119+
lengths1.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>(),
120+
lengths2.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>(),
121+
idxs.packed_accessor64<int64_t, 3, at::RestrictPtrTraits>(),
122+
dists.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
123+
K_64,
124+
radius2);
125+
}));
126+
127+
AT_CUDA_CHECK(cudaGetLastError());
128+
129+
return std::make_tuple(idxs, dists);
130+
}
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
#include <torch/extension.h>
11+
#include <tuple>
12+
#include "utils/pytorch3d_cutils.h"
13+
14+
// Compute indices of K neighbors in pointcloud p2 to points
15+
// in pointcloud p1 which fall within a specified radius
16+
//
17+
// Args:
18+
// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
19+
// containing P1 points of dimension D.
20+
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
21+
// containing P2 points of dimension D.
22+
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
23+
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
24+
// K: Integer giving the upper bound on the number of samples to take
25+
// within the radius
26+
// radius: the radius around each point within which the neighbors need to be
27+
// located
28+
//
29+
// Returns:
30+
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
31+
// p1_neighbor_idx[n, i, k] = j means that the kth
32+
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
33+
// This is padded with -1s both where a cloud in p2 has fewer than
34+
// S points and where a cloud in p1 has fewer than P1 points and
35+
// also if there are fewer than K points which satisfy the radius
36+
// threshold.
37+
//
38+
// p1_neighbor_dists: FloatTensor of shape (N, P1, K) containing the squared
39+
// distance from each point p1[n, p, :] to its K neighbors
40+
// p2[n, p1_neighbor_idx[n, p, k], :].
41+
42+
// CPU implementation
43+
std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
44+
const at::Tensor& p1,
45+
const at::Tensor& p2,
46+
const at::Tensor& lengths1,
47+
const at::Tensor& lengths2,
48+
const int K,
49+
const float radius);
50+
51+
// CUDA implementation
52+
std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
53+
const at::Tensor& p1,
54+
const at::Tensor& p2,
55+
const at::Tensor& lengths1,
56+
const at::Tensor& lengths2,
57+
const int K,
58+
const float radius);
59+
60+
// Implementation which is exposed
61+
// Note: the backward pass reuses the KNearestNeighborBackward kernel
62+
inline std::tuple<at::Tensor, at::Tensor> BallQuery(
63+
const at::Tensor& p1,
64+
const at::Tensor& p2,
65+
const at::Tensor& lengths1,
66+
const at::Tensor& lengths2,
67+
int K,
68+
float radius) {
69+
if (p1.is_cuda() || p2.is_cuda()) {
70+
#ifdef WITH_CUDA
71+
CHECK_CUDA(p1);
72+
CHECK_CUDA(p2);
73+
return BallQueryCuda(
74+
p1.contiguous(),
75+
p2.contiguous(),
76+
lengths1.contiguous(),
77+
lengths2.contiguous(),
78+
K,
79+
radius);
80+
#else
81+
AT_ERROR("Not compiled with GPU support.");
82+
#endif
83+
}
84+
return BallQueryCpu(
85+
p1.contiguous(),
86+
p2.contiguous(),
87+
lengths1.contiguous(),
88+
lengths2.contiguous(),
89+
K,
90+
radius);
91+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <torch/extension.h>
10+
#include <queue>
11+
#include <tuple>
12+
13+
std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
14+
const at::Tensor& p1,
15+
const at::Tensor& p2,
16+
const at::Tensor& lengths1,
17+
const at::Tensor& lengths2,
18+
int K,
19+
float radius) {
20+
const int N = p1.size(0);
21+
const int P1 = p1.size(1);
22+
const int D = p1.size(2);
23+
24+
auto long_opts = lengths1.options().dtype(torch::kInt64);
25+
torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts);
26+
torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options());
27+
const float radius2 = radius * radius;
28+
29+
auto p1_a = p1.accessor<float, 3>();
30+
auto p2_a = p2.accessor<float, 3>();
31+
auto lengths1_a = lengths1.accessor<int64_t, 1>();
32+
auto lengths2_a = lengths2.accessor<int64_t, 1>();
33+
auto idxs_a = idxs.accessor<int64_t, 3>();
34+
auto dists_a = dists.accessor<float, 3>();
35+
36+
for (int n = 0; n < N; ++n) {
37+
const int64_t length1 = lengths1_a[n];
38+
const int64_t length2 = lengths2_a[n];
39+
for (int64_t i = 0; i < length1; ++i) {
40+
for (int64_t j = 0, count = 0; j < length2 && count < K; ++j) {
41+
float dist2 = 0;
42+
for (int d = 0; d < D; ++d) {
43+
float diff = p1_a[n][i][d] - p2_a[n][j][d];
44+
dist2 += diff * diff;
45+
}
46+
if (dist2 < radius2) {
47+
dists_a[n][i][count] = dist2;
48+
idxs_a[n][i][count] = j;
49+
++count;
50+
}
51+
}
52+
}
53+
}
54+
return std::make_tuple(idxs, dists);
55+
}

pytorch3d/csrc/ext.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// clang-format on
1313
#include "./pulsar/pytorch/renderer.h"
1414
#include "./pulsar/pytorch/tensor_util.h"
15+
#include "ball_query/ball_query.h"
1516
#include "blending/sigmoid_alpha_blend.h"
1617
#include "compositing/alpha_composite.h"
1718
#include "compositing/norm_weighted_sum.h"
@@ -38,6 +39,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
3839
#endif
3940
m.def("knn_points_idx", &KNearestNeighborIdx);
4041
m.def("knn_points_backward", &KNearestNeighborBackward);
42+
43+
// Ball Query
44+
m.def("ball_query", &BallQuery);
4145
m.def(
4246
"mesh_normal_consistency_find_verts", &MeshNormalConsistencyFindVertices);
4347
m.def("gather_scatter", &GatherScatter);

pytorch3d/csrc/knn/knn.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,10 @@ __global__ void KNearestNeighborBackwardKernel(
477477
const float grad_dist = grad_dists[n * P1 * K + p1_idx * K + k];
478478
// index of point in p2 corresponding to the k-th nearest neighbor
479479
const size_t p2_idx = idxs[n * P1 * K + p1_idx * K + k];
480+
// If the index is the pad value of -1 then ignore it
481+
if (p2_idx == -1) {
482+
continue;
483+
}
480484
const float diff = 2.0 * grad_dist *
481485
(p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
482486
atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff);

pytorch3d/csrc/knn/knn_cpu.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
9999
for (int64_t i1 = 0; i1 < length1; ++i1) {
100100
for (int64_t k = 0; k < length2; ++k) {
101101
const int64_t i2 = idxs_a[n][i1][k];
102+
// If the index is the pad value of -1 then ignore it
103+
if (i2 == -1) {
104+
continue;
105+
}
102106
for (int64_t d = 0; d < D; ++d) {
103107
const float diff =
104108
2.0f * grad_dists_a[n][i1][k] * (p1_a[n][i1][d] - p2_a[n][i2][d]);

pytorch3d/ops/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from .ball_query import ball_query
78
from .cameras_alignment import corresponding_cameras_alignment
89
from .cubify import cubify
910
from .graph_conv import GraphConv
@@ -34,5 +35,4 @@
3435
)
3536
from .vert_align import vert_align
3637

37-
3838
__all__ = [k for k in globals().keys() if not k.startswith("_")]

0 commit comments

Comments
 (0)