|
| 1 | +/* |
| 2 | + * Copyright (c) Facebook, Inc. and its affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#include <ATen/ATen.h> |
| 10 | +#include <ATen/cuda/CUDAContext.h> |
| 11 | +#include <c10/cuda/CUDAGuard.h> |
| 12 | +#include <math.h> |
| 13 | +#include <stdio.h> |
| 14 | +#include <stdlib.h> |
| 15 | +#include "utils/pytorch3d_cutils.h" |
| 16 | +#include "utils/warp_reduce.cuh" |
| 17 | + |
| 18 | +template <unsigned int block_size> |
| 19 | +__global__ void FarthestPointSamplingKernel( |
| 20 | + // clang-format off |
| 21 | + const at::PackedTensorAccessor64<float, 3, at::RestrictPtrTraits> points, |
| 22 | + const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits> lengths, |
| 23 | + const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits> K, |
| 24 | + at::PackedTensorAccessor64<int64_t, 2, at::RestrictPtrTraits> idxs, |
| 25 | + at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> min_point_dist, |
| 26 | + const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits> start_idxs |
| 27 | + // clang-format on |
| 28 | +) { |
| 29 | + // Get constants |
| 30 | + const int64_t N = points.size(0); |
| 31 | + const int64_t P = points.size(1); |
| 32 | + const int64_t D = points.size(2); |
| 33 | + |
| 34 | + // Create single shared memory buffer which is split and cast to different |
| 35 | + // types: dists/dists_idx are used to save the maximum distances seen by the |
| 36 | + // points processed by any one thread and the associated point indices. |
| 37 | + // These values only need to be accessed by other threads in this block which |
| 38 | + // are processing the same batch and not by other blocks. |
| 39 | + extern __shared__ char shared_buf[]; |
| 40 | + float* dists = (float*)shared_buf; // block_size floats |
| 41 | + int64_t* dists_idx = (int64_t*)&dists[block_size]; // block_size int64_t |
| 42 | + |
| 43 | + // Get batch index and thread index |
| 44 | + const int64_t batch_idx = blockIdx.x; |
| 45 | + const size_t tid = threadIdx.x; |
| 46 | + |
| 47 | + // If K is greater than the number of points in the pointcloud |
| 48 | + // we only need to iterate until the smaller value is reached. |
| 49 | + const int64_t k_n = min(K[batch_idx], lengths[batch_idx]); |
| 50 | + |
| 51 | + // Write the first selected point to global memory in the first thread |
| 52 | + int64_t selected = start_idxs[batch_idx]; |
| 53 | + if (tid == 0) |
| 54 | + idxs[batch_idx][0] = selected; |
| 55 | + |
| 56 | + // Iterate to find k_n sampled points |
| 57 | + for (int64_t k = 1; k < k_n; ++k) { |
| 58 | + // Keep track of the maximum of the minimum distance to previously selected |
| 59 | + // points seen by this thread |
| 60 | + int64_t max_dist_idx = 0; |
| 61 | + float max_dist = -1.0; |
| 62 | + |
| 63 | + // Iterate through all the points in this pointcloud. For already selected |
| 64 | + // points, the minimum distance to the set of previously selected points |
| 65 | + // will be 0.0 so they won't be selected again. |
| 66 | + for (int64_t p = tid; p < lengths[batch_idx]; p += block_size) { |
| 67 | + // Calculate the distance to the last selected point |
| 68 | + float dist2 = 0.0; |
| 69 | + for (int64_t d = 0; d < D; ++d) { |
| 70 | + float diff = points[batch_idx][selected][d] - points[batch_idx][p][d]; |
| 71 | + dist2 += (diff * diff); |
| 72 | + } |
| 73 | + |
| 74 | + // If the distance of point p to the last selected point is |
| 75 | + // less than the previous minimum distance of p to the set of selected |
| 76 | + // points, then updated the corresponding value in min_point_dist |
| 77 | + // so it always contains the min distance. |
| 78 | + const float p_min_dist = min(dist2, min_point_dist[batch_idx][p]); |
| 79 | + min_point_dist[batch_idx][p] = p_min_dist; |
| 80 | + |
| 81 | + // Update the max distance and point idx for this thread. |
| 82 | + max_dist_idx = (p_min_dist > max_dist) ? p : max_dist_idx; |
| 83 | + max_dist = (p_min_dist > max_dist) ? p_min_dist : max_dist; |
| 84 | + } |
| 85 | + |
| 86 | + // After going through all points for this thread, save the max |
| 87 | + // point and idx seen by this thread. Each thread sees P/block_size points. |
| 88 | + dists[tid] = max_dist; |
| 89 | + dists_idx[tid] = max_dist_idx; |
| 90 | + // Sync to ensure all threads in the block have updated their max point. |
| 91 | + __syncthreads(); |
| 92 | + |
| 93 | + // Parallelized block reduction to find the max point seen by |
| 94 | + // all the threads in this block for iteration k. |
| 95 | + // Each block represents one batch element so we can use a divide/conquer |
| 96 | + // approach to find the max, syncing all threads after each step. |
| 97 | + |
| 98 | + for (int s = block_size / 2; s > 0; s >>= 1) { |
| 99 | + if (tid < s) { |
| 100 | + // Compare the best point seen by two threads and update the shared |
| 101 | + // memory at the location of the first thread index with the max out |
| 102 | + // of the two threads. |
| 103 | + if (dists[tid] < dists[tid + s]) { |
| 104 | + dists[tid] = dists[tid + s]; |
| 105 | + dists_idx[tid] = dists_idx[tid + s]; |
| 106 | + } |
| 107 | + } |
| 108 | + __syncthreads(); |
| 109 | + } |
| 110 | + |
| 111 | + // TODO(nikhilar): As reduction proceeds, the number of “active” threads |
| 112 | + // decreases. When tid < 32, there should only be one warp left which could |
| 113 | + // be unrolled. |
| 114 | + |
| 115 | + // The overall max after reducing will be saved |
| 116 | + // at the location of tid = 0. |
| 117 | + selected = dists_idx[0]; |
| 118 | + |
| 119 | + if (tid == 0) { |
| 120 | + // Write the farthest point for iteration k to global memory |
| 121 | + idxs[batch_idx][k] = selected; |
| 122 | + } |
| 123 | + } |
| 124 | +} |
| 125 | + |
| 126 | +at::Tensor FarthestPointSamplingCuda( |
| 127 | + const at::Tensor& points, // (N, P, 3) |
| 128 | + const at::Tensor& lengths, // (N,) |
| 129 | + const at::Tensor& K, // (N,) |
| 130 | + const at::Tensor& start_idxs) { |
| 131 | + // Check inputs are on the same device |
| 132 | + at::TensorArg p_t{points, "points", 1}, lengths_t{lengths, "lengths", 2}, |
| 133 | + k_t{K, "K", 3}, start_idxs_t{start_idxs, "start_idxs", 4}; |
| 134 | + at::CheckedFrom c = "FarthestPointSamplingCuda"; |
| 135 | + at::checkAllSameGPU(c, {p_t, lengths_t, k_t, start_idxs_t}); |
| 136 | + at::checkAllSameType(c, {lengths_t, k_t, start_idxs_t}); |
| 137 | + |
| 138 | + // Set the device for the kernel launch based on the device of points |
| 139 | + at::cuda::CUDAGuard device_guard(points.device()); |
| 140 | + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| 141 | + |
| 142 | + TORCH_CHECK( |
| 143 | + points.size(0) == lengths.size(0), |
| 144 | + "Point and lengths must have the same batch dimension"); |
| 145 | + |
| 146 | + TORCH_CHECK( |
| 147 | + points.size(0) == K.size(0), |
| 148 | + "Points and K must have the same batch dimension"); |
| 149 | + |
| 150 | + const int64_t N = points.size(0); |
| 151 | + const int64_t P = points.size(1); |
| 152 | + const int64_t max_K = at::max(K).item<int64_t>(); |
| 153 | + |
| 154 | + // Initialize the output tensor with the sampled indices |
| 155 | + auto idxs = at::full({N, max_K}, -1, lengths.options()); |
| 156 | + auto min_point_dist = at::full({N, P}, 1e10, points.options()); |
| 157 | + |
| 158 | + if (N == 0 || P == 0) { |
| 159 | + AT_CUDA_CHECK(cudaGetLastError()); |
| 160 | + return idxs; |
| 161 | + } |
| 162 | + |
| 163 | + // Set the number of blocks to the batch size so that the |
| 164 | + // block reduction step can be done for each pointcloud |
| 165 | + // to find the max distance point in the pointcloud at each iteration. |
| 166 | + const size_t blocks = N; |
| 167 | + |
| 168 | + // Set the threads to the nearest power of 2 of the number of |
| 169 | + // points in the pointcloud (up to the max threads in a block). |
| 170 | + // This will ensure each thread processes the minimum necessary number of |
| 171 | + // points (P/threads). |
| 172 | + const int points_pow_2 = std::log(static_cast<double>(P)) / std::log(2.0); |
| 173 | + const size_t threads = max(min(1 << points_pow_2, MAX_THREADS_PER_BLOCK), 1); |
| 174 | + |
| 175 | + // Create the accessors |
| 176 | + auto points_a = points.packed_accessor64<float, 3, at::RestrictPtrTraits>(); |
| 177 | + auto lengths_a = |
| 178 | + lengths.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>(); |
| 179 | + auto K_a = K.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>(); |
| 180 | + auto idxs_a = idxs.packed_accessor64<int64_t, 2, at::RestrictPtrTraits>(); |
| 181 | + auto start_idxs_a = |
| 182 | + start_idxs.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>(); |
| 183 | + auto min_point_dist_a = |
| 184 | + min_point_dist.packed_accessor64<float, 2, at::RestrictPtrTraits>(); |
| 185 | + |
| 186 | + // Initialize the shared memory which will be used to store the |
| 187 | + // distance/index of the best point seen by each thread. |
| 188 | + size_t shared_mem = threads * sizeof(float) + threads * sizeof(int64_t); |
| 189 | + // TODO: using shared memory for min_point_dist gives an ~2x speed up |
| 190 | + // compared to using a global (N, P) shaped tensor, however for |
| 191 | + // larger pointclouds this may exceed the shared memory limit per block. |
| 192 | + // If a speed up is required for smaller pointclouds, then the storage |
| 193 | + // could be switched to shared memory if the required total shared memory is |
| 194 | + // within the memory limit per block. |
| 195 | + |
| 196 | + // Support a case for all powers of 2 up to MAX_THREADS_PER_BLOCK possible per |
| 197 | + // block. |
| 198 | + switch (threads) { |
| 199 | + case 1024: |
| 200 | + FarthestPointSamplingKernel<1024> |
| 201 | + <<<blocks, threads, shared_mem, stream>>>( |
| 202 | + points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); |
| 203 | + break; |
| 204 | + case 512: |
| 205 | + FarthestPointSamplingKernel<512><<<blocks, threads, shared_mem, stream>>>( |
| 206 | + points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); |
| 207 | + break; |
| 208 | + case 256: |
| 209 | + FarthestPointSamplingKernel<256><<<blocks, threads, shared_mem, stream>>>( |
| 210 | + points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); |
| 211 | + break; |
| 212 | + case 128: |
| 213 | + FarthestPointSamplingKernel<128><<<blocks, threads, shared_mem, stream>>>( |
| 214 | + points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); |
| 215 | + break; |
| 216 | + case 64: |
| 217 | + FarthestPointSamplingKernel<64><<<blocks, threads, shared_mem, stream>>>( |
| 218 | + points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); |
| 219 | + break; |
| 220 | + case 32: |
| 221 | + FarthestPointSamplingKernel<32><<<blocks, threads, shared_mem, stream>>>( |
| 222 | + points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); |
| 223 | + break; |
| 224 | + case 16: |
| 225 | + FarthestPointSamplingKernel<16><<<blocks, threads, shared_mem, stream>>>( |
| 226 | + points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); |
| 227 | + break; |
| 228 | + case 8: |
| 229 | + FarthestPointSamplingKernel<8><<<blocks, threads, shared_mem, stream>>>( |
| 230 | + points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); |
| 231 | + break; |
| 232 | + case 4: |
| 233 | + FarthestPointSamplingKernel<4><<<threads, threads, shared_mem, stream>>>( |
| 234 | + points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); |
| 235 | + break; |
| 236 | + case 2: |
| 237 | + FarthestPointSamplingKernel<2><<<threads, threads, shared_mem, stream>>>( |
| 238 | + points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); |
| 239 | + break; |
| 240 | + case 1: |
| 241 | + FarthestPointSamplingKernel<1><<<threads, threads, shared_mem, stream>>>( |
| 242 | + points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); |
| 243 | + break; |
| 244 | + default: |
| 245 | + FarthestPointSamplingKernel<1024> |
| 246 | + <<<blocks, threads, shared_mem, stream>>>( |
| 247 | + points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a); |
| 248 | + } |
| 249 | + |
| 250 | + AT_CUDA_CHECK(cudaGetLastError()); |
| 251 | + return idxs; |
| 252 | +} |
0 commit comments