Skip to content

Commit b7e3b7b

Browse files
bottlerfacebook-github-bot
authored andcommitted
rendered_mesh_dataset improvements
Summary: Allow choosing the device and the distance Reviewed By: shapovalov Differential Revision: D42451605 fbshipit-source-id: 214f02d09da94eb127b3cc308d5bae800dc7b9e2
1 parent acc60db commit b7e3b7b

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

projects/implicitron_trainer/tests/experiment.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,10 @@ data_source_ImplicitronDataSource_args:
103103
num_views: 40
104104
data_file: null
105105
azimuth_range: 180.0
106+
distance: 2.7
106107
resolution: 128
107108
use_point_light: true
109+
gpu_idx: 0
108110
path_manager_factory_class_type: PathManagerFactory
109111
path_manager_factory_PathManagerFactory_args:
110112
silence_logs: true

pytorch3d/implicitron/dataset/rendered_mesh_dataset_map_provider.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class RenderedMeshDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13
4949
if one is available, the data it produces is on the CPU just like
5050
the data returned by implicitron's other dataset map providers.
5151
This is because both datasets and models can be large, so implicitron's
52-
GenericModel.forward (etc) expects data on the CPU and only moves
52+
training loop expects data on the CPU and only moves
5353
what it needs to the device.
5454
5555
For a more detailed explanation of this code, please refer to the
@@ -61,16 +61,23 @@ class RenderedMeshDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13
6161
the cow mesh in the same repo as this code.
6262
azimuth_range: number of degrees on each side of the start position to
6363
take samples
64+
distance: distance from camera centres to the origin.
6465
resolution: the common height and width of the output images.
6566
use_point_light: whether to use a particular point light as opposed
6667
to ambient white.
68+
gpu_idx: which gpu to use for rendering the mesh.
69+
path_manager_factory: (Optional) An object that generates an instance of
70+
PathManager that can translate provided file paths.
71+
path_manager_factory_class_type: The class type of `path_manager_factory`.
6772
"""
6873

6974
num_views: int = 40
7075
data_file: Optional[str] = None
7176
azimuth_range: float = 180
77+
distance: float = 2.7
7278
resolution: int = 128
7379
use_point_light: bool = True
80+
gpu_idx: Optional[int] = 0
7481
path_manager_factory: PathManagerFactory
7582
path_manager_factory_class_type: str = "PathManagerFactory"
7683

@@ -85,8 +92,8 @@ def get_all_train_cameras(self) -> CamerasBase:
8592
def __post_init__(self) -> None:
8693
super().__init__()
8794
run_auto_creation(self)
88-
if torch.cuda.is_available():
89-
device = torch.device("cuda:0")
95+
if torch.cuda.is_available() and self.gpu_idx is not None:
96+
device = torch.device(f"cuda:{self.gpu_idx}")
9097
else:
9198
device = torch.device("cpu")
9299
if self.data_file is None:
@@ -106,13 +113,13 @@ def __post_init__(self) -> None:
106113
num_views=self.num_views,
107114
mesh=mesh,
108115
azimuth_range=self.azimuth_range,
116+
distance=self.distance,
109117
resolution=self.resolution,
110118
device=device,
111119
use_point_light=self.use_point_light,
112120
)
113121
# pyre-ignore[16]
114122
self.poses = poses.cpu()
115-
expand_args_fields(SingleSceneDataset)
116123
# pyre-ignore[16]
117124
self.train_dataset = SingleSceneDataset( # pyre-ignore[28]
118125
object_name="cow",
@@ -130,6 +137,7 @@ def _generate_cow_renders(
130137
num_views: int,
131138
mesh: Meshes,
132139
azimuth_range: float,
140+
distance: float,
133141
resolution: int,
134142
device: torch.device,
135143
use_point_light: bool,
@@ -168,11 +176,11 @@ def _generate_cow_renders(
168176
else:
169177
lights = AmbientLights(device=device)
170178

171-
# Initialize an OpenGL perspective camera that represents a batch of different
179+
# Initialize a perspective camera that represents a batch of different
172180
# viewing angles. All the cameras helper methods support mixed type inputs and
173-
# broadcasting. So we can view the camera from the a distance of dist=2.7, and
181+
# broadcasting. So we can view the camera from a fixed distance, and
174182
# then specify elevation and azimuth angles for each viewpoint as tensors.
175-
R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim)
183+
R, T = look_at_view_transform(dist=distance, elev=elev, azim=azim)
176184
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
177185

178186
# Define the settings for rasterization and shading.

0 commit comments

Comments
 (0)