Skip to content

Commit 790eb8c

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
Chamfer for Pointclouds object
Summary: Allow Pointclouds objects and heterogenous data to be provided for Chamfer loss. Remove "none" as an option for point_reduction because it doesn't make sense and in the current implementation is effectively the same as "sum". Possible improvement: create specialised operations for sum and cosine_similarity of padded tensors, to avoid having to create masks. sum would be useful elsewhere. Reviewed By: gkioxari Differential Revision: D20816301 fbshipit-source-id: 0f32073210225d157c029d80de450eecdb64f4d2
1 parent 677b0bd commit 790eb8c

File tree

3 files changed

+680
-168
lines changed

3 files changed

+680
-168
lines changed

pytorch3d/loss/chamfer.py

Lines changed: 114 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,101 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

3+
from typing import Union
4+
35
import torch
46
import torch.nn.functional as F
5-
from pytorch3d.ops.nearest_neighbor_points import nn_points_idx
7+
from pytorch3d.ops.knn import knn_gather, knn_points
8+
from pytorch3d.structures.pointclouds import Pointclouds
69

710

8-
def _validate_chamfer_reduction_inputs(batch_reduction: str, point_reduction: str):
11+
def _validate_chamfer_reduction_inputs(
12+
batch_reduction: Union[str, None], point_reduction: str
13+
):
914
"""Check the requested reductions are valid.
1015
1116
Args:
1217
batch_reduction: Reduction operation to apply for the loss across the
13-
batch, can be one of ["none", "mean", "sum"].
18+
batch, can be one of ["mean", "sum"] or None.
1419
point_reduction: Reduction operation to apply for the loss across the
15-
points, can be one of ["none", "mean", "sum"].
20+
points, can be one of ["mean", "sum"].
21+
"""
22+
if batch_reduction is not None and batch_reduction not in ["mean", "sum"]:
23+
raise ValueError('batch_reduction must be one of ["mean", "sum"] or None')
24+
if point_reduction not in ["mean", "sum"]:
25+
raise ValueError('point_reduction must be one of ["mean", "sum"]')
26+
27+
28+
def _handle_pointcloud_input(
29+
points: Union[torch.Tensor, Pointclouds],
30+
lengths: Union[torch.Tensor, None],
31+
normals: Union[torch.Tensor, None],
32+
):
33+
"""
34+
If points is an instance of Pointclouds, retrieve the padded points tensor
35+
along with the number of points per batch and the padded normals.
36+
Otherwise, return the input points (and normals) with the number of points per cloud
37+
set to the size of the second dimension of `points`.
1638
"""
17-
if batch_reduction not in ["none", "mean", "sum"]:
18-
raise ValueError('batch_reduction must be one of ["none", "mean", "sum"]')
19-
if point_reduction not in ["none", "mean", "sum"]:
20-
raise ValueError('point_reduction must be one of ["none", "mean", "sum"]')
21-
if batch_reduction == "none" and point_reduction == "none":
22-
raise ValueError('batch_reduction and point_reduction cannot both be "none".')
39+
if isinstance(points, Pointclouds):
40+
X = points.points_padded()
41+
lengths = points.num_points_per_cloud()
42+
normals = points.normals_padded() # either a tensor or None
43+
elif torch.is_tensor(points):
44+
if points.ndim != 3:
45+
raise ValueError("Expected points to be of shape (N, P, D)")
46+
X = points
47+
if lengths is not None and (
48+
lengths.ndim != 1 or lengths.shape[0] != X.shape[0]
49+
):
50+
raise ValueError("Expected lengths to be of shape (N,)")
51+
if lengths is None:
52+
lengths = torch.full(
53+
(X.shape[0],), X.shape[1], dtype=torch.int64, device=points.device
54+
)
55+
if normals is not None and normals.ndim != 3:
56+
raise ValueError("Expected normals to be of shape (N, P, 3")
57+
else:
58+
raise ValueError(
59+
"The input pointclouds should be either "
60+
+ "Pointclouds objects or torch.Tensor of shape "
61+
+ "(minibatch, num_points, 3)."
62+
)
63+
return X, lengths, normals
2364

2465

2566
def chamfer_distance(
2667
x,
2768
y,
69+
x_lengths=None,
70+
y_lengths=None,
2871
x_normals=None,
2972
y_normals=None,
3073
weights=None,
31-
batch_reduction: str = "mean",
74+
batch_reduction: Union[str, None] = "mean",
3275
point_reduction: str = "mean",
3376
):
3477
"""
3578
Chamfer distance between two pointclouds x and y.
3679
3780
Args:
38-
x: FloatTensor of shape (N, P1, D) representing a batch of point clouds
39-
with P1 points in each batch element, batch size N and feature
40-
dimension D.
41-
y: FloatTensor of shape (N, P2, D) representing a batch of point clouds
42-
with P2 points in each batch element, batch size N and feature
43-
dimension D.
81+
x: FloatTensor of shape (N, P1, D) or a Pointclouds object representing
82+
a batch of point clouds with at most P1 points in each batch element,
83+
batch size N and feature dimension D.
84+
y: FloatTensor of shape (N, P2, D) or a Pointclouds object representing
85+
a batch of point clouds with at most P2 points in each batch element,
86+
batch size N and feature dimension D.
87+
x_lengths: Optional LongTensor of shape (N,) giving the number of points in each
88+
cloud in x.
89+
y_lengths: Optional LongTensor of shape (N,) giving the number of points in each
90+
cloud in x.
4491
x_normals: Optional FloatTensor of shape (N, P1, D).
4592
y_normals: Optional FloatTensor of shape (N, P2, D).
4693
weights: Optional FloatTensor of shape (N,) giving weights for
4794
batch elements for reduction operation.
4895
batch_reduction: Reduction operation to apply for the loss across the
49-
batch, can be one of ["none", "mean", "sum"].
96+
batch, can be one of ["mean", "sum"] or None.
5097
point_reduction: Reduction operation to apply for the loss across the
51-
points, can be one of ["none", "mean", "sum"].
98+
points, can be one of ["mean", "sum"].
5299
53100
Returns:
54101
2-element tuple containing
@@ -61,16 +108,31 @@ def chamfer_distance(
61108
"""
62109
_validate_chamfer_reduction_inputs(batch_reduction, point_reduction)
63110

111+
x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals)
112+
y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals)
113+
114+
return_normals = x_normals is not None and y_normals is not None
115+
64116
N, P1, D = x.shape
65117
P2 = y.shape[1]
66118

119+
# Check if inputs are heterogeneous and create a lengths mask.
120+
is_x_heterogeneous = ~(x_lengths == P1).all()
121+
is_y_heterogeneous = ~(y_lengths == P2).all()
122+
x_mask = (
123+
torch.arange(P1, device=x.device)[None] >= x_lengths[:, None]
124+
) # shape [N, P1]
125+
y_mask = (
126+
torch.arange(P2, device=y.device)[None] >= y_lengths[:, None]
127+
) # shape [N, P2]
128+
67129
if y.shape[0] != N or y.shape[2] != D:
68130
raise ValueError("y does not have the correct shape.")
69131
if weights is not None:
70132
if weights.size(0) != N:
71133
raise ValueError("weights must be of shape (N,).")
72134
if not (weights >= 0).all():
73-
raise ValueError("weights can not be nonnegative.")
135+
raise ValueError("weights cannot be negative.")
74136
if weights.sum() == 0.0:
75137
weights = weights.view(N, 1)
76138
if batch_reduction in ["mean", "sum"]:
@@ -80,46 +142,60 @@ def chamfer_distance(
80142
)
81143
return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0)
82144

83-
return_normals = x_normals is not None and y_normals is not None
84145
cham_norm_x = x.new_zeros(())
85146
cham_norm_y = x.new_zeros(())
86147

87-
x_near, xidx_near, x_normals_near = nn_points_idx(x, y, y_normals)
88-
y_near, yidx_near, y_normals_near = nn_points_idx(y, x, x_normals)
148+
x_dists, x_idx = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, K=1)
149+
y_dists, y_idx = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, K=1)
89150

90-
cham_x = (x - x_near).norm(dim=2, p=2) ** 2.0 # (N, P1)
91-
cham_y = (y - y_near).norm(dim=2, p=2) ** 2.0 # (N, P2)
151+
cham_x = x_dists[..., 0] # (N, P1)
152+
cham_y = y_dists[..., 0] # (N, P2)
153+
154+
if is_x_heterogeneous:
155+
cham_x[x_mask] = 0.0
156+
if is_y_heterogeneous:
157+
cham_y[y_mask] = 0.0
92158

93159
if weights is not None:
94160
cham_x *= weights.view(N, 1)
95161
cham_y *= weights.view(N, 1)
96162

97163
if return_normals:
164+
# Gather the normals using the indices and keep only value for k=0
165+
x_normals_near = knn_gather(y_normals, x_idx, y_lengths)[..., 0, :]
166+
y_normals_near = knn_gather(x_normals, y_idx, x_lengths)[..., 0, :]
167+
98168
cham_norm_x = 1 - torch.abs(
99169
F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6)
100170
)
101171
cham_norm_y = 1 - torch.abs(
102172
F.cosine_similarity(y_normals, y_normals_near, dim=2, eps=1e-6)
103173
)
174+
175+
if is_x_heterogeneous:
176+
cham_norm_x[x_mask] = 0.0
177+
if is_y_heterogeneous:
178+
cham_norm_y[y_mask] = 0.0
179+
104180
if weights is not None:
105181
cham_norm_x *= weights.view(N, 1)
106182
cham_norm_y *= weights.view(N, 1)
107183

108-
if point_reduction != "none":
109-
# If not 'none' then either 'sum' or 'mean'.
110-
cham_x = cham_x.sum(1) # (N,)
111-
cham_y = cham_y.sum(1) # (N,)
184+
# Apply point reduction
185+
cham_x = cham_x.sum(1) # (N,)
186+
cham_y = cham_y.sum(1) # (N,)
187+
if return_normals:
188+
cham_norm_x = cham_norm_x.sum(1) # (N,)
189+
cham_norm_y = cham_norm_y.sum(1) # (N,)
190+
if point_reduction == "mean":
191+
cham_x /= x_lengths
192+
cham_y /= y_lengths
112193
if return_normals:
113-
cham_norm_x = cham_norm_x.sum(1) # (N,)
114-
cham_norm_y = cham_norm_y.sum(1) # (N,)
115-
if point_reduction == "mean":
116-
cham_x /= P1
117-
cham_y /= P2
118-
if return_normals:
119-
cham_norm_x /= P1
120-
cham_norm_y /= P2
194+
cham_norm_x /= x_lengths
195+
cham_norm_y /= y_lengths
121196

122-
if batch_reduction != "none":
197+
if batch_reduction is not None:
198+
# batch_reduction == "sum"
123199
cham_x = cham_x.sum()
124200
cham_y = cham_y.sum()
125201
if return_normals:

tests/bm_chamfer.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

3+
from itertools import product
34

45
import torch
56
from fvcore.common.benchmark import benchmark
@@ -20,8 +21,23 @@ def bm_chamfer() -> None:
2021
)
2122

2223
if torch.cuda.is_available():
23-
kwargs_list = kwargs_list_naive + [
24-
{"batch_size": 1, "P1": 1000, "P2": 3000, "return_normals": False},
25-
{"batch_size": 1, "P1": 1000, "P2": 30000, "return_normals": True},
26-
]
24+
kwargs_list = []
25+
batch_size = [1, 32]
26+
P1 = [32, 1000, 10000]
27+
P2 = [64, 3000, 30000]
28+
return_normals = [True, False]
29+
homogeneous = [True, False]
30+
test_cases = product(batch_size, P1, P2, return_normals, homogeneous)
31+
32+
for case in test_cases:
33+
b, p1, p2, n, h = case
34+
kwargs_list.append(
35+
{
36+
"batch_size": b,
37+
"P1": p1,
38+
"P2": p2,
39+
"return_normals": n,
40+
"homogeneous": h,
41+
}
42+
)
2743
benchmark(TestChamfer.chamfer_with_init, "CHAMFER", kwargs_list, warmup_iters=1)

0 commit comments

Comments
 (0)