@@ -396,20 +396,20 @@ def forward(
396
396
for func in self ._implicit_functions :
397
397
func .bind_args (** custom_args )
398
398
399
- object_mask : Optional [ torch . Tensor ] = None
399
+ chunked_renderer_inputs = {}
400
400
if fg_probability is not None :
401
401
sampled_fb_prob = rend_utils .ndc_grid_sample (
402
402
fg_probability [:n_targets ], ray_bundle .xys , mode = "nearest"
403
403
)
404
- object_mask = sampled_fb_prob > 0.5
404
+ chunked_renderer_inputs [ " object_mask" ] = sampled_fb_prob > 0.5
405
405
406
406
# (5)-(6) Implicit function evaluation and Rendering
407
407
rendered = self ._render (
408
408
ray_bundle = ray_bundle ,
409
409
sampling_mode = sampling_mode ,
410
410
evaluation_mode = evaluation_mode ,
411
411
implicit_functions = self ._implicit_functions ,
412
- object_mask = object_mask ,
412
+ chunked_inputs = chunked_renderer_inputs ,
413
413
)
414
414
415
415
# Unbind the custom arguments to prevent pytorch from storing
@@ -501,7 +501,6 @@ def visualize(
501
501
Helper function to visualize the predictions generated
502
502
in the forward pass.
503
503
504
-
505
504
Args:
506
505
viz: Visdom connection object
507
506
visdom_env_imgs: name of visdom environment for the images.
@@ -521,29 +520,32 @@ def _render(
521
520
self ,
522
521
* ,
523
522
ray_bundle : RayBundle ,
524
- object_mask : Optional [ torch .Tensor ],
523
+ chunked_inputs : Dict [ str , torch .Tensor ],
525
524
sampling_mode : RenderSamplingMode ,
526
525
** kwargs ,
527
526
) -> RendererOutput :
528
527
"""
529
528
Args:
530
529
ray_bundle: A `RayBundle` object containing the parametrizations of the
531
530
sampled rendering rays.
532
- object_mask: A tensor of shape `(B, 3, H, W)` denoting the silhouette of the object
533
- in the image. This is required for the SignedDistanceFunctionRenderer.
531
+ chunked_inputs: A collection of tensor of shape `(B, _, H, W)`. E.g.
532
+ SignedDistanceFunctionRenderer requires "object_mask", shape
533
+ (B, 1, H, W), the silhouette of the object in the image. When
534
+ chunking, they are passed to the renderer as shape
535
+ `(B, _, chunksize)`.
534
536
sampling_mode: The sampling method to use. Must be a value from the
535
537
RenderSamplingMode Enum.
538
+
536
539
Returns:
537
540
An instance of RendererOutput
538
-
539
541
"""
540
542
if sampling_mode == RenderSamplingMode .FULL_GRID and self .chunk_size_grid > 0 :
541
543
return _apply_chunked (
542
544
self .renderer ,
543
545
_chunk_generator (
544
546
self .chunk_size_grid ,
545
547
ray_bundle ,
546
- object_mask ,
548
+ chunked_inputs ,
547
549
self .tqdm_trigger_threshold ,
548
550
** kwargs ,
549
551
),
@@ -553,7 +555,7 @@ def _render(
553
555
# pyre-fixme[29]: `BaseRenderer` is not a function.
554
556
return self .renderer (
555
557
ray_bundle = ray_bundle ,
556
- object_mask = object_mask ,
558
+ ** chunked_inputs ,
557
559
** kwargs ,
558
560
)
559
561
@@ -837,7 +839,7 @@ def _tensor_collator(batch, new_dims) -> torch.Tensor:
837
839
def _chunk_generator (
838
840
chunk_size : int ,
839
841
ray_bundle : RayBundle ,
840
- object_mask : Optional [ torch .Tensor ],
842
+ chunked_inputs : Dict [ str , torch .Tensor ],
841
843
tqdm_trigger_threshold : int ,
842
844
* args ,
843
845
** kwargs ,
@@ -880,8 +882,6 @@ def _chunk_generator(
880
882
xys = ray_bundle .xys .reshape (batch_size , - 1 , 2 )[:, start_idx :end_idx ],
881
883
)
882
884
extra_args = kwargs .copy ()
883
- if object_mask is not None :
884
- extra_args ["object_mask" ] = object_mask .reshape (batch_size , - 1 , 1 )[
885
- :, start_idx :end_idx
886
- ]
885
+ for k , v in chunked_inputs .items ():
886
+ extra_args [k ] = v .flatten (2 )[:, :, start_idx :end_idx ]
887
887
yield [ray_bundle_chunk , * args ], extra_args
0 commit comments