Skip to content

Commit 7566530

Browse files
JaapSuterfacebook-github-bot
authored andcommitted
CUDA marching_cubes fix
Summary: Fix an inclusive vs exclusive scan mix-up that was accidentally introduced when removing the Thrust dependency (`Thrust::exclusive_scan`) and reimplementing it using `at::cumsum` (which does an inclusive scan). This fixes two Github reported issues: * #1731 * #1751 Reviewed By: bottler Differential Revision: D54605545 fbshipit-source-id: da9e92f3f8a9a35f7b7191428d0b9a9ca03e0d4d
1 parent a27755d commit 7566530

File tree

1 file changed

+50
-20
lines changed

1 file changed

+50
-20
lines changed

pytorch3d/csrc/marching_cubes/marching_cubes.cu

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,44 @@ __global__ void GenerateFacesKernel(
382382
} // end for grid-strided kernel
383383
}
384384

385+
// ATen/Torch does not have an exclusive-scan operator. Additionally, in the
386+
// code below we need to get the "total number of items to work on" after
387+
// a scan, which with an inclusive-scan would simply be the value of the last
388+
// element in the tensor.
389+
//
390+
// This utility function hits two birds with one stone, by running
391+
// an inclusive-scan into a right-shifted view of a tensor that's
392+
// allocated to be one element bigger than the input tensor.
393+
//
394+
// Note; return tensor is `int64_t` per element, even if the input
395+
// tensor is only 32-bit. Also, the return tensor is one element bigger
396+
// than the input one.
397+
//
398+
// Secondary optional argument is an output argument that gets the
399+
// value of the last element of the return tensor (because you almost
400+
// always need this CPU-side right after this function anyway).
401+
static at::Tensor ExclusiveScanAndTotal(
402+
const at::Tensor& inTensor,
403+
int64_t* optTotal = nullptr) {
404+
const auto inSize = inTensor.sizes()[0];
405+
auto retTensor = at::zeros({inSize + 1}, at::kLong).to(inTensor.device());
406+
407+
using at::indexing::None;
408+
using at::indexing::Slice;
409+
auto rightShiftedView = retTensor.index({Slice(1, None)});
410+
411+
// Do an (inclusive-scan) cumulative sum in to the view that's
412+
// shifted one element to the right...
413+
at::cumsum_out(rightShiftedView, inTensor, 0, at::kLong);
414+
415+
if (optTotal) {
416+
*optTotal = retTensor[inSize].cpu().item<int64_t>();
417+
}
418+
419+
// ...so that the not-shifted tensor holds the exclusive-scan
420+
return retTensor;
421+
}
422+
385423
// Entrance for marching cubes cuda extension. Marching Cubes is an algorithm to
386424
// create triangle meshes from an implicit function (one of the form f(x, y, z)
387425
// = 0). It works by iteratively checking a grid of cubes superimposed over a
@@ -444,20 +482,18 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
444482
using at::indexing::Slice;
445483

446484
auto d_voxelVerts =
447-
at::zeros({numVoxels + 1}, at::TensorOptions().dtype(at::kInt))
485+
at::zeros({numVoxels}, at::TensorOptions().dtype(at::kInt))
448486
.to(vol.device());
449-
auto d_voxelVerts_ = d_voxelVerts.index({Slice(1, None)});
450487
auto d_voxelOccupied =
451-
at::zeros({numVoxels + 1}, at::TensorOptions().dtype(at::kInt))
488+
at::zeros({numVoxels}, at::TensorOptions().dtype(at::kInt))
452489
.to(vol.device());
453-
auto d_voxelOccupied_ = d_voxelOccupied.index({Slice(1, None)});
454490

455491
// Execute "ClassifyVoxelKernel" kernel to precompute
456492
// two arrays - d_voxelOccupied and d_voxelVertices to global memory,
457493
// which stores the occupancy state and number of voxel vertices per voxel.
458494
ClassifyVoxelKernel<<<grid, threads, 0, stream>>>(
459-
d_voxelVerts_.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
460-
d_voxelOccupied_.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
495+
d_voxelVerts.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
496+
d_voxelOccupied.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
461497
vol.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
462498
isolevel);
463499
AT_CUDA_CHECK(cudaGetLastError());
@@ -467,12 +503,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
467503
// count for voxels in the grid and compute the number of active voxels.
468504
// If the number of active voxels is 0, return zero tensor for verts and
469505
// faces.
470-
471-
auto d_voxelOccupiedScan = at::cumsum(d_voxelOccupied, 0);
472-
auto d_voxelOccupiedScan_ = d_voxelOccupiedScan.index({Slice(1, None)});
473-
474-
// number of active voxels
475-
int64_t activeVoxels = d_voxelOccupiedScan[numVoxels].cpu().item<int64_t>();
506+
int64_t activeVoxels = 0;
507+
auto d_voxelOccupiedScan =
508+
ExclusiveScanAndTotal(d_voxelOccupied, &activeVoxels);
476509

477510
const int device_id = vol.device().index();
478511
auto opt = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, device_id);
@@ -487,24 +520,21 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
487520
return std::make_tuple(verts, faces, ids);
488521
}
489522

490-
// Execute "CompactVoxelsKernel" kernel to compress voxels for accleration.
523+
// Execute "CompactVoxelsKernel" kernel to compress voxels for acceleration.
491524
// This allows us to run triangle generation on only the occupied voxels.
492525
auto d_compVoxelArray = at::zeros({activeVoxels}, opt);
493526
CompactVoxelsKernel<<<grid, threads, 0, stream>>>(
494527
d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
495528
d_voxelOccupied.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
496-
d_voxelOccupiedScan_
529+
d_voxelOccupiedScan
497530
.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
498531
numVoxels);
499532
AT_CUDA_CHECK(cudaGetLastError());
500533
cudaDeviceSynchronize();
501534

502535
// Scan d_voxelVerts array to generate offsets of vertices for each voxel
503-
auto d_voxelVertsScan = at::cumsum(d_voxelVerts, 0);
504-
auto d_voxelVertsScan_ = d_voxelVertsScan.index({Slice(1, None)});
505-
506-
// total number of vertices
507-
int64_t totalVerts = d_voxelVertsScan[numVoxels].cpu().item<int64_t>();
536+
int64_t totalVerts = 0;
537+
auto d_voxelVertsScan = ExclusiveScanAndTotal(d_voxelVerts, &totalVerts);
508538

509539
// Execute "GenerateFacesKernel" kernel
510540
// This runs only on the occupied voxels.
@@ -524,7 +554,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
524554
faces.packed_accessor<int64_t, 2, at::RestrictPtrTraits>(),
525555
ids.packed_accessor<int64_t, 1, at::RestrictPtrTraits>(),
526556
d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
527-
d_voxelVertsScan_.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
557+
d_voxelVertsScan.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
528558
activeVoxels,
529559
vol.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
530560
faceTable.packed_accessor32<int, 2, at::RestrictPtrTraits>(),

0 commit comments

Comments
 (0)