@@ -783,18 +783,53 @@ def test_camera_class_init(self):
783
783
self .assertTrue (cam .znear .shape == (2 ,))
784
784
self .assertTrue (cam .zfar .shape == (2 ,))
785
785
786
- # update znear element 1
787
- cam [1 ].znear = 20.0
788
- self .assertTrue (cam .znear [1 ] == 20.0 )
789
-
790
- # Get item and get value
791
- c0 = cam [0 ]
792
- self .assertTrue (c0 .zfar == 100.0 )
793
-
794
786
# Test to
795
787
new_cam = cam .to (device = device )
796
788
self .assertTrue (new_cam .device == device )
797
789
790
+ def test_getitem (self ):
791
+ R_matrix = torch .randn ((6 , 3 , 3 ))
792
+ cam = FoVPerspectiveCameras (znear = 10.0 , zfar = 100.0 , R = R_matrix )
793
+
794
+ # Check get item returns an instance of the same class
795
+ # with all the same keys
796
+ c0 = cam [0 ]
797
+ self .assertTrue (isinstance (c0 , FoVPerspectiveCameras ))
798
+ self .assertEqual (cam .__dict__ .keys (), c0 .__dict__ .keys ())
799
+
800
+ # Check all fields correct in get item with int index
801
+ self .assertEqual (len (c0 ), 1 )
802
+ self .assertClose (c0 .zfar , torch .tensor ([100.0 ]))
803
+ self .assertClose (c0 .znear , torch .tensor ([10.0 ]))
804
+ self .assertClose (c0 .R , R_matrix [0 :1 , ...])
805
+ self .assertEqual (c0 .device , torch .device ("cpu" ))
806
+
807
+ # Check list(int) index
808
+ c012 = cam [[0 , 1 , 2 ]]
809
+ self .assertEqual (len (c012 ), 3 )
810
+ self .assertClose (c012 .zfar , torch .tensor ([100.0 ] * 3 ))
811
+ self .assertClose (c012 .znear , torch .tensor ([10.0 ] * 3 ))
812
+ self .assertClose (c012 .R , R_matrix [0 :3 , ...])
813
+
814
+ # Check torch.LongTensor index
815
+ index = torch .tensor ([1 , 3 , 5 ], dtype = torch .int64 )
816
+ c135 = cam [index ]
817
+ self .assertEqual (len (c135 ), 3 )
818
+ self .assertClose (c135 .zfar , torch .tensor ([100.0 ] * 3 ))
819
+ self .assertClose (c135 .znear , torch .tensor ([10.0 ] * 3 ))
820
+ self .assertClose (c135 .R , R_matrix [[1 , 3 , 5 ], ...])
821
+
822
+ # Check errors with get item
823
+ with self .assertRaisesRegex (ValueError , "out of bounds" ):
824
+ cam [6 ]
825
+
826
+ with self .assertRaisesRegex (ValueError , "Invalid index type" ):
827
+ cam [slice (0 , 1 )]
828
+
829
+ with self .assertRaisesRegex (ValueError , "Invalid index type" ):
830
+ index = torch .tensor ([1 , 3 , 5 ], dtype = torch .float32 )
831
+ cam [index ]
832
+
798
833
def test_get_full_transform (self ):
799
834
cam = FoVPerspectiveCameras ()
800
835
T = torch .tensor ([0.0 , 0.0 , 1.0 ]).view (1 , - 1 )
@@ -919,6 +954,30 @@ def test_perspective_type(self):
919
954
self .assertFalse (cam .is_perspective ())
920
955
self .assertEqual (cam .get_znear (), 1.0 )
921
956
957
+ def test_getitem (self ):
958
+ R_matrix = torch .randn ((6 , 3 , 3 ))
959
+ scale = torch .tensor ([[1.0 , 1.0 , 1.0 ]], requires_grad = True )
960
+ cam = FoVOrthographicCameras (
961
+ znear = 10.0 , zfar = 100.0 , R = R_matrix , scale_xyz = scale
962
+ )
963
+
964
+ # Check get item returns an instance of the same class
965
+ # with all the same keys
966
+ c0 = cam [0 ]
967
+ self .assertTrue (isinstance (c0 , FoVOrthographicCameras ))
968
+ self .assertEqual (cam .__dict__ .keys (), c0 .__dict__ .keys ())
969
+
970
+ # Check torch.LongTensor index
971
+ index = torch .tensor ([1 , 3 , 5 ], dtype = torch .int64 )
972
+ c135 = cam [index ]
973
+ self .assertEqual (len (c135 ), 3 )
974
+ self .assertClose (c135 .zfar , torch .tensor ([100.0 ] * 3 ))
975
+ self .assertClose (c135 .znear , torch .tensor ([10.0 ] * 3 ))
976
+ self .assertClose (c135 .min_x , torch .tensor ([- 1.0 ] * 3 ))
977
+ self .assertClose (c135 .max_x , torch .tensor ([1.0 ] * 3 ))
978
+ self .assertClose (c135 .R , R_matrix [[1 , 3 , 5 ], ...])
979
+ self .assertClose (c135 .scale_xyz , scale .expand (3 , - 1 ))
980
+
922
981
923
982
############################################################
924
983
# Orthographic Camera #
@@ -976,6 +1035,30 @@ def test_perspective_type(self):
976
1035
self .assertFalse (cam .is_perspective ())
977
1036
self .assertIsNone (cam .get_znear ())
978
1037
1038
+ def test_getitem (self ):
1039
+ R_matrix = torch .randn ((6 , 3 , 3 ))
1040
+ principal_point = torch .randn ((6 , 2 , 1 ))
1041
+ focal_length = 5.0
1042
+ cam = OrthographicCameras (
1043
+ R = R_matrix ,
1044
+ focal_length = focal_length ,
1045
+ principal_point = principal_point ,
1046
+ )
1047
+
1048
+ # Check get item returns an instance of the same class
1049
+ # with all the same keys
1050
+ c0 = cam [0 ]
1051
+ self .assertTrue (isinstance (c0 , OrthographicCameras ))
1052
+ self .assertEqual (cam .__dict__ .keys (), c0 .__dict__ .keys ())
1053
+
1054
+ # Check torch.LongTensor index
1055
+ index = torch .tensor ([1 , 3 , 5 ], dtype = torch .int64 )
1056
+ c135 = cam [index ]
1057
+ self .assertEqual (len (c135 ), 3 )
1058
+ self .assertClose (c135 .focal_length , torch .tensor ([5.0 ] * 3 ))
1059
+ self .assertClose (c135 .R , R_matrix [[1 , 3 , 5 ], ...])
1060
+ self .assertClose (c135 .principal_point , principal_point [[1 , 3 , 5 ], ...])
1061
+
979
1062
980
1063
############################################################
981
1064
# Perspective Camera #
@@ -1027,3 +1110,30 @@ def test_perspective_type(self):
1027
1110
cam = PerspectiveCameras (focal_length = 5.0 , principal_point = ((2.5 , 2.5 ),))
1028
1111
self .assertTrue (cam .is_perspective ())
1029
1112
self .assertIsNone (cam .get_znear ())
1113
+
1114
+ def test_getitem (self ):
1115
+ R_matrix = torch .randn ((6 , 3 , 3 ))
1116
+ principal_point = torch .randn ((6 , 2 , 1 ))
1117
+ focal_length = 5.0
1118
+ cam = PerspectiveCameras (
1119
+ R = R_matrix ,
1120
+ focal_length = focal_length ,
1121
+ principal_point = principal_point ,
1122
+ )
1123
+
1124
+ # Check get item returns an instance of the same class
1125
+ # with all the same keys
1126
+ c0 = cam [0 ]
1127
+ self .assertTrue (isinstance (c0 , PerspectiveCameras ))
1128
+ self .assertEqual (cam .__dict__ .keys (), c0 .__dict__ .keys ())
1129
+
1130
+ # Check torch.LongTensor index
1131
+ index = torch .tensor ([1 , 3 , 5 ], dtype = torch .int64 )
1132
+ c135 = cam [index ]
1133
+ self .assertEqual (len (c135 ), 3 )
1134
+ self .assertClose (c135 .focal_length , torch .tensor ([5.0 ] * 3 ))
1135
+ self .assertClose (c135 .R , R_matrix [[1 , 3 , 5 ], ...])
1136
+ self .assertClose (c135 .principal_point , principal_point [[1 , 3 , 5 ], ...])
1137
+
1138
+ # Check in_ndc is handled correctly
1139
+ self .assertEqual (cam ._in_ndc , c0 ._in_ndc )
0 commit comments