Skip to content

Commit 6f2212d

Browse files
bottlerfacebook-github-bot
authored andcommitted
use less thrust, maybe help Windows
Summary: I think we include more thrust than needed, and maybe removing it will help things like #1610 with DebugSyncStream errors on Windows. Reviewed By: shapovalov Differential Revision: D48949888 fbshipit-source-id: add889c0acf730a039dc9ffd6bbcc24ded20ef27
1 parent a3d99ca commit 6f2212d

File tree

3 files changed

+18
-43
lines changed

3 files changed

+18
-43
lines changed

pytorch3d/csrc/iou_box3d/iou_box3d.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
#include <math.h>
1313
#include <stdio.h>
1414
#include <stdlib.h>
15-
#include <thrust/device_vector.h>
16-
#include <thrust/tuple.h>
1715
#include "iou_box3d/iou_utils.cuh"
1816

1917
// Parallelize over N*M computations which can each be done

pytorch3d/csrc/iou_box3d/iou_utils.cuh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
#include <float.h>
1010
#include <math.h>
11-
#include <thrust/device_vector.h>
1211
#include <cstdio>
1312
#include "utils/float_math.cuh"
1413

pytorch3d/csrc/marching_cubes/marching_cubes.cu

Lines changed: 18 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
#include <ATen/ATen.h>
1010
#include <ATen/cuda/CUDAContext.h>
1111
#include <c10/cuda/CUDAGuard.h>
12-
#include <thrust/device_vector.h>
13-
#include <thrust/scan.h>
1412
#include <cstdio>
1513
#include "marching_cubes/tables.h"
1614

@@ -40,20 +38,6 @@ through" each cube in the grid.
4038
// EPS: Used to indicate if two float values are close
4139
__constant__ const float EPSILON = 1e-5;
4240

43-
// Thrust wrapper for exclusive scan
44-
//
45-
// Args:
46-
// output: pointer to on-device output array
47-
// input: pointer to on-device input array, where scan is performed
48-
// numElements: number of elements for the input array
49-
//
50-
void ThrustScanWrapper(int* output, int* input, int numElements) {
51-
thrust::exclusive_scan(
52-
thrust::device_ptr<int>(input),
53-
thrust::device_ptr<int>(input + numElements),
54-
thrust::device_ptr<int>(output));
55-
}
56-
5741
// Linearly interpolate the position where an isosurface cuts an edge
5842
// between two vertices, based on their scalar values
5943
//
@@ -455,19 +439,24 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
455439
grid.x = 65535;
456440
}
457441

442+
using at::indexing::None;
443+
using at::indexing::Slice;
444+
458445
auto d_voxelVerts =
459-
at::zeros({numVoxels}, at::TensorOptions().dtype(at::kInt))
446+
at::zeros({numVoxels + 1}, at::TensorOptions().dtype(at::kInt))
460447
.to(vol.device());
448+
auto d_voxelVerts_ = d_voxelVerts.index({Slice(1, None)});
461449
auto d_voxelOccupied =
462-
at::zeros({numVoxels}, at::TensorOptions().dtype(at::kInt))
450+
at::zeros({numVoxels + 1}, at::TensorOptions().dtype(at::kInt))
463451
.to(vol.device());
452+
auto d_voxelOccupied_ = d_voxelOccupied.index({Slice(1, None)});
464453

465454
// Execute "ClassifyVoxelKernel" kernel to precompute
466455
// two arrays - d_voxelOccupied and d_voxelVertices to global memory,
467456
// which stores the occupancy state and number of voxel vertices per voxel.
468457
ClassifyVoxelKernel<<<grid, threads, 0, stream>>>(
469-
d_voxelVerts.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
470-
d_voxelOccupied.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
458+
d_voxelVerts_.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
459+
d_voxelOccupied_.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
471460
vol.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
472461
isolevel);
473462
AT_CUDA_CHECK(cudaGetLastError());
@@ -477,18 +466,12 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
477466
// count for voxels in the grid and compute the number of active voxels.
478467
// If the number of active voxels is 0, return zero tensor for verts and
479468
// faces.
480-
auto d_voxelOccupiedScan =
481-
at::zeros({numVoxels}, at::TensorOptions().dtype(at::kInt))
482-
.to(vol.device());
483-
ThrustScanWrapper(
484-
d_voxelOccupiedScan.data_ptr<int>(),
485-
d_voxelOccupied.data_ptr<int>(),
486-
numVoxels);
469+
470+
auto d_voxelOccupiedScan = at::cumsum(d_voxelOccupied, 0);
471+
auto d_voxelOccupiedScan_ = d_voxelOccupiedScan.index({Slice(1, None)});
487472

488473
// number of active voxels
489-
int lastElement = d_voxelVerts[numVoxels - 1].cpu().item<int>();
490-
int lastScan = d_voxelOccupiedScan[numVoxels - 1].cpu().item<int>();
491-
int activeVoxels = lastElement + lastScan;
474+
int activeVoxels = d_voxelOccupiedScan[numVoxels].cpu().item<int>();
492475

493476
const int device_id = vol.device().index();
494477
auto opt = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, device_id);
@@ -509,22 +492,17 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
509492
CompactVoxelsKernel<<<grid, threads, 0, stream>>>(
510493
d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
511494
d_voxelOccupied.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
512-
d_voxelOccupiedScan.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
495+
d_voxelOccupiedScan_.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
513496
numVoxels);
514497
AT_CUDA_CHECK(cudaGetLastError());
515498
cudaDeviceSynchronize();
516499

517500
// Scan d_voxelVerts array to generate offsets of vertices for each voxel
518-
auto d_voxelVertsScan = at::zeros({numVoxels}, opt);
519-
ThrustScanWrapper(
520-
d_voxelVertsScan.data_ptr<int>(),
521-
d_voxelVerts.data_ptr<int>(),
522-
numVoxels);
501+
auto d_voxelVertsScan = at::cumsum(d_voxelVerts, 0);
502+
auto d_voxelVertsScan_ = d_voxelVertsScan.index({Slice(1, None)});
523503

524504
// total number of vertices
525-
lastElement = d_voxelVerts[numVoxels - 1].cpu().item<int>();
526-
lastScan = d_voxelVertsScan[numVoxels - 1].cpu().item<int>();
527-
int totalVerts = lastElement + lastScan;
505+
int totalVerts = d_voxelVertsScan[numVoxels].cpu().item<int>();
528506

529507
// Execute "GenerateFacesKernel" kernel
530508
// This runs only on the occupied voxels.
@@ -544,7 +522,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
544522
faces.packed_accessor<int64_t, 2, at::RestrictPtrTraits>(),
545523
ids.packed_accessor<int64_t, 1, at::RestrictPtrTraits>(),
546524
d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
547-
d_voxelVertsScan.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
525+
d_voxelVertsScan_.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
548526
activeVoxels,
549527
vol.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
550528
faceTable.packed_accessor32<int, 2, at::RestrictPtrTraits>(),

0 commit comments

Comments
 (0)