1
1
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
2
3
+ from typing import Union
4
+
3
5
import torch
4
6
import torch .nn .functional as F
5
- from pytorch3d .ops .nearest_neighbor_points import nn_points_idx
7
+ from pytorch3d .ops .knn import knn_gather , knn_points
8
+ from pytorch3d .structures .pointclouds import Pointclouds
6
9
7
10
8
- def _validate_chamfer_reduction_inputs (batch_reduction : str , point_reduction : str ):
11
+ def _validate_chamfer_reduction_inputs (
12
+ batch_reduction : Union [str , None ], point_reduction : str
13
+ ):
9
14
"""Check the requested reductions are valid.
10
15
11
16
Args:
12
17
batch_reduction: Reduction operation to apply for the loss across the
13
- batch, can be one of ["none", " mean", "sum"].
18
+ batch, can be one of ["mean", "sum"] or None .
14
19
point_reduction: Reduction operation to apply for the loss across the
15
- points, can be one of ["none", "mean", "sum"].
20
+ points, can be one of ["mean", "sum"].
21
+ """
22
+ if batch_reduction is not None and batch_reduction not in ["mean" , "sum" ]:
23
+ raise ValueError ('batch_reduction must be one of ["mean", "sum"] or None' )
24
+ if point_reduction not in ["mean" , "sum" ]:
25
+ raise ValueError ('point_reduction must be one of ["mean", "sum"]' )
26
+
27
+
28
+ def _handle_pointcloud_input (
29
+ points : Union [torch .Tensor , Pointclouds ],
30
+ lengths : Union [torch .Tensor , None ],
31
+ normals : Union [torch .Tensor , None ],
32
+ ):
33
+ """
34
+ If points is an instance of Pointclouds, retrieve the padded points tensor
35
+ along with the number of points per batch and the padded normals.
36
+ Otherwise, return the input points (and normals) with the number of points per cloud
37
+ set to the size of the second dimension of `points`.
16
38
"""
17
- if batch_reduction not in ["none" , "mean" , "sum" ]:
18
- raise ValueError ('batch_reduction must be one of ["none", "mean", "sum"]' )
19
- if point_reduction not in ["none" , "mean" , "sum" ]:
20
- raise ValueError ('point_reduction must be one of ["none", "mean", "sum"]' )
21
- if batch_reduction == "none" and point_reduction == "none" :
22
- raise ValueError ('batch_reduction and point_reduction cannot both be "none".' )
39
+ if isinstance (points , Pointclouds ):
40
+ X = points .points_padded ()
41
+ lengths = points .num_points_per_cloud ()
42
+ normals = points .normals_padded () # either a tensor or None
43
+ elif torch .is_tensor (points ):
44
+ if points .ndim != 3 :
45
+ raise ValueError ("Expected points to be of shape (N, P, D)" )
46
+ X = points
47
+ if lengths is not None and (
48
+ lengths .ndim != 1 or lengths .shape [0 ] != X .shape [0 ]
49
+ ):
50
+ raise ValueError ("Expected lengths to be of shape (N,)" )
51
+ if lengths is None :
52
+ lengths = torch .full (
53
+ (X .shape [0 ],), X .shape [1 ], dtype = torch .int64 , device = points .device
54
+ )
55
+ if normals is not None and normals .ndim != 3 :
56
+ raise ValueError ("Expected normals to be of shape (N, P, 3" )
57
+ else :
58
+ raise ValueError (
59
+ "The input pointclouds should be either "
60
+ + "Pointclouds objects or torch.Tensor of shape "
61
+ + "(minibatch, num_points, 3)."
62
+ )
63
+ return X , lengths , normals
23
64
24
65
25
66
def chamfer_distance (
26
67
x ,
27
68
y ,
69
+ x_lengths = None ,
70
+ y_lengths = None ,
28
71
x_normals = None ,
29
72
y_normals = None ,
30
73
weights = None ,
31
- batch_reduction : str = "mean" ,
74
+ batch_reduction : Union [ str , None ] = "mean" ,
32
75
point_reduction : str = "mean" ,
33
76
):
34
77
"""
35
78
Chamfer distance between two pointclouds x and y.
36
79
37
80
Args:
38
- x: FloatTensor of shape (N, P1, D) representing a batch of point clouds
39
- with P1 points in each batch element, batch size N and feature
40
- dimension D.
41
- y: FloatTensor of shape (N, P2, D) representing a batch of point clouds
42
- with P2 points in each batch element, batch size N and feature
43
- dimension D.
81
+ x: FloatTensor of shape (N, P1, D) or a Pointclouds object representing
82
+ a batch of point clouds with at most P1 points in each batch element,
83
+ batch size N and feature dimension D.
84
+ y: FloatTensor of shape (N, P2, D) or a Pointclouds object representing
85
+ a batch of point clouds with at most P2 points in each batch element,
86
+ batch size N and feature dimension D.
87
+ x_lengths: Optional LongTensor of shape (N,) giving the number of points in each
88
+ cloud in x.
89
+ y_lengths: Optional LongTensor of shape (N,) giving the number of points in each
90
+ cloud in x.
44
91
x_normals: Optional FloatTensor of shape (N, P1, D).
45
92
y_normals: Optional FloatTensor of shape (N, P2, D).
46
93
weights: Optional FloatTensor of shape (N,) giving weights for
47
94
batch elements for reduction operation.
48
95
batch_reduction: Reduction operation to apply for the loss across the
49
- batch, can be one of ["none", " mean", "sum"].
96
+ batch, can be one of ["mean", "sum"] or None .
50
97
point_reduction: Reduction operation to apply for the loss across the
51
- points, can be one of ["none", " mean", "sum"].
98
+ points, can be one of ["mean", "sum"].
52
99
53
100
Returns:
54
101
2-element tuple containing
@@ -61,16 +108,31 @@ def chamfer_distance(
61
108
"""
62
109
_validate_chamfer_reduction_inputs (batch_reduction , point_reduction )
63
110
111
+ x , x_lengths , x_normals = _handle_pointcloud_input (x , x_lengths , x_normals )
112
+ y , y_lengths , y_normals = _handle_pointcloud_input (y , y_lengths , y_normals )
113
+
114
+ return_normals = x_normals is not None and y_normals is not None
115
+
64
116
N , P1 , D = x .shape
65
117
P2 = y .shape [1 ]
66
118
119
+ # Check if inputs are heterogeneous and create a lengths mask.
120
+ is_x_heterogeneous = ~ (x_lengths == P1 ).all ()
121
+ is_y_heterogeneous = ~ (y_lengths == P2 ).all ()
122
+ x_mask = (
123
+ torch .arange (P1 , device = x .device )[None ] >= x_lengths [:, None ]
124
+ ) # shape [N, P1]
125
+ y_mask = (
126
+ torch .arange (P2 , device = y .device )[None ] >= y_lengths [:, None ]
127
+ ) # shape [N, P2]
128
+
67
129
if y .shape [0 ] != N or y .shape [2 ] != D :
68
130
raise ValueError ("y does not have the correct shape." )
69
131
if weights is not None :
70
132
if weights .size (0 ) != N :
71
133
raise ValueError ("weights must be of shape (N,)." )
72
134
if not (weights >= 0 ).all ():
73
- raise ValueError ("weights can not be nonnegative ." )
135
+ raise ValueError ("weights cannot be negative ." )
74
136
if weights .sum () == 0.0 :
75
137
weights = weights .view (N , 1 )
76
138
if batch_reduction in ["mean" , "sum" ]:
@@ -80,46 +142,60 @@ def chamfer_distance(
80
142
)
81
143
return ((x .sum ((1 , 2 )) * weights ) * 0.0 , (x .sum ((1 , 2 )) * weights ) * 0.0 )
82
144
83
- return_normals = x_normals is not None and y_normals is not None
84
145
cham_norm_x = x .new_zeros (())
85
146
cham_norm_y = x .new_zeros (())
86
147
87
- x_near , xidx_near , x_normals_near = nn_points_idx (x , y , y_normals )
88
- y_near , yidx_near , y_normals_near = nn_points_idx (y , x , x_normals )
148
+ x_dists , x_idx = knn_points (x , y , lengths1 = x_lengths , lengths2 = y_lengths , K = 1 )
149
+ y_dists , y_idx = knn_points (y , x , lengths1 = y_lengths , lengths2 = x_lengths , K = 1 )
89
150
90
- cham_x = (x - x_near ).norm (dim = 2 , p = 2 ) ** 2.0 # (N, P1)
91
- cham_y = (y - y_near ).norm (dim = 2 , p = 2 ) ** 2.0 # (N, P2)
151
+ cham_x = x_dists [..., 0 ] # (N, P1)
152
+ cham_y = y_dists [..., 0 ] # (N, P2)
153
+
154
+ if is_x_heterogeneous :
155
+ cham_x [x_mask ] = 0.0
156
+ if is_y_heterogeneous :
157
+ cham_y [y_mask ] = 0.0
92
158
93
159
if weights is not None :
94
160
cham_x *= weights .view (N , 1 )
95
161
cham_y *= weights .view (N , 1 )
96
162
97
163
if return_normals :
164
+ # Gather the normals using the indices and keep only value for k=0
165
+ x_normals_near = knn_gather (y_normals , x_idx , y_lengths )[..., 0 , :]
166
+ y_normals_near = knn_gather (x_normals , y_idx , x_lengths )[..., 0 , :]
167
+
98
168
cham_norm_x = 1 - torch .abs (
99
169
F .cosine_similarity (x_normals , x_normals_near , dim = 2 , eps = 1e-6 )
100
170
)
101
171
cham_norm_y = 1 - torch .abs (
102
172
F .cosine_similarity (y_normals , y_normals_near , dim = 2 , eps = 1e-6 )
103
173
)
174
+
175
+ if is_x_heterogeneous :
176
+ cham_norm_x [x_mask ] = 0.0
177
+ if is_y_heterogeneous :
178
+ cham_norm_y [y_mask ] = 0.0
179
+
104
180
if weights is not None :
105
181
cham_norm_x *= weights .view (N , 1 )
106
182
cham_norm_y *= weights .view (N , 1 )
107
183
108
- if point_reduction != "none" :
109
- # If not 'none' then either 'sum' or 'mean'.
110
- cham_x = cham_x .sum (1 ) # (N,)
111
- cham_y = cham_y .sum (1 ) # (N,)
184
+ # Apply point reduction
185
+ cham_x = cham_x .sum (1 ) # (N,)
186
+ cham_y = cham_y .sum (1 ) # (N,)
187
+ if return_normals :
188
+ cham_norm_x = cham_norm_x .sum (1 ) # (N,)
189
+ cham_norm_y = cham_norm_y .sum (1 ) # (N,)
190
+ if point_reduction == "mean" :
191
+ cham_x /= x_lengths
192
+ cham_y /= y_lengths
112
193
if return_normals :
113
- cham_norm_x = cham_norm_x .sum (1 ) # (N,)
114
- cham_norm_y = cham_norm_y .sum (1 ) # (N,)
115
- if point_reduction == "mean" :
116
- cham_x /= P1
117
- cham_y /= P2
118
- if return_normals :
119
- cham_norm_x /= P1
120
- cham_norm_y /= P2
194
+ cham_norm_x /= x_lengths
195
+ cham_norm_y /= y_lengths
121
196
122
- if batch_reduction != "none" :
197
+ if batch_reduction is not None :
198
+ # batch_reduction == "sum"
123
199
cham_x = cham_x .sum ()
124
200
cham_y = cham_y .sum ()
125
201
if return_normals :
0 commit comments