Skip to content

Commit b87058c

Browse files
bottlerfacebook-github-bot
authored andcommitted
fix recent lint
Summary: lint clean again Reviewed By: patricklabatut Differential Revision: D20868775 fbshipit-source-id: ade4301c1012c5c6943186432465215701d635a9
1 parent 90dc7a0 commit b87058c

File tree

6 files changed

+29
-60
lines changed

6 files changed

+29
-60
lines changed

pytorch3d/ops/points_alignment.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

33
import warnings
4-
from typing import List, Optional, Tuple, Union
5-
import torch
4+
from typing import List, Tuple, Union
65

7-
from pytorch3d.structures.pointclouds import Pointclouds
8-
from pytorch3d.structures import utils as strutil
6+
import torch
97
from pytorch3d.ops import utils as oputil
8+
from pytorch3d.structures import utils as strutil
9+
from pytorch3d.structures.pointclouds import Pointclouds
1010

1111

1212
def corresponding_points_alignment(
@@ -77,9 +77,7 @@ def corresponding_points_alignment(
7777
weights = strutil.list_to_padded(weights)[..., 0]
7878

7979
if Xt.shape[:2] != weights.shape:
80-
raise ValueError(
81-
"weights should have the same first two dimensions as X."
82-
)
80+
raise ValueError("weights should have the same first two dimensions as X.")
8381

8482
b, n, dim = Xt.shape
8583

@@ -120,9 +118,7 @@ def corresponding_points_alignment(
120118
U, S, V = torch.svd(XYcov)
121119

122120
# identity matrix used for fixing reflections
123-
E = torch.eye(dim, dtype=XYcov.dtype, device=XYcov.device)[None].repeat(
124-
b, 1, 1
125-
)
121+
E = torch.eye(dim, dtype=XYcov.dtype, device=XYcov.device)[None].repeat(b, 1, 1)
126122

127123
if not allow_reflection:
128124
# reflection test:

pytorch3d/ops/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def wmean(
2727
* if `weights` is None => `mean(x, dim)`,
2828
* otherwise => `sum(x*w, dim) / max{sum(w, dim), eps}`.
2929
"""
30-
args = dict(dim=dim, keepdim=keepdim)
30+
args = {"dim": dim, "keepdim": keepdim}
3131

3232
if weight is None:
3333
return x.mean(**args)
@@ -38,7 +38,6 @@ def wmean(
3838
):
3939
raise ValueError("wmean: weights are not compatible with the tensor")
4040

41-
return (
42-
(x * weight[..., None]).sum(**args)
43-
/ weight[..., None].sum(**args).clamp(eps)
41+
return (x * weight[..., None]).sum(**args) / weight[..., None].sum(**args).clamp(
42+
eps
4443
)

tests/bm_points_alignment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
from copy import deepcopy
55
from itertools import product
6-
from fvcore.common.benchmark import benchmark
76

7+
from fvcore.common.benchmark import benchmark
88
from test_points_alignment import TestCorrespondingPointsAlignment
99

1010

tests/common_testing.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

3-
from typing import Optional
4-
53
import unittest
4+
from typing import Optional
65

76
import numpy as np
87
import torch
@@ -57,7 +56,5 @@ def assertClose(
5756
input, other, rtol=rtol, atol=atol, equal_nan=equal_nan
5857
)
5958
else:
60-
close = np.allclose(
61-
input, other, rtol=rtol, atol=atol, equal_nan=equal_nan
62-
)
59+
close = np.allclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan)
6360
self.assertTrue(close, msg)

tests/test_ops_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33

44
import numpy as np
55
import torch
6-
76
from common_testing import TestCaseMixin
8-
97
from pytorch3d.ops import utils as oputil
108

9+
1110
class TestOpsUtils(TestCaseMixin, unittest.TestCase):
1211
def setUp(self) -> None:
1312
super().setUp()
@@ -62,8 +61,7 @@ def test_wmean(self):
6261
# test dim
6362
weight = torch.rand(x.shape[0], n_points, device=device)
6463
weight_np = np.tile(
65-
weight[:, :, None].cpu().data.numpy(),
66-
(1, 1, x_np.shape[-1]),
64+
weight[:, :, None].cpu().data.numpy(), (1, 1, x_np.shape[-1])
6765
)
6866
mean = oputil.wmean(x, dim=0, weight=weight, keepdim=False)
6967
mean_gt = np.average(x_np, axis=0, weights=weight_np)

tests/test_points_alignment.py

Lines changed: 15 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
33

44

5-
import numpy as np
65
import unittest
7-
import torch
86

7+
import numpy as np
8+
import torch
99
from common_testing import TestCaseMixin
10-
1110
from pytorch3d.ops import points_alignment
1211
from pytorch3d.structures.pointclouds import Pointclouds
1312
from pytorch3d.transforms import rotation_conversions
@@ -54,18 +53,14 @@ def random_rotation(batch_size, dim, device=None):
5453
# generate random rotation matrices with orthogonalization of
5554
# random normal square matrices, followed by a transformation
5655
# that ensures determinant(R)==1
57-
H = torch.randn(
58-
batch_size, dim, dim, dtype=torch.float32, device=device
59-
)
56+
H = torch.randn(batch_size, dim, dim, dtype=torch.float32, device=device)
6057
U, _, V = torch.svd(H)
6158
E = torch.eye(dim, dtype=torch.float32, device=device)[None].repeat(
6259
batch_size, 1, 1
6360
)
6461
E[:, -1, -1] = torch.det(torch.bmm(U, V.transpose(2, 1)))
6562
R = torch.bmm(torch.bmm(U, E), V.transpose(2, 1))
66-
assert torch.allclose(
67-
torch.det(R), R.new_ones(batch_size), atol=1e-4
68-
)
63+
assert torch.allclose(torch.det(R), R.new_ones(batch_size), atol=1e-4)
6964

7065
return R
7166

@@ -94,19 +89,13 @@ def init_point_cloud(
9489
dtype=torch.int64,
9590
)
9691
X_list = [
97-
torch.randn(
98-
int(n_pt), dim, device=device, dtype=torch.float32
99-
)
92+
torch.randn(int(n_pt), dim, device=device, dtype=torch.float32)
10093
for n_pt in n_points_per_batch
10194
]
10295
X = Pointclouds(X_list)
10396
else:
10497
X = torch.randn(
105-
batch_size,
106-
n_points,
107-
dim,
108-
device=device,
109-
dtype=torch.float32,
98+
batch_size, n_points, dim, device=device, dtype=torch.float32
11099
)
111100
X = Pointclouds(list(X))
112101
else:
@@ -143,11 +132,7 @@ def generate_random_reflection(batch_size=10, dim=3, device=None):
143132
# randomly select one of the dimensions to reflect for each
144133
# element in the batch
145134
dim_to_reflect = torch.randint(
146-
low=0,
147-
high=dim,
148-
size=(batch_size,),
149-
device=device,
150-
dtype=torch.int64,
135+
low=0, high=dim, size=(batch_size,), device=device, dtype=torch.int64
151136
)
152137

153138
# convert dim_to_reflect to a batch of reflection matrices M
@@ -211,8 +196,7 @@ def corresponding_points_alignment(
211196
weights *= (weights * template.size()[1] > 0.3).to(weights)
212197
if use_pointclouds: # convert to List[Tensor]
213198
weights = [
214-
w[:npts]
215-
for w, npts in zip(weights, X.num_points_per_cloud())
199+
w[:npts] for w, npts in zip(weights, X.num_points_per_cloud())
216200
]
217201

218202
torch.cuda.synchronize()
@@ -255,7 +239,7 @@ def test_corresponding_points_alignment(self, batch_size=10):
255239
use_point_clouds_cases = (
256240
(True, False) if dim == 3 and n_points > 3 else (False,)
257241
)
258-
for random_weights in (False, True,):
242+
for random_weights in (False, True):
259243
for use_pointclouds in use_point_clouds_cases:
260244
for estimate_scale in (False, True):
261245
for reflect in (False, True):
@@ -325,8 +309,7 @@ def _test_single_corresponding_points_alignment(
325309
weights *= (weights * template.size()[1] > 0.3).to(weights)
326310
if use_pointclouds: # convert to List[Tensor]
327311
weights = [
328-
w[:npts]
329-
for w, npts in zip(weights, X.num_points_per_cloud())
312+
w[:npts] for w, npts in zip(weights, X.num_points_per_cloud())
330313
]
331314

332315
# apply the generated transformation to the generated
@@ -374,9 +357,9 @@ def align_and_get_mse(weights_):
374357

375358
X_t_est = _apply_pcl_transformation(X_noisy, R_n, T_n, s=s_n)
376359

377-
return (
378-
((X_t_est - X_t) * weights[..., None]) ** 2
379-
).sum(dim=(1, 2)) / weights.sum(dim=-1)
360+
return (((X_t_est - X_t) * weights[..., None]) ** 2).sum(
361+
dim=(1, 2)
362+
) / weights.sum(dim=-1)
380363

381364
# check that using weights leads to lower weighted_MSE(X_noisy, X_t)
382365
self.assertTrue(
@@ -386,9 +369,7 @@ def align_and_get_mse(weights_):
386369
if reflect and not allow_reflection:
387370
# check that all rotations have det=1
388371
self._assert_all_close(
389-
torch.det(R_est),
390-
R_est.new_ones(batch_size),
391-
assert_error_message,
372+
torch.det(R_est), R_est.new_ones(batch_size), assert_error_message
392373
)
393374

394375
else:
@@ -430,6 +411,4 @@ def _assert_all_close(self, a_, b_, err_message, weights=None, atol=1e-6):
430411
if weights is None:
431412
self.assertClose(a_, b_, atol=atol, msg=err_message)
432413
else:
433-
self.assertClose(
434-
a_ * weights, b_ * weights, atol=atol, msg=err_message
435-
)
414+
self.assertClose(a_ * weights, b_ * weights, atol=atol, msg=err_message)

0 commit comments

Comments
 (0)