Skip to content

Commit 9577198

Browse files
Krzysztof Chalupkafacebook-github-bot
Krzysztof Chalupka
authored andcommitted
Make PyTorch3D C++17 incompatible again :(
Summary: D38919607 (c4545a7) and D38858887 (06cbba2) were premature, turns out CUDA 10.2 doesn't support C++17. Reviewed By: bottler Differential Revision: D39156205 fbshipit-source-id: 5e2e84cc4a57d1113a915166631651d438540d56
1 parent 1530a66 commit 9577198

File tree

3 files changed

+14
-8
lines changed

3 files changed

+14
-8
lines changed

pytorch3d/csrc/iou_box3d/iou_box3d.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ __global__ void IoUBox3DKernel(
2929
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
3030
const size_t stride = gridDim.x * blockDim.x;
3131

32-
std::array<FaceVerts, NUM_TRIS> box1_tris{};
33-
std::array<FaceVerts, NUM_TRIS> box2_tris{};
34-
std::array<FaceVerts, NUM_PLANES> box1_planes{};
35-
std::array<FaceVerts, NUM_PLANES> box2_planes{};
32+
FaceVerts box1_tris[NUM_TRIS];
33+
FaceVerts box2_tris[NUM_TRIS];
34+
FaceVerts box1_planes[NUM_PLANES];
35+
FaceVerts box2_planes[NUM_PLANES];
3636

3737
for (size_t i = tid; i < N * M; i += stride) {
3838
const size_t n = i / M; // box1 index

pytorch3d/csrc/iou_box3d/iou_utils.cuh

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ __device__ inline float3 FaceNormal(
175175
auto normal = float3();
176176
auto maxDist = -1;
177177
for (auto v1 = vertices.begin(); v1 != vertices.end() - 1; ++v1) {
178-
for (auto v2 = std::next(v1); v2 != vertices.end(); ++v2) {
178+
for (auto v2 = v1 + 1; v2 != vertices.end(); ++v2) {
179179
const auto v1ToCenter = *v1 - faceCenter;
180180
const auto v2ToCenter = *v2 - faceCenter;
181181
const auto dist = norm(cross(v1ToCenter, v2ToCenter));
@@ -472,8 +472,10 @@ __device__ inline bool IsCoplanarTriTri(
472472
const bool check1 = abs(dot(tri1_n, tri2_n)) > 1 - dEpsilon;
473473

474474
// Compute most distant points
475-
const auto [v1m, v2m] =
475+
const auto v1mAndv2m =
476476
ArgMaxVerts({tri1.v0, tri1.v1, tri1.v2}, {tri2.v0, tri2.v1, tri2.v2});
477+
const auto v1m = std::get<0>(v1mAndv2m);
478+
const auto v2m = std::get<1>(v1mAndv2m);
477479

478480
float3 n12m = v1m - v2m;
479481
n12m = n12m / fmaxf(norm(n12m), kEpsilon);
@@ -506,8 +508,10 @@ __device__ inline bool IsCoplanarTriPlane(
506508
const bool check1 = abs(dot(nt, normal)) > 1 - dEpsilon;
507509

508510
// Compute most distant points
509-
const auto [v1m, v2m] = ArgMaxVerts(
511+
const auto v1mAndv2m = ArgMaxVerts(
510512
{tri.v0, tri.v1, tri.v2}, {plane.v0, plane.v1, plane.v2, plane.v3});
513+
const auto v1m = std::get<0>(v1mAndv2m);
514+
const auto v2m = std::get<1>(v1mAndv2m);
511515

512516
float3 n12m = v1m - v2m;
513517
n12m = n12m / fmaxf(norm(n12m), kEpsilon);

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def get_extensions():
4949
source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu"), recursive=True)
5050
extension = CppExtension
5151

52-
extra_compile_args = {"cxx": []}
52+
extra_compile_args = {"cxx": ["-std=c++14"]}
5353
define_macros = []
5454
include_dirs = [extensions_dir]
5555

@@ -73,6 +73,8 @@ def get_extensions():
7373
"-D__CUDA_NO_HALF_CONVERSIONS__",
7474
"-D__CUDA_NO_HALF2_OPERATORS__",
7575
]
76+
if os.name != "nt":
77+
nvcc_args.append("-std=c++14")
7678
if cub_home is None:
7779
prefix = os.environ.get("CONDA_PREFIX", None)
7880
if prefix is not None and os.path.isdir(prefix + "/include/cub"):

0 commit comments

Comments
 (0)