Skip to content

Commit 28ccdb7

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
Enable __getitem__ for Cameras to return an instance of Cameras
Summary: Added a custom `__getitem__` method to `CamerasBase` which returns an instance of the appropriate camera instead of the `TensorAccessor` class. Long term we should deprecate the `TensorAccessor` and the `__getitem__` method on `TensorProperties` FB: In the next diff I will update the uses of `select_cameras` in implicitron. Reviewed By: bottler Differential Revision: D33185885 fbshipit-source-id: c31995d0eb126981e91ba61a6151d5404b263f67
1 parent cc3259b commit 28ccdb7

File tree

3 files changed

+224
-13
lines changed

3 files changed

+224
-13
lines changed

pytorch3d/renderer/cameras.py

Lines changed: 105 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ class CamerasBase(TensorProperties):
7575
boolean argument of the function.
7676
"""
7777

78+
# Used in __getitem__ to index the relevant fields
79+
# When creating a new camera, this should be set in the __init__
80+
_FIELDS: Tuple = ()
81+
7882
def get_projection_transform(self):
7983
"""
8084
Calculate the projective transformation matrix.
@@ -362,6 +366,55 @@ def get_image_size(self):
362366
"""
363367
return self.image_size if hasattr(self, "image_size") else None
364368

369+
def __getitem__(
370+
self, index: Union[int, List[int], torch.LongTensor]
371+
) -> "CamerasBase":
372+
"""
373+
Override for the __getitem__ method in TensorProperties which needs to be
374+
refactored.
375+
376+
Args:
377+
index: an int/list/long tensor used to index all the fields in the cameras given by
378+
self._FIELDS.
379+
Returns:
380+
if `index` is an index int/list/long tensor return an instance of the current
381+
cameras class with only the values at the selected index.
382+
"""
383+
384+
kwargs = {}
385+
386+
if not isinstance(index, (int, list, torch.LongTensor)):
387+
msg = "Invalid index type, expected int, List[int] or torch.LongTensor; got %r"
388+
raise ValueError(msg % type(index))
389+
390+
if isinstance(index, int):
391+
index = [index]
392+
393+
if max(index) >= len(self):
394+
raise ValueError(f"Index {max(index)} is out of bounds for select cameras")
395+
396+
for field in self._FIELDS:
397+
val = getattr(self, field, None)
398+
if val is None:
399+
continue
400+
401+
# e.g. "in_ndc" is set as attribute "_in_ndc" on the class
402+
# but provided as "in_ndc" on initialization
403+
if field.startswith("_"):
404+
field = field[1:]
405+
406+
if isinstance(val, (str, bool)):
407+
kwargs[field] = val
408+
elif isinstance(val, torch.Tensor):
409+
# In the init, all inputs will be converted to
410+
# tensors before setting as attributes
411+
kwargs[field] = val[index]
412+
else:
413+
raise ValueError(f"Field {field} type is not supported for indexing")
414+
415+
kwargs["device"] = self.device
416+
return self.__class__(**kwargs)
417+
365418

366419
############################################################
367420
# Field of View Camera Classes #
@@ -434,6 +487,18 @@ class FoVPerspectiveCameras(CamerasBase):
434487
for rasterization.
435488
"""
436489

490+
# For __getitem__
491+
_FIELDS = (
492+
"K",
493+
"znear",
494+
"zfar",
495+
"aspect_ratio",
496+
"fov",
497+
"R",
498+
"T",
499+
"degrees",
500+
)
501+
437502
def __init__(
438503
self,
439504
znear=1.0,
@@ -590,7 +655,7 @@ def unproject_points(
590655
xy_depth: torch.Tensor,
591656
world_coordinates: bool = True,
592657
scaled_depth_input: bool = False,
593-
**kwargs
658+
**kwargs,
594659
) -> torch.Tensor:
595660
""">!
596661
FoV cameras further allow for passing depth in world units
@@ -681,6 +746,20 @@ class FoVOrthographicCameras(CamerasBase):
681746
The definition of the parameters follow the OpenGL orthographic camera.
682747
"""
683748

749+
# For __getitem__
750+
_FIELDS = (
751+
"K",
752+
"znear",
753+
"zfar",
754+
"R",
755+
"T",
756+
"max_y",
757+
"min_y",
758+
"max_x",
759+
"min_x",
760+
"scale_xyz",
761+
)
762+
684763
def __init__(
685764
self,
686765
znear=1.0,
@@ -819,7 +898,7 @@ def unproject_points(
819898
xy_depth: torch.Tensor,
820899
world_coordinates: bool = True,
821900
scaled_depth_input: bool = False,
822-
**kwargs
901+
**kwargs,
823902
) -> torch.Tensor:
824903
""">!
825904
FoV cameras further allow for passing depth in world units
@@ -907,6 +986,17 @@ class PerspectiveCameras(CamerasBase):
907986
If parameters are specified in screen space, `in_ndc` must be set to False.
908987
"""
909988

989+
# For __getitem__
990+
_FIELDS = (
991+
"K",
992+
"R",
993+
"T",
994+
"focal_length",
995+
"principal_point",
996+
"_in_ndc", # arg is in_ndc but attribute set as _in_ndc
997+
"image_size",
998+
)
999+
9101000
def __init__(
9111001
self,
9121002
focal_length=1.0,
@@ -1007,7 +1097,7 @@ def unproject_points(
10071097
xy_depth: torch.Tensor,
10081098
world_coordinates: bool = True,
10091099
from_ndc: bool = False,
1010-
**kwargs
1100+
**kwargs,
10111101
) -> torch.Tensor:
10121102
"""
10131103
Args:
@@ -1126,6 +1216,17 @@ class OrthographicCameras(CamerasBase):
11261216
If parameters are specified in screen space, `in_ndc` must be set to False.
11271217
"""
11281218

1219+
# For __getitem__
1220+
_FIELDS = (
1221+
"K",
1222+
"R",
1223+
"T",
1224+
"focal_length",
1225+
"principal_point",
1226+
"_in_ndc",
1227+
"image_size",
1228+
)
1229+
11291230
def __init__(
11301231
self,
11311232
focal_length=1.0,
@@ -1225,7 +1326,7 @@ def unproject_points(
12251326
xy_depth: torch.Tensor,
12261327
world_coordinates: bool = True,
12271328
from_ndc: bool = False,
1228-
**kwargs
1329+
**kwargs,
12291330
) -> torch.Tensor:
12301331
"""
12311332
Args:

pytorch3d/renderer/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def __getitem__(self, index: Union[int, slice]) -> TensorAccessor:
155155
Returns:
156156
if `index` is an index int/slice return a TensorAccessor class
157157
with getattribute/setattribute methods which return/update the value
158-
at the index in the original camera.
158+
at the index in the original class.
159159
"""
160160
if isinstance(index, (int, slice)):
161161
return TensorAccessor(class_object=self, index=index)

tests/test_cameras.py

Lines changed: 118 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -783,18 +783,53 @@ def test_camera_class_init(self):
783783
self.assertTrue(cam.znear.shape == (2,))
784784
self.assertTrue(cam.zfar.shape == (2,))
785785

786-
# update znear element 1
787-
cam[1].znear = 20.0
788-
self.assertTrue(cam.znear[1] == 20.0)
789-
790-
# Get item and get value
791-
c0 = cam[0]
792-
self.assertTrue(c0.zfar == 100.0)
793-
794786
# Test to
795787
new_cam = cam.to(device=device)
796788
self.assertTrue(new_cam.device == device)
797789

790+
def test_getitem(self):
791+
R_matrix = torch.randn((6, 3, 3))
792+
cam = FoVPerspectiveCameras(znear=10.0, zfar=100.0, R=R_matrix)
793+
794+
# Check get item returns an instance of the same class
795+
# with all the same keys
796+
c0 = cam[0]
797+
self.assertTrue(isinstance(c0, FoVPerspectiveCameras))
798+
self.assertEqual(cam.__dict__.keys(), c0.__dict__.keys())
799+
800+
# Check all fields correct in get item with int index
801+
self.assertEqual(len(c0), 1)
802+
self.assertClose(c0.zfar, torch.tensor([100.0]))
803+
self.assertClose(c0.znear, torch.tensor([10.0]))
804+
self.assertClose(c0.R, R_matrix[0:1, ...])
805+
self.assertEqual(c0.device, torch.device("cpu"))
806+
807+
# Check list(int) index
808+
c012 = cam[[0, 1, 2]]
809+
self.assertEqual(len(c012), 3)
810+
self.assertClose(c012.zfar, torch.tensor([100.0] * 3))
811+
self.assertClose(c012.znear, torch.tensor([10.0] * 3))
812+
self.assertClose(c012.R, R_matrix[0:3, ...])
813+
814+
# Check torch.LongTensor index
815+
index = torch.tensor([1, 3, 5], dtype=torch.int64)
816+
c135 = cam[index]
817+
self.assertEqual(len(c135), 3)
818+
self.assertClose(c135.zfar, torch.tensor([100.0] * 3))
819+
self.assertClose(c135.znear, torch.tensor([10.0] * 3))
820+
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
821+
822+
# Check errors with get item
823+
with self.assertRaisesRegex(ValueError, "out of bounds"):
824+
cam[6]
825+
826+
with self.assertRaisesRegex(ValueError, "Invalid index type"):
827+
cam[slice(0, 1)]
828+
829+
with self.assertRaisesRegex(ValueError, "Invalid index type"):
830+
index = torch.tensor([1, 3, 5], dtype=torch.float32)
831+
cam[index]
832+
798833
def test_get_full_transform(self):
799834
cam = FoVPerspectiveCameras()
800835
T = torch.tensor([0.0, 0.0, 1.0]).view(1, -1)
@@ -919,6 +954,30 @@ def test_perspective_type(self):
919954
self.assertFalse(cam.is_perspective())
920955
self.assertEqual(cam.get_znear(), 1.0)
921956

957+
def test_getitem(self):
958+
R_matrix = torch.randn((6, 3, 3))
959+
scale = torch.tensor([[1.0, 1.0, 1.0]], requires_grad=True)
960+
cam = FoVOrthographicCameras(
961+
znear=10.0, zfar=100.0, R=R_matrix, scale_xyz=scale
962+
)
963+
964+
# Check get item returns an instance of the same class
965+
# with all the same keys
966+
c0 = cam[0]
967+
self.assertTrue(isinstance(c0, FoVOrthographicCameras))
968+
self.assertEqual(cam.__dict__.keys(), c0.__dict__.keys())
969+
970+
# Check torch.LongTensor index
971+
index = torch.tensor([1, 3, 5], dtype=torch.int64)
972+
c135 = cam[index]
973+
self.assertEqual(len(c135), 3)
974+
self.assertClose(c135.zfar, torch.tensor([100.0] * 3))
975+
self.assertClose(c135.znear, torch.tensor([10.0] * 3))
976+
self.assertClose(c135.min_x, torch.tensor([-1.0] * 3))
977+
self.assertClose(c135.max_x, torch.tensor([1.0] * 3))
978+
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
979+
self.assertClose(c135.scale_xyz, scale.expand(3, -1))
980+
922981

923982
############################################################
924983
# Orthographic Camera #
@@ -976,6 +1035,30 @@ def test_perspective_type(self):
9761035
self.assertFalse(cam.is_perspective())
9771036
self.assertIsNone(cam.get_znear())
9781037

1038+
def test_getitem(self):
1039+
R_matrix = torch.randn((6, 3, 3))
1040+
principal_point = torch.randn((6, 2, 1))
1041+
focal_length = 5.0
1042+
cam = OrthographicCameras(
1043+
R=R_matrix,
1044+
focal_length=focal_length,
1045+
principal_point=principal_point,
1046+
)
1047+
1048+
# Check get item returns an instance of the same class
1049+
# with all the same keys
1050+
c0 = cam[0]
1051+
self.assertTrue(isinstance(c0, OrthographicCameras))
1052+
self.assertEqual(cam.__dict__.keys(), c0.__dict__.keys())
1053+
1054+
# Check torch.LongTensor index
1055+
index = torch.tensor([1, 3, 5], dtype=torch.int64)
1056+
c135 = cam[index]
1057+
self.assertEqual(len(c135), 3)
1058+
self.assertClose(c135.focal_length, torch.tensor([5.0] * 3))
1059+
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
1060+
self.assertClose(c135.principal_point, principal_point[[1, 3, 5], ...])
1061+
9791062

9801063
############################################################
9811064
# Perspective Camera #
@@ -1027,3 +1110,30 @@ def test_perspective_type(self):
10271110
cam = PerspectiveCameras(focal_length=5.0, principal_point=((2.5, 2.5),))
10281111
self.assertTrue(cam.is_perspective())
10291112
self.assertIsNone(cam.get_znear())
1113+
1114+
def test_getitem(self):
1115+
R_matrix = torch.randn((6, 3, 3))
1116+
principal_point = torch.randn((6, 2, 1))
1117+
focal_length = 5.0
1118+
cam = PerspectiveCameras(
1119+
R=R_matrix,
1120+
focal_length=focal_length,
1121+
principal_point=principal_point,
1122+
)
1123+
1124+
# Check get item returns an instance of the same class
1125+
# with all the same keys
1126+
c0 = cam[0]
1127+
self.assertTrue(isinstance(c0, PerspectiveCameras))
1128+
self.assertEqual(cam.__dict__.keys(), c0.__dict__.keys())
1129+
1130+
# Check torch.LongTensor index
1131+
index = torch.tensor([1, 3, 5], dtype=torch.int64)
1132+
c135 = cam[index]
1133+
self.assertEqual(len(c135), 3)
1134+
self.assertClose(c135.focal_length, torch.tensor([5.0] * 3))
1135+
self.assertClose(c135.R, R_matrix[[1, 3, 5], ...])
1136+
self.assertClose(c135.principal_point, principal_point[[1, 3, 5], ...])
1137+
1138+
# Check in_ndc is handled correctly
1139+
self.assertEqual(cam._in_ndc, c0._in_ndc)

0 commit comments

Comments
 (0)