|
36 | 36 |
|
37 | 37 | import numpy as np
|
38 | 38 | import torch
|
| 39 | +from pytorch3d.common.datatypes import Device |
39 | 40 | from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
40 | 41 | from pytorch3d.renderer.cameras import (
|
41 | 42 | camera_position_from_spherical_angles,
|
@@ -149,14 +150,17 @@ def ndc_to_screen_points_naive(points, imsize):
|
149 | 150 |
|
150 | 151 |
|
151 | 152 | def init_random_cameras(
|
152 |
| - cam_type: typing.Type[CamerasBase], batch_size: int, random_z: bool = False |
| 153 | + cam_type: typing.Type[CamerasBase], |
| 154 | + batch_size: int, |
| 155 | + random_z: bool = False, |
| 156 | + device: Device = "cpu", |
153 | 157 | ):
|
154 | 158 | cam_params = {}
|
155 | 159 | T = torch.randn(batch_size, 3) * 0.03
|
156 | 160 | if not random_z:
|
157 | 161 | T[:, 2] = 4
|
158 | 162 | R = so3_exp_map(torch.randn(batch_size, 3) * 3.0)
|
159 |
| - cam_params = {"R": R, "T": T} |
| 163 | + cam_params = {"R": R, "T": T, "device": device} |
160 | 164 | if cam_type in (OpenGLPerspectiveCameras, OpenGLOrthographicCameras):
|
161 | 165 | cam_params["znear"] = torch.rand(batch_size) * 10 + 0.1
|
162 | 166 | cam_params["zfar"] = torch.rand(batch_size) * 4 + 1 + cam_params["znear"]
|
@@ -613,15 +617,33 @@ def test_unproject_points(self, batch_size=50, num_points=100):
|
613 | 617 | self.assertTrue(torch.allclose(xyz_unproj, matching_xyz, atol=1e-4))
|
614 | 618 |
|
615 | 619 | @staticmethod
|
616 |
| - def unproject_points(cam_type, batch_size=50, num_points=100): |
| 620 | + def unproject_points( |
| 621 | + cam_type, batch_size=50, num_points=100, device: Device = "cpu" |
| 622 | + ): |
617 | 623 | """
|
618 | 624 | Checks that an unprojection of a randomly projected point cloud
|
619 | 625 | stays the same.
|
620 | 626 | """
|
| 627 | + if device == "cuda": |
| 628 | + device = torch.device("cuda:0") |
| 629 | + else: |
| 630 | + device = torch.device("cpu") |
| 631 | + |
| 632 | + str2cls = { # noqa |
| 633 | + "OpenGLOrthographicCameras": OpenGLOrthographicCameras, |
| 634 | + "OpenGLPerspectiveCameras": OpenGLPerspectiveCameras, |
| 635 | + "SfMOrthographicCameras": SfMOrthographicCameras, |
| 636 | + "SfMPerspectiveCameras": SfMPerspectiveCameras, |
| 637 | + "FoVOrthographicCameras": FoVOrthographicCameras, |
| 638 | + "FoVPerspectiveCameras": FoVPerspectiveCameras, |
| 639 | + "OrthographicCameras": OrthographicCameras, |
| 640 | + "PerspectiveCameras": PerspectiveCameras, |
| 641 | + "FishEyeCameras": FishEyeCameras, |
| 642 | + } |
621 | 643 |
|
622 | 644 | def run_cameras():
|
623 | 645 | # init the cameras
|
624 |
| - cameras = init_random_cameras(cam_type, batch_size) |
| 646 | + cameras = init_random_cameras(str2cls[cam_type], batch_size, device=device) |
625 | 647 | # xyz - the ground truth point cloud
|
626 | 648 | xyz = torch.randn(num_points, 3) * 0.3
|
627 | 649 | xyz = cameras.unproject_points(xyz, scaled_depth_input=True)
|
@@ -666,15 +688,33 @@ def test_project_points_screen(self, batch_size=50, num_points=100):
|
666 | 688 | self.assertClose(xyz_project_screen, xyz_project_screen_naive, atol=1e-4)
|
667 | 689 |
|
668 | 690 | @staticmethod
|
669 |
| - def transform_points(cam_type, batch_size=50, num_points=100): |
| 691 | + def transform_points( |
| 692 | + cam_type, batch_size=50, num_points=100, device: Device = "cpu" |
| 693 | + ): |
670 | 694 | """
|
671 | 695 | Checks that an unprojection of a randomly projected point cloud
|
672 | 696 | stays the same.
|
673 | 697 | """
|
674 | 698 |
|
| 699 | + if device == "cuda": |
| 700 | + device = torch.device("cuda:0") |
| 701 | + else: |
| 702 | + device = torch.device("cpu") |
| 703 | + str2cls = { # noqa |
| 704 | + "OpenGLOrthographicCameras": OpenGLOrthographicCameras, |
| 705 | + "OpenGLPerspectiveCameras": OpenGLPerspectiveCameras, |
| 706 | + "SfMOrthographicCameras": SfMOrthographicCameras, |
| 707 | + "SfMPerspectiveCameras": SfMPerspectiveCameras, |
| 708 | + "FoVOrthographicCameras": FoVOrthographicCameras, |
| 709 | + "FoVPerspectiveCameras": FoVPerspectiveCameras, |
| 710 | + "OrthographicCameras": OrthographicCameras, |
| 711 | + "PerspectiveCameras": PerspectiveCameras, |
| 712 | + "FishEyeCameras": FishEyeCameras, |
| 713 | + } |
| 714 | + |
675 | 715 | def run_cameras():
|
676 | 716 | # init the cameras
|
677 |
| - cameras = init_random_cameras(cam_type, batch_size) |
| 717 | + cameras = init_random_cameras(str2cls[cam_type], batch_size, device=device) |
678 | 718 | # xyz - the ground truth point cloud
|
679 | 719 | xy = torch.randn(num_points, 2) * 2.0 - 1.0
|
680 | 720 | z = torch.randn(num_points, 1) * 3.0 + 1.0
|
|
0 commit comments