Skip to content

Commit c4545a7

Browse files
Krzysztof Chalupkafacebook-github-bot
Krzysztof Chalupka
authored andcommitted
Add structured bindings to iou to prove that we're C++17-friendly. Also other minor improvements to bbox iou
Summary: Recently we removed C++14-only compilation, should work. Reviewed By: bottler Differential Revision: D38919607 fbshipit-source-id: 6a26fa7713f7ba2163364ccc673ad774aa3a5adb
1 parent 5e7707b commit c4545a7

File tree

2 files changed

+61
-165
lines changed

2 files changed

+61
-165
lines changed

pytorch3d/csrc/iou_box3d/iou_box3d.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,17 @@ __global__ void IoUBox3DKernel(
2929
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
3030
const size_t stride = gridDim.x * blockDim.x;
3131

32+
std::array<FaceVerts, NUM_TRIS> box1_tris{};
33+
std::array<FaceVerts, NUM_TRIS> box2_tris{};
34+
std::array<FaceVerts, NUM_PLANES> box1_planes{};
35+
std::array<FaceVerts, NUM_PLANES> box2_planes{};
36+
3237
for (size_t i = tid; i < N * M; i += stride) {
3338
const size_t n = i / M; // box1 index
3439
const size_t m = i % M; // box2 index
3540

3641
// Convert to array of structs of face vertices i.e. effectively (F, 3, 3)
3742
// FaceVerts is a data type defined in iou_utils.cuh
38-
FaceVerts box1_tris[NUM_TRIS];
39-
FaceVerts box2_tris[NUM_TRIS];
4043
GetBoxTris(boxes1[n], box1_tris);
4144
GetBoxTris(boxes2[m], box2_tris);
4245

@@ -46,9 +49,7 @@ __global__ void IoUBox3DKernel(
4649
const float3 box2_center = BoxCenter(boxes2[m]);
4750

4851
// Convert to an array of face vertices
49-
FaceVerts box1_planes[NUM_PLANES];
5052
GetBoxPlanes(boxes1[n], box1_planes);
51-
FaceVerts box2_planes[NUM_PLANES];
5253
GetBoxPlanes(boxes2[m], box2_planes);
5354

5455
// Get Box Volumes

pytorch3d/csrc/iou_box3d/iou_utils.cuh

Lines changed: 56 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -39,25 +39,25 @@ const int MAX_TRIS = 100;
3939
// We will use struct arrays for representing
4040
// the data for each box and intersecting
4141
// triangles
42-
typedef struct {
42+
struct FaceVerts {
4343
float3 v0;
4444
float3 v1;
4545
float3 v2;
4646
float3 v3; // Can be empty for triangles
47-
} FaceVerts;
47+
};
4848

49-
typedef struct {
49+
struct FaceVertsIdx {
5050
int v0;
5151
int v1;
5252
int v2;
5353
int v3; // Can be empty for triangles
54-
} FaceVertsIdx;
54+
};
5555

5656
// This is used when deciding which faces to
5757
// keep that are not coplanar
58-
typedef struct {
58+
struct Keep {
5959
bool keep;
60-
} Keep;
60+
};
6161

6262
__device__ FaceVertsIdx _PLANES[] = {
6363
{0, 1, 2, 3},
@@ -128,64 +128,64 @@ __device__ inline void GetBoxPlanes(
128128
}
129129
}
130130

131-
// The normal of a plane spanned by vectors e0 and e1
131+
// The geometric center of a list of vertices.
132132
//
133133
// Args
134-
// e0, e1: float3 vectors defining a plane
134+
// vertices: A list of float3 vertices {v0, ..., vN}.
135135
//
136136
// Returns
137-
// float3: normal of the plane
137+
// float3: Geometric center of the vertices.
138138
//
139-
__device__ inline float3 GetNormal(const float3 e0, const float3 e1) {
140-
float3 n = cross(e0, e1);
141-
n = n / std::fmaxf(norm(n), kEpsilon);
142-
return n;
139+
__device__ inline float3 FaceCenter(
140+
std::initializer_list<const float3> vertices) {
141+
auto sumVertices = float3{};
142+
for (const auto& vertex : vertices) {
143+
sumVertices = sumVertices + vertex;
144+
}
145+
return sumVertices / vertices.size();
143146
}
144147

145-
// The center of a triangle defined by vertices (v0, v1, v2)
148+
// The normal of a plane spanned by vectors e0 and e1
146149
//
147150
// Args
148-
// v0, v1, v2: float3 coordinates of the vertices of the triangle
151+
// e0, e1: float3 vectors defining a plane
149152
//
150153
// Returns
151-
// float3: center of the triangle
154+
// float3: normal of the plane
152155
//
153-
__device__ inline float3
154-
TriCenter(const float3 v0, const float3 v1, const float3 v2) {
155-
float3 ctr = (v0 + v1 + v2) / 3.0f;
156-
return ctr;
156+
__device__ inline float3 GetNormal(const float3 e0, const float3 e1) {
157+
float3 n = cross(e0, e1);
158+
n = n / std::fmaxf(norm(n), kEpsilon);
159+
return n;
157160
}
158161

159-
// The normal of the triangle defined by vertices (v0, v1, v2)
162+
// The normal of a face with vertices (v0, v1, v2) or (v0, ..., v3).
160163
// We find the "best" edges connecting the face center to the vertices,
161164
// such that the cross product between the edges is maximized.
162165
//
163166
// Args
164-
// v0, v1, v2: float3 coordinates of the vertices of the face
167+
// vertices: a list of float3 coordinates of the vertices.
165168
//
166169
// Returns
167-
// float3: normal for the face
170+
// float3: center of the plane
168171
//
169-
__device__ inline float3
170-
TriNormal(const float3 v0, const float3 v1, const float3 v2) {
171-
const float3 ctr = TriCenter(v0, v1, v2);
172-
173-
const float d01 = norm(cross(v0 - ctr, v1 - ctr));
174-
const float d02 = norm(cross(v0 - ctr, v2 - ctr));
175-
const float d12 = norm(cross(v1 - ctr, v2 - ctr));
176-
177-
float3 n = GetNormal(v0 - ctr, v1 - ctr);
178-
float max_dist = d01;
179-
180-
if (d02 > max_dist) {
181-
max_dist = d02;
182-
n = GetNormal(v0 - ctr, v2 - ctr);
183-
}
184-
if (d12 > max_dist) {
185-
n = GetNormal(v1 - ctr, v2 - ctr);
172+
__device__ inline float3 FaceNormal(
173+
std::initializer_list<const float3> vertices) {
174+
const auto faceCenter = FaceCenter(vertices);
175+
auto normal = float3();
176+
auto maxDist = -1;
177+
for (auto v1 = vertices.begin(); v1 != vertices.end() - 1; ++v1) {
178+
for (auto v2 = std::next(v1); v2 != vertices.end(); ++v2) {
179+
const auto v1ToCenter = *v1 - faceCenter;
180+
const auto v2ToCenter = *v2 - faceCenter;
181+
const auto dist = norm(cross(v1ToCenter, v2ToCenter));
182+
if (dist > maxDist) {
183+
normal = GetNormal(v1ToCenter, v2ToCenter);
184+
maxDist = dist;
185+
}
186+
}
186187
}
187-
188-
return n;
188+
return normal;
189189
}
190190

191191
// The area of the face defined by vertices (v0, v1, v2)
@@ -201,79 +201,10 @@ TriNormal(const float3 v0, const float3 v1, const float3 v2) {
201201
//
202202
__device__ inline float FaceArea(const FaceVerts& tri) {
203203
// Get verts for face 1
204-
const float3 v0 = tri.v0;
205-
const float3 v1 = tri.v1;
206-
const float3 v2 = tri.v2;
207-
const float3 n = cross(v1 - v0, v2 - v0);
204+
const float3 n = cross(tri.v1 - tri.v0, tri.v2 - tri.v0);
208205
return norm(n) / 2.0;
209206
}
210207

211-
// The center of a plane defined by vertices (v0, v1, v2, v3)
212-
//
213-
// Args
214-
// v0, v1, v2, v3: float3 coordinates of the vertices of the plane
215-
//
216-
// Returns
217-
// float3: center of the plane
218-
//
219-
__device__ inline float3 PlaneCenter(
220-
const float3 v0,
221-
const float3 v1,
222-
const float3 v2,
223-
const float3 v3) {
224-
float3 ctr = (v0 + v1 + v2 + v3) / 4.0f;
225-
return ctr;
226-
}
227-
228-
// The normal of a planar face with vertices (v0, v1, v2, v3)
229-
// We find the "best" edges connecting the face center to the vertices,
230-
// such that the cross product between the edges is maximized.
231-
//
232-
// Args
233-
// e0, e1: float3 coordinates of the vertices of the plane
234-
//
235-
// Returns
236-
// float3: center of the plane
237-
//
238-
__device__ inline float3 PlaneNormal(
239-
const float3 v0,
240-
const float3 v1,
241-
const float3 v2,
242-
const float3 v3) {
243-
const float3 ctr = PlaneCenter(v0, v1, v2, v3);
244-
245-
const float d01 = norm(cross(v0 - ctr, v1 - ctr));
246-
const float d02 = norm(cross(v0 - ctr, v2 - ctr));
247-
const float d03 = norm(cross(v0 - ctr, v3 - ctr));
248-
const float d12 = norm(cross(v1 - ctr, v2 - ctr));
249-
const float d13 = norm(cross(v1 - ctr, v3 - ctr));
250-
const float d23 = norm(cross(v2 - ctr, v3 - ctr));
251-
252-
float max_dist = d01;
253-
float3 n = GetNormal(v0 - ctr, v1 - ctr);
254-
255-
if (d02 > max_dist) {
256-
max_dist = d02;
257-
n = GetNormal(v0 - ctr, v2 - ctr);
258-
}
259-
if (d03 > max_dist) {
260-
max_dist = d03;
261-
n = GetNormal(v0 - ctr, v3 - ctr);
262-
}
263-
if (d12 > max_dist) {
264-
max_dist = d12;
265-
n = GetNormal(v1 - ctr, v2 - ctr);
266-
}
267-
if (d13 > max_dist) {
268-
max_dist = d13;
269-
n = GetNormal(v1 - ctr, v3 - ctr);
270-
}
271-
if (d23 > max_dist) {
272-
n = GetNormal(v2 - ctr, v3 - ctr);
273-
}
274-
return n;
275-
}
276-
277208
// The normal of a box plane defined by the verts in `plane` such that it
278209
// points toward the centroid of the box given by `center`.
279210
//
@@ -290,17 +221,12 @@ template <typename FaceVertsPlane>
290221
__device__ inline float3 PlaneNormalDirection(
291222
const FaceVertsPlane& plane,
292223
const float3& center) {
293-
// The plane's vertices
294-
const float3 v0 = plane.v0;
295-
const float3 v1 = plane.v1;
296-
const float3 v2 = plane.v2;
297-
const float3 v3 = plane.v3;
298-
299224
// The plane's center
300-
const float3 plane_center = PlaneCenter(v0, v1, v2, v3);
225+
const float3 plane_center =
226+
FaceCenter({plane.v0, plane.v1, plane.v2, plane.v3});
301227

302228
// The plane's normal
303-
float3 n = PlaneNormal(v0, v1, v2, v3);
229+
float3 n = FaceNormal({plane.v0, plane.v1, plane.v2, plane.v3});
304230

305231
// We project the center on the plane defined by (v0, v1, v2, v3)
306232
// We can write center = plane_center + a * e0 + b * e1 + c * n
@@ -442,14 +368,8 @@ __device__ inline float3 PolyhedronCenter(
442368
//
443369
__device__ inline bool
444370
IsInside(const FaceVerts& plane, const float3& normal, const float3& point) {
445-
// Vertices of the plane
446-
const float3 v0 = plane.v0;
447-
const float3 v1 = plane.v1;
448-
const float3 v2 = plane.v2;
449-
const float3 v3 = plane.v3;
450-
451371
// The center of the plane
452-
const float3 plane_ctr = PlaneCenter(v0, v1, v2, v3);
372+
const float3 plane_ctr = FaceCenter({plane.v0, plane.v1, plane.v2, plane.v3});
453373

454374
// Every point p can be written as p = plane_ctr + a e0 + b e1 + c n
455375
// Solving for c:
@@ -478,14 +398,8 @@ __device__ inline float3 PlaneEdgeIntersection(
478398
const float3& normal,
479399
const float3& p0,
480400
const float3& p1) {
481-
// Vertices of the plane
482-
const float3 v0 = plane.v0;
483-
const float3 v1 = plane.v1;
484-
const float3 v2 = plane.v2;
485-
const float3 v3 = plane.v3;
486-
487401
// The center of the plane
488-
const float3 plane_ctr = PlaneCenter(v0, v1, v2, v3);
402+
const float3 plane_ctr = FaceCenter({plane.v0, plane.v1, plane.v2, plane.v3});
489403

490404
// The point of intersection can be parametrized
491405
// p = p0 + a (p1 - p0) where a in [0, 1]
@@ -548,30 +462,18 @@ __device__ inline std::tuple<float3, float3> ArgMaxVerts(
548462
__device__ inline bool IsCoplanarTriTri(
549463
const FaceVerts& tri1,
550464
const FaceVerts& tri2) {
551-
// Get verts for face 1
552-
const float3 v0_1 = tri1.v0;
553-
const float3 v1_1 = tri1.v1;
554-
const float3 v2_1 = tri1.v2;
465+
const float3 tri1_ctr = FaceCenter({tri1.v0, tri1.v1, tri1.v2});
466+
const float3 tri1_n = FaceNormal({tri1.v0, tri1.v1, tri1.v2});
555467

556-
const float3 tri1_ctr = TriCenter(v0_1, v1_1, v2_1);
557-
const float3 tri1_n = TriNormal(v0_1, v1_1, v2_1);
558-
559-
// Get verts for face 2
560-
const float3 v0_2 = tri2.v0;
561-
const float3 v1_2 = tri2.v1;
562-
const float3 v2_2 = tri2.v2;
563-
564-
const float3 tri2_ctr = TriCenter(v0_2, v1_2, v2_2);
565-
const float3 tri2_n = TriNormal(v0_2, v1_2, v2_2);
468+
const float3 tri2_ctr = FaceCenter({tri2.v0, tri2.v1, tri2.v2});
469+
const float3 tri2_n = FaceNormal({tri2.v0, tri2.v1, tri2.v2});
566470

567471
// Check if parallel
568472
const bool check1 = abs(dot(tri1_n, tri2_n)) > 1 - dEpsilon;
569473

570474
// Compute most distant points
571-
auto argvs =
475+
const auto [v1m, v2m] =
572476
ArgMaxVerts({tri1.v0, tri1.v1, tri1.v2}, {tri2.v0, tri2.v1, tri2.v2});
573-
const float3 v1m = std::get<0>(argvs);
574-
const float3 v2m = std::get<1>(argvs);
575477

576478
float3 n12m = v1m - v2m;
577479
n12m = n12m / fmaxf(norm(n12m), kEpsilon);
@@ -597,22 +499,15 @@ __device__ inline bool IsCoplanarTriPlane(
597499
const FaceVerts& tri,
598500
const FaceVerts& plane,
599501
const float3& normal) {
600-
// Get verts for tri
601-
const float3 v0t = tri.v0;
602-
const float3 v1t = tri.v1;
603-
const float3 v2t = tri.v2;
604-
605-
const float3 tri_ctr = TriCenter(v0t, v1t, v2t);
606-
const float3 nt = TriNormal(v0t, v1t, v2t);
502+
const float3 tri_ctr = FaceCenter({tri.v0, tri.v1, tri.v2});
503+
const float3 nt = FaceNormal({tri.v0, tri.v1, tri.v2});
607504

608505
// check if parallel
609506
const bool check1 = abs(dot(nt, normal)) > 1 - dEpsilon;
610507

611508
// Compute most distant points
612-
auto argvs = ArgMaxVerts(
509+
const auto [v1m, v2m] = ArgMaxVerts(
613510
{tri.v0, tri.v1, tri.v2}, {plane.v0, plane.v1, plane.v2, plane.v3});
614-
const float3 v1m = std::get<0>(argvs);
615-
const float3 v2m = std::get<1>(argvs);
616511

617512
float3 n12m = v1m - v2m;
618513
n12m = n12m / fmaxf(norm(n12m), kEpsilon);

0 commit comments

Comments
 (0)