Skip to content

Commit e9fb6c2

Browse files
Pyre Bot Jrfacebook-github-bot
Pyre Bot Jr
authored andcommitted
Add annotations to vision/fair/pytorch3d
Reviewed By: shannonzhu Differential Revision: D33970393 fbshipit-source-id: 9b4dfaccfc3793fd37705a923d689cb14c9d26ba
1 parent c2862ff commit e9fb6c2

File tree

21 files changed

+65
-49
lines changed

21 files changed

+65
-49
lines changed

pytorch3d/datasets/r2n2/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,9 @@ def collate_batched_R2N2(batch: List[Dict]): # pragma: no cover
9898
return collated_dict
9999

100100

101-
def compute_extrinsic_matrix(azimuth, elevation, distance): # pragma: no cover
101+
def compute_extrinsic_matrix(
102+
azimuth: float, elevation: float, distance: float
103+
): # pragma: no cover
102104
"""
103105
Copied from meshrcnn codebase:
104106
https://github.com/facebookresearch/meshrcnn/blob/main/shapenet/utils/coords.py#L96
@@ -138,6 +140,7 @@ def compute_extrinsic_matrix(azimuth, elevation, distance): # pragma: no cover
138140
# rotates the model 90 degrees about the x axis. To compensate for this quirk we
139141
# roll that rotation into the extrinsic matrix here
140142
rot = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
143+
# pyre-fixme[16]: `Tensor` has no attribute `mm`.
141144
RT = RT.mm(rot.to(RT))
142145

143146
return RT
@@ -384,7 +387,7 @@ def voxelize(voxel_coords, P, V): # pragma: no cover
384387
return voxels
385388

386389

387-
def project_verts(verts, P, eps=1e-1): # pragma: no cover
390+
def project_verts(verts, P, eps: float = 1e-1): # pragma: no cover
388391
"""
389392
Copied from meshrcnn codebase:
390393
https://github.com/facebookresearch/meshrcnn/blob/main/shapenet/utils/coords.py#L159

pytorch3d/io/obj_io.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
)
3333

3434

35-
def _format_faces_indices(faces_indices, max_index, device, pad_value=None):
35+
def _format_faces_indices(faces_indices, max_index: int, device, pad_value=None):
3636
"""
3737
Format indices and check for invalid values. Indices can refer to
3838
values in one of the face properties: vertices, textures or normals.
@@ -57,6 +57,7 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None):
5757
)
5858

5959
if pad_value is not None:
60+
# pyre-fixme[28]: Unexpected keyword argument `dim`.
6061
mask = faces_indices.eq(pad_value).all(dim=-1)
6162

6263
# Change to 0 based indexing.
@@ -66,14 +67,15 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None):
6667
faces_indices[(faces_indices < 0)] += max_index
6768

6869
if pad_value is not None:
70+
# pyre-fixme[61]: `mask` is undefined, or not always defined.
6971
faces_indices[mask] = pad_value
7072

7173
return _check_faces_indices(faces_indices, max_index, pad_value)
7274

7375

7476
def load_obj(
7577
f,
76-
load_textures=True,
78+
load_textures: bool = True,
7779
create_texture_atlas: bool = False,
7880
texture_atlas_size: int = 4,
7981
texture_wrap: Optional[str] = "repeat",
@@ -351,7 +353,7 @@ def _parse_face(
351353
faces_normals_idx,
352354
faces_textures_idx,
353355
faces_materials_idx,
354-
):
356+
) -> None:
355357
face = tokens[1:]
356358
face_list = [f.split("/") for f in face]
357359
face_verts = []
@@ -546,7 +548,7 @@ def _load_materials(
546548
def _load_obj(
547549
f_obj,
548550
*,
549-
data_dir,
551+
data_dir: str,
550552
load_textures: bool = True,
551553
create_texture_atlas: bool = False,
552554
texture_atlas_size: int = 4,

pytorch3d/io/ply_io.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,9 @@ def _read_ply_element_ascii(f, definition: _PlyElementType):
463463
return data
464464

465465

466-
def _read_raw_array(f, aim: str, length: int, dtype: type = np.uint8, dtype_size=1):
466+
def _read_raw_array(
467+
f, aim: str, length: int, dtype: type = np.uint8, dtype_size: int = 1
468+
):
467469
"""
468470
Read [length] elements from a file.
469471

pytorch3d/io/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def nullcontext(x):
2828
PathOrStr = Union[pathlib.Path, str]
2929

3030

31-
def _open_file(f, path_manager: PathManager, mode="r") -> ContextManager[IO]:
31+
def _open_file(f, path_manager: PathManager, mode: str = "r") -> ContextManager[IO]:
3232
if isinstance(f, str):
3333
f = path_manager.open(f, mode)
3434
return contextlib.closing(f)

pytorch3d/loss/chamfer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
def _validate_chamfer_reduction_inputs(
1616
batch_reduction: Union[str, None], point_reduction: str
17-
):
17+
) -> None:
1818
"""Check the requested reductions are valid.
1919
2020
Args:

pytorch3d/ops/knn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def knn_points(
106106
version: int = -1,
107107
return_nn: bool = False,
108108
return_sorted: bool = True,
109-
):
109+
) -> _KNN:
110110
"""
111111
K-Nearest neighbors on point clouds.
112112

pytorch3d/ops/points_normals.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def estimate_pointcloud_local_coord_frames(
166166
return curvatures, local_coord_frames
167167

168168

169-
def _disambiguate_vector_directions(pcl, knns, vecs):
169+
def _disambiguate_vector_directions(pcl, knns, vecs: float) -> float:
170170
"""
171171
Disambiguates normal directions according to [1].
172172
@@ -180,6 +180,7 @@ def _disambiguate_vector_directions(pcl, knns, vecs):
180180
# each element of the neighborhood
181181
df = knns - pcl[:, :, None]
182182
# projection of the difference on the principal direction
183+
# pyre-fixme[16]: `float` has no attribute `__getitem__`.
183184
proj = (vecs[:, :, None] * df).sum(3)
184185
# check how many projections are positive
185186
n_pos = (proj > 0).type_as(knns).sum(2, keepdim=True)

pytorch3d/ops/points_to_volumes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ def _check_points_to_volumes_inputs(
479479
volume_features: torch.Tensor,
480480
grid_sizes: torch.LongTensor,
481481
mask: Optional[torch.Tensor] = None,
482-
):
482+
) -> None:
483483

484484
max_grid_size = grid_sizes.max(dim=0).values
485485
if torch.prod(max_grid_size) > volume_densities.shape[1]:

pytorch3d/ops/subdivide_meshes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ def create_verts_index(verts_per_mesh, edges_per_mesh, device=None):
400400
return verts_idx
401401

402402

403-
def create_faces_index(faces_per_mesh, device=None):
403+
def create_faces_index(faces_per_mesh: int, device=None):
404404
"""
405405
Helper function to group the faces indices for each mesh. New faces are
406406
stacked at the end of the original faces tensor, so in order to have
@@ -417,7 +417,9 @@ def create_faces_index(faces_per_mesh, device=None):
417417
"""
418418
# e.g. faces_per_mesh = [2, 5, 3]
419419

420+
# pyre-fixme[16]: `int` has no attribute `sum`.
420421
F = faces_per_mesh.sum() # e.g. 10
422+
# pyre-fixme[16]: `int` has no attribute `cumsum`.
421423
faces_per_mesh_cumsum = faces_per_mesh.cumsum(dim=0) # (N,) e.g. (2, 7, 10)
422424

423425
switch1_idx = faces_per_mesh_cumsum.clone()

pytorch3d/ops/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def convert_pointclouds_to_tensor(pcl: Union[torch.Tensor, "Pointclouds"]):
150150
return X, num_points
151151

152152

153-
def is_pointclouds(pcl: Union[torch.Tensor, "Pointclouds"]):
153+
def is_pointclouds(pcl: Union[torch.Tensor, "Pointclouds"]) -> bool:
154154
"""Checks whether the input `pcl` is an instance of `Pointclouds`
155155
by checking the existence of `points_padded` and `num_points_per_cloud`
156156
functions.

pytorch3d/renderer/cameras.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -427,10 +427,10 @@ def __getitem__(
427427

428428

429429
def OpenGLPerspectiveCameras(
430-
znear=1.0,
431-
zfar=100.0,
432-
aspect_ratio=1.0,
433-
fov=60.0,
430+
znear: float = 1.0,
431+
zfar: float = 100.0,
432+
aspect_ratio: float = 1.0,
433+
fov: float = 60.0,
434434
degrees: bool = True,
435435
R: torch.Tensor = _R,
436436
T: torch.Tensor = _T,
@@ -709,12 +709,12 @@ def in_ndc(self):
709709

710710

711711
def OpenGLOrthographicCameras(
712-
znear=1.0,
713-
zfar=100.0,
714-
top=1.0,
715-
bottom=-1.0,
716-
left=-1.0,
717-
right=1.0,
712+
znear: float = 1.0,
713+
zfar: float = 100.0,
714+
top: float = 1.0,
715+
bottom: float = -1.0,
716+
left: float = -1.0,
717+
right: float = 1.0,
718718
scale_xyz=((1.0, 1.0, 1.0),), # (1, 3)
719719
R: torch.Tensor = _R,
720720
T: torch.Tensor = _T,
@@ -956,7 +956,7 @@ def in_ndc(self):
956956

957957

958958
def SfMPerspectiveCameras(
959-
focal_length=1.0,
959+
focal_length: float = 1.0,
960960
principal_point=((0.0, 0.0),),
961961
R: torch.Tensor = _R,
962962
T: torch.Tensor = _T,
@@ -1194,7 +1194,7 @@ def in_ndc(self):
11941194

11951195

11961196
def SfMOrthographicCameras(
1197-
focal_length=1.0,
1197+
focal_length: float = 1.0,
11981198
principal_point=((0.0, 0.0),),
11991199
R: torch.Tensor = _R,
12001200
T: torch.Tensor = _T,
@@ -1645,9 +1645,9 @@ def look_at_rotation(
16451645

16461646

16471647
def look_at_view_transform(
1648-
dist=1.0,
1649-
elev=0.0,
1650-
azim=0.0,
1648+
dist: float = 1.0,
1649+
elev: float = 0.0,
1650+
azim: float = 0.0,
16511651
degrees: bool = True,
16521652
eye: Optional[Sequence] = None,
16531653
at=((0, 0, 0),), # (1, 3)

pytorch3d/renderer/implicit/raymarching.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def forward(
162162
return opacities
163163

164164

165-
def _shifted_cumprod(x, shift=1):
165+
def _shifted_cumprod(x, shift: int = 1):
166166
"""
167167
Computes `torch.cumprod(x, dim=-1)` and prepends `shift` number of
168168
ones and removes `shift` trailing elements to/from the last dimension
@@ -177,7 +177,7 @@ def _shifted_cumprod(x, shift=1):
177177

178178
def _check_density_bounds(
179179
rays_densities: torch.Tensor, bounds: Tuple[float, float] = (0.0, 1.0)
180-
):
180+
) -> None:
181181
"""
182182
Checks whether the elements of `rays_densities` range within `bounds`.
183183
If not issues a warning.
@@ -197,7 +197,7 @@ def _check_raymarcher_inputs(
197197
features_can_be_none: bool = False,
198198
z_can_be_none: bool = False,
199199
density_1d: bool = True,
200-
):
200+
) -> None:
201201
"""
202202
Checks the validity of the inputs to raymarching algorithms.
203203
"""

pytorch3d/renderer/implicit/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def _validate_ray_bundle_variables(
9898
rays_origins: torch.Tensor,
9999
rays_directions: torch.Tensor,
100100
rays_lengths: torch.Tensor,
101-
):
101+
) -> None:
102102
"""
103103
Validate the shapes of RayBundle variables
104104
`rays_origins`, `rays_directions`, and `rays_lengths`.

pytorch3d/renderer/lighting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def specular(self, normals, points, camera_position, shininess) -> torch.Tensor:
323323
return torch.zeros_like(points)
324324

325325

326-
def _validate_light_properties(obj):
326+
def _validate_light_properties(obj) -> None:
327327
props = ("ambient_color", "diffuse_color", "specular_color")
328328
for n in props:
329329
t = getattr(obj, n)

pytorch3d/renderer/mesh/shader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def TexturedSoftPhongShader(
301301
lights: Optional[TensorProperties] = None,
302302
materials: Optional[Materials] = None,
303303
blend_params: Optional[BlendParams] = None,
304-
):
304+
) -> SoftPhongShader:
305305
"""
306306
TexturedSoftPhongShader class has been DEPRECATED. Use SoftPhongShader instead.
307307
Preserving TexturedSoftPhongShader as a function for backwards compatibility.

pytorch3d/structures/meshes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1557,7 +1557,7 @@ def sample_textures(self, fragments):
15571557
raise ValueError("Meshes does not have textures")
15581558

15591559

1560-
def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True):
1560+
def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True) -> Meshes:
15611561
"""
15621562
Merge multiple Meshes objects, i.e. concatenate the meshes objects. They
15631563
must all be on the same device. If include_textures is true, they must all

pytorch3d/structures/pointclouds.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1224,7 +1224,7 @@ def inside_box(self, box):
12241224
return coord_inside.all(dim=-1)
12251225

12261226

1227-
def join_pointclouds_as_batch(pointclouds: Sequence[Pointclouds]):
1227+
def join_pointclouds_as_batch(pointclouds: Sequence[Pointclouds]) -> Pointclouds:
12281228
"""
12291229
Merge a list of Pointclouds objects into a single batched Pointclouds
12301230
object. All pointclouds must be on the same device.

pytorch3d/transforms/math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch
1111

1212

13-
DEFAULT_ACOS_BOUND = 1.0 - 1e-4
13+
DEFAULT_ACOS_BOUND: float = 1.0 - 1e-4
1414

1515

1616
def acos_linear_extrapolation(

pytorch3d/transforms/transform3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,7 @@ def _broadcast_bmm(a, b):
754754

755755

756756
@torch.no_grad()
757-
def _check_valid_rotation_matrix(R, tol: float = 1e-7):
757+
def _check_valid_rotation_matrix(R, tol: float = 1e-7) -> None:
758758
"""
759759
Determine if R is a valid rotation matrix by checking it satisfies the
760760
following conditions:

0 commit comments

Comments
 (0)