Skip to content

Commit 2edb93d

Browse files
bottlerfacebook-github-bot
authored andcommitted
chunked_inputs
Summary: Make method for SDF's use of object mask more general, so that a renderer can be given per-pixel values. Reviewed By: shapovalov Differential Revision: D35247412 fbshipit-source-id: 6aeccb1d0b5f1265a3f692a1453407a07e51a33c
1 parent 41c594c commit 2edb93d

File tree

2 files changed

+56
-17
lines changed

2 files changed

+56
-17
lines changed

pytorch3d/implicitron/models/base.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -396,20 +396,20 @@ def forward(
396396
for func in self._implicit_functions:
397397
func.bind_args(**custom_args)
398398

399-
object_mask: Optional[torch.Tensor] = None
399+
chunked_renderer_inputs = {}
400400
if fg_probability is not None:
401401
sampled_fb_prob = rend_utils.ndc_grid_sample(
402402
fg_probability[:n_targets], ray_bundle.xys, mode="nearest"
403403
)
404-
object_mask = sampled_fb_prob > 0.5
404+
chunked_renderer_inputs["object_mask"] = sampled_fb_prob > 0.5
405405

406406
# (5)-(6) Implicit function evaluation and Rendering
407407
rendered = self._render(
408408
ray_bundle=ray_bundle,
409409
sampling_mode=sampling_mode,
410410
evaluation_mode=evaluation_mode,
411411
implicit_functions=self._implicit_functions,
412-
object_mask=object_mask,
412+
chunked_inputs=chunked_renderer_inputs,
413413
)
414414

415415
# Unbind the custom arguments to prevent pytorch from storing
@@ -501,7 +501,6 @@ def visualize(
501501
Helper function to visualize the predictions generated
502502
in the forward pass.
503503
504-
505504
Args:
506505
viz: Visdom connection object
507506
visdom_env_imgs: name of visdom environment for the images.
@@ -521,29 +520,32 @@ def _render(
521520
self,
522521
*,
523522
ray_bundle: RayBundle,
524-
object_mask: Optional[torch.Tensor],
523+
chunked_inputs: Dict[str, torch.Tensor],
525524
sampling_mode: RenderSamplingMode,
526525
**kwargs,
527526
) -> RendererOutput:
528527
"""
529528
Args:
530529
ray_bundle: A `RayBundle` object containing the parametrizations of the
531530
sampled rendering rays.
532-
object_mask: A tensor of shape `(B, 3, H, W)` denoting the silhouette of the object
533-
in the image. This is required for the SignedDistanceFunctionRenderer.
531+
chunked_inputs: A collection of tensor of shape `(B, _, H, W)`. E.g.
532+
SignedDistanceFunctionRenderer requires "object_mask", shape
533+
(B, 1, H, W), the silhouette of the object in the image. When
534+
chunking, they are passed to the renderer as shape
535+
`(B, _, chunksize)`.
534536
sampling_mode: The sampling method to use. Must be a value from the
535537
RenderSamplingMode Enum.
538+
536539
Returns:
537540
An instance of RendererOutput
538-
539541
"""
540542
if sampling_mode == RenderSamplingMode.FULL_GRID and self.chunk_size_grid > 0:
541543
return _apply_chunked(
542544
self.renderer,
543545
_chunk_generator(
544546
self.chunk_size_grid,
545547
ray_bundle,
546-
object_mask,
548+
chunked_inputs,
547549
self.tqdm_trigger_threshold,
548550
**kwargs,
549551
),
@@ -553,7 +555,7 @@ def _render(
553555
# pyre-fixme[29]: `BaseRenderer` is not a function.
554556
return self.renderer(
555557
ray_bundle=ray_bundle,
556-
object_mask=object_mask,
558+
**chunked_inputs,
557559
**kwargs,
558560
)
559561

@@ -837,7 +839,7 @@ def _tensor_collator(batch, new_dims) -> torch.Tensor:
837839
def _chunk_generator(
838840
chunk_size: int,
839841
ray_bundle: RayBundle,
840-
object_mask: Optional[torch.Tensor],
842+
chunked_inputs: Dict[str, torch.Tensor],
841843
tqdm_trigger_threshold: int,
842844
*args,
843845
**kwargs,
@@ -880,8 +882,6 @@ def _chunk_generator(
880882
xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx],
881883
)
882884
extra_args = kwargs.copy()
883-
if object_mask is not None:
884-
extra_args["object_mask"] = object_mask.reshape(batch_size, -1, 1)[
885-
:, start_idx:end_idx
886-
]
885+
for k, v in chunked_inputs.items():
886+
extra_args[k] = v.flatten(2)[:, :, start_idx:end_idx]
887887
yield [ray_bundle_chunk, *args], extra_args

tests/implicitron/test_forward_pass.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
import torch
1010
from pytorch3d.implicitron.models.base import GenericModel
1111
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
12-
from pytorch3d.implicitron.tools.config import expand_args_fields
12+
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
1313
from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras
1414

1515

1616
class TestGenericModel(unittest.TestCase):
1717
def test_gm(self):
18-
# Simple test of a forward pass of the default GenericModel.
18+
# Simple test of a forward and backward pass of the default GenericModel.
1919
device = torch.device("cuda:1")
2020
expand_args_fields(GenericModel)
2121
model = GenericModel()
@@ -51,6 +51,7 @@ def test_gm(self):
5151
**defaulted_args,
5252
)
5353
self.assertGreater(train_preds["objective"].item(), 0)
54+
train_preds["objective"].backward()
5455

5556
model.eval()
5657
with torch.no_grad():
@@ -65,3 +66,41 @@ def test_gm(self):
6566
eval_preds["images_render"].shape,
6667
(1, 3, model.render_image_height, model.render_image_width),
6768
)
69+
70+
def test_idr(self):
71+
# Forward pass of GenericModel with IDR.
72+
device = torch.device("cuda:1")
73+
args = get_default_args(GenericModel)
74+
args.renderer_class_type = "SignedDistanceFunctionRenderer"
75+
args.implicit_function_class_type = "IdrFeatureField"
76+
args.implicit_function_IdrFeatureField_args.n_harmonic_functions_xyz = 6
77+
78+
model = GenericModel(**args)
79+
model.to(device)
80+
81+
n_train_cameras = 2
82+
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
83+
cameras = PerspectiveCameras(R=R, T=T, device=device)
84+
85+
defaulted_args = {
86+
"depth_map": None,
87+
"mask_crop": None,
88+
"sequence_name": None,
89+
}
90+
91+
target_image_rgb = torch.rand(
92+
(n_train_cameras, 3, model.render_image_height, model.render_image_width),
93+
device=device,
94+
)
95+
fg_probability = torch.rand(
96+
(n_train_cameras, 1, model.render_image_height, model.render_image_width),
97+
device=device,
98+
)
99+
train_preds = model(
100+
camera=cameras,
101+
evaluation_mode=EvaluationMode.TRAINING,
102+
image_rgb=target_image_rgb,
103+
fg_probability=fg_probability,
104+
**defaulted_args,
105+
)
106+
self.assertGreater(train_preds["objective"].item(), 0)

0 commit comments

Comments
 (0)