2
2
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
3
3
4
4
5
- import numpy as np
6
5
import unittest
7
- import torch
8
6
7
+ import numpy as np
8
+ import torch
9
9
from common_testing import TestCaseMixin
10
-
11
10
from pytorch3d .ops import points_alignment
12
11
from pytorch3d .structures .pointclouds import Pointclouds
13
12
from pytorch3d .transforms import rotation_conversions
@@ -54,18 +53,14 @@ def random_rotation(batch_size, dim, device=None):
54
53
# generate random rotation matrices with orthogonalization of
55
54
# random normal square matrices, followed by a transformation
56
55
# that ensures determinant(R)==1
57
- H = torch .randn (
58
- batch_size , dim , dim , dtype = torch .float32 , device = device
59
- )
56
+ H = torch .randn (batch_size , dim , dim , dtype = torch .float32 , device = device )
60
57
U , _ , V = torch .svd (H )
61
58
E = torch .eye (dim , dtype = torch .float32 , device = device )[None ].repeat (
62
59
batch_size , 1 , 1
63
60
)
64
61
E [:, - 1 , - 1 ] = torch .det (torch .bmm (U , V .transpose (2 , 1 )))
65
62
R = torch .bmm (torch .bmm (U , E ), V .transpose (2 , 1 ))
66
- assert torch .allclose (
67
- torch .det (R ), R .new_ones (batch_size ), atol = 1e-4
68
- )
63
+ assert torch .allclose (torch .det (R ), R .new_ones (batch_size ), atol = 1e-4 )
69
64
70
65
return R
71
66
@@ -94,19 +89,13 @@ def init_point_cloud(
94
89
dtype = torch .int64 ,
95
90
)
96
91
X_list = [
97
- torch .randn (
98
- int (n_pt ), dim , device = device , dtype = torch .float32
99
- )
92
+ torch .randn (int (n_pt ), dim , device = device , dtype = torch .float32 )
100
93
for n_pt in n_points_per_batch
101
94
]
102
95
X = Pointclouds (X_list )
103
96
else :
104
97
X = torch .randn (
105
- batch_size ,
106
- n_points ,
107
- dim ,
108
- device = device ,
109
- dtype = torch .float32 ,
98
+ batch_size , n_points , dim , device = device , dtype = torch .float32
110
99
)
111
100
X = Pointclouds (list (X ))
112
101
else :
@@ -143,11 +132,7 @@ def generate_random_reflection(batch_size=10, dim=3, device=None):
143
132
# randomly select one of the dimensions to reflect for each
144
133
# element in the batch
145
134
dim_to_reflect = torch .randint (
146
- low = 0 ,
147
- high = dim ,
148
- size = (batch_size ,),
149
- device = device ,
150
- dtype = torch .int64 ,
135
+ low = 0 , high = dim , size = (batch_size ,), device = device , dtype = torch .int64
151
136
)
152
137
153
138
# convert dim_to_reflect to a batch of reflection matrices M
@@ -211,8 +196,7 @@ def corresponding_points_alignment(
211
196
weights *= (weights * template .size ()[1 ] > 0.3 ).to (weights )
212
197
if use_pointclouds : # convert to List[Tensor]
213
198
weights = [
214
- w [:npts ]
215
- for w , npts in zip (weights , X .num_points_per_cloud ())
199
+ w [:npts ] for w , npts in zip (weights , X .num_points_per_cloud ())
216
200
]
217
201
218
202
torch .cuda .synchronize ()
@@ -255,7 +239,7 @@ def test_corresponding_points_alignment(self, batch_size=10):
255
239
use_point_clouds_cases = (
256
240
(True , False ) if dim == 3 and n_points > 3 else (False ,)
257
241
)
258
- for random_weights in (False , True , ):
242
+ for random_weights in (False , True ):
259
243
for use_pointclouds in use_point_clouds_cases :
260
244
for estimate_scale in (False , True ):
261
245
for reflect in (False , True ):
@@ -325,8 +309,7 @@ def _test_single_corresponding_points_alignment(
325
309
weights *= (weights * template .size ()[1 ] > 0.3 ).to (weights )
326
310
if use_pointclouds : # convert to List[Tensor]
327
311
weights = [
328
- w [:npts ]
329
- for w , npts in zip (weights , X .num_points_per_cloud ())
312
+ w [:npts ] for w , npts in zip (weights , X .num_points_per_cloud ())
330
313
]
331
314
332
315
# apply the generated transformation to the generated
@@ -374,9 +357,9 @@ def align_and_get_mse(weights_):
374
357
375
358
X_t_est = _apply_pcl_transformation (X_noisy , R_n , T_n , s = s_n )
376
359
377
- return (
378
- (( X_t_est - X_t ) * weights [..., None ]) ** 2
379
- ). sum ( dim = ( 1 , 2 )) / weights .sum (dim = - 1 )
360
+ return ((( X_t_est - X_t ) * weights [..., None ]) ** 2 ). sum (
361
+ dim = ( 1 , 2 )
362
+ ) / weights .sum (dim = - 1 )
380
363
381
364
# check that using weights leads to lower weighted_MSE(X_noisy, X_t)
382
365
self .assertTrue (
@@ -386,9 +369,7 @@ def align_and_get_mse(weights_):
386
369
if reflect and not allow_reflection :
387
370
# check that all rotations have det=1
388
371
self ._assert_all_close (
389
- torch .det (R_est ),
390
- R_est .new_ones (batch_size ),
391
- assert_error_message ,
372
+ torch .det (R_est ), R_est .new_ones (batch_size ), assert_error_message
392
373
)
393
374
394
375
else :
@@ -430,6 +411,4 @@ def _assert_all_close(self, a_, b_, err_message, weights=None, atol=1e-6):
430
411
if weights is None :
431
412
self .assertClose (a_ , b_ , atol = atol , msg = err_message )
432
413
else :
433
- self .assertClose (
434
- a_ * weights , b_ * weights , atol = atol , msg = err_message
435
- )
414
+ self .assertClose (a_ * weights , b_ * weights , atol = atol , msg = err_message )
0 commit comments