@@ -382,6 +382,44 @@ __global__ void GenerateFacesKernel(
382
382
} // end for grid-strided kernel
383
383
}
384
384
385
+ // ATen/Torch does not have an exclusive-scan operator. Additionally, in the
386
+ // code below we need to get the "total number of items to work on" after
387
+ // a scan, which with an inclusive-scan would simply be the value of the last
388
+ // element in the tensor.
389
+ //
390
+ // This utility function hits two birds with one stone, by running
391
+ // an inclusive-scan into a right-shifted view of a tensor that's
392
+ // allocated to be one element bigger than the input tensor.
393
+ //
394
+ // Note; return tensor is `int64_t` per element, even if the input
395
+ // tensor is only 32-bit. Also, the return tensor is one element bigger
396
+ // than the input one.
397
+ //
398
+ // Secondary optional argument is an output argument that gets the
399
+ // value of the last element of the return tensor (because you almost
400
+ // always need this CPU-side right after this function anyway).
401
+ static at::Tensor ExclusiveScanAndTotal (
402
+ const at::Tensor& inTensor,
403
+ int64_t * optTotal = nullptr ) {
404
+ const auto inSize = inTensor.sizes ()[0 ];
405
+ auto retTensor = at::zeros ({inSize + 1 }, at::kLong ).to (inTensor.device ());
406
+
407
+ using at::indexing::None;
408
+ using at::indexing::Slice;
409
+ auto rightShiftedView = retTensor.index ({Slice (1 , None)});
410
+
411
+ // Do an (inclusive-scan) cumulative sum in to the view that's
412
+ // shifted one element to the right...
413
+ at::cumsum_out (rightShiftedView, inTensor, 0 , at::kLong );
414
+
415
+ if (optTotal) {
416
+ *optTotal = retTensor[inSize].cpu ().item <int64_t >();
417
+ }
418
+
419
+ // ...so that the not-shifted tensor holds the exclusive-scan
420
+ return retTensor;
421
+ }
422
+
385
423
// Entrance for marching cubes cuda extension. Marching Cubes is an algorithm to
386
424
// create triangle meshes from an implicit function (one of the form f(x, y, z)
387
425
// = 0). It works by iteratively checking a grid of cubes superimposed over a
@@ -444,20 +482,18 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
444
482
using at::indexing::Slice;
445
483
446
484
auto d_voxelVerts =
447
- at::zeros ({numVoxels + 1 }, at::TensorOptions ().dtype (at::kInt ))
485
+ at::zeros ({numVoxels}, at::TensorOptions ().dtype (at::kInt ))
448
486
.to (vol.device ());
449
- auto d_voxelVerts_ = d_voxelVerts.index ({Slice (1 , None)});
450
487
auto d_voxelOccupied =
451
- at::zeros ({numVoxels + 1 }, at::TensorOptions ().dtype (at::kInt ))
488
+ at::zeros ({numVoxels}, at::TensorOptions ().dtype (at::kInt ))
452
489
.to (vol.device ());
453
- auto d_voxelOccupied_ = d_voxelOccupied.index ({Slice (1 , None)});
454
490
455
491
// Execute "ClassifyVoxelKernel" kernel to precompute
456
492
// two arrays - d_voxelOccupied and d_voxelVertices to global memory,
457
493
// which stores the occupancy state and number of voxel vertices per voxel.
458
494
ClassifyVoxelKernel<<<grid, threads, 0 , stream>>> (
459
- d_voxelVerts_ .packed_accessor32 <int , 1 , at::RestrictPtrTraits>(),
460
- d_voxelOccupied_ .packed_accessor32 <int , 1 , at::RestrictPtrTraits>(),
495
+ d_voxelVerts .packed_accessor32 <int , 1 , at::RestrictPtrTraits>(),
496
+ d_voxelOccupied .packed_accessor32 <int , 1 , at::RestrictPtrTraits>(),
461
497
vol.packed_accessor32 <float , 3 , at::RestrictPtrTraits>(),
462
498
isolevel);
463
499
AT_CUDA_CHECK (cudaGetLastError ());
@@ -467,12 +503,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
467
503
// count for voxels in the grid and compute the number of active voxels.
468
504
// If the number of active voxels is 0, return zero tensor for verts and
469
505
// faces.
470
-
471
- auto d_voxelOccupiedScan = at::cumsum (d_voxelOccupied, 0 );
472
- auto d_voxelOccupiedScan_ = d_voxelOccupiedScan.index ({Slice (1 , None)});
473
-
474
- // number of active voxels
475
- int64_t activeVoxels = d_voxelOccupiedScan[numVoxels].cpu ().item <int64_t >();
506
+ int64_t activeVoxels = 0 ;
507
+ auto d_voxelOccupiedScan =
508
+ ExclusiveScanAndTotal (d_voxelOccupied, &activeVoxels);
476
509
477
510
const int device_id = vol.device ().index ();
478
511
auto opt = at::TensorOptions ().dtype (at::kInt ).device (at::kCUDA , device_id);
@@ -487,24 +520,21 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
487
520
return std::make_tuple (verts, faces, ids);
488
521
}
489
522
490
- // Execute "CompactVoxelsKernel" kernel to compress voxels for accleration .
523
+ // Execute "CompactVoxelsKernel" kernel to compress voxels for acceleration .
491
524
// This allows us to run triangle generation on only the occupied voxels.
492
525
auto d_compVoxelArray = at::zeros ({activeVoxels}, opt);
493
526
CompactVoxelsKernel<<<grid, threads, 0 , stream>>> (
494
527
d_compVoxelArray.packed_accessor32 <int , 1 , at::RestrictPtrTraits>(),
495
528
d_voxelOccupied.packed_accessor32 <int , 1 , at::RestrictPtrTraits>(),
496
- d_voxelOccupiedScan_
529
+ d_voxelOccupiedScan
497
530
.packed_accessor32 <int64_t , 1 , at::RestrictPtrTraits>(),
498
531
numVoxels);
499
532
AT_CUDA_CHECK (cudaGetLastError ());
500
533
cudaDeviceSynchronize ();
501
534
502
535
// Scan d_voxelVerts array to generate offsets of vertices for each voxel
503
- auto d_voxelVertsScan = at::cumsum (d_voxelVerts, 0 );
504
- auto d_voxelVertsScan_ = d_voxelVertsScan.index ({Slice (1 , None)});
505
-
506
- // total number of vertices
507
- int64_t totalVerts = d_voxelVertsScan[numVoxels].cpu ().item <int64_t >();
536
+ int64_t totalVerts = 0 ;
537
+ auto d_voxelVertsScan = ExclusiveScanAndTotal (d_voxelVerts, &totalVerts);
508
538
509
539
// Execute "GenerateFacesKernel" kernel
510
540
// This runs only on the occupied voxels.
@@ -524,7 +554,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
524
554
faces.packed_accessor <int64_t , 2 , at::RestrictPtrTraits>(),
525
555
ids.packed_accessor <int64_t , 1 , at::RestrictPtrTraits>(),
526
556
d_compVoxelArray.packed_accessor32 <int , 1 , at::RestrictPtrTraits>(),
527
- d_voxelVertsScan_ .packed_accessor32 <int64_t , 1 , at::RestrictPtrTraits>(),
557
+ d_voxelVertsScan .packed_accessor32 <int64_t , 1 , at::RestrictPtrTraits>(),
528
558
activeVoxels,
529
559
vol.packed_accessor32 <float , 3 , at::RestrictPtrTraits>(),
530
560
faceTable.packed_accessor32 <int , 2 , at::RestrictPtrTraits>(),
0 commit comments