Skip to content

Commit a6dada3

Browse files
shapovalovfacebook-github-bot
authored andcommitted
Extracted ImplicitronModelBase and unified API for GenericModel and ModelDBIR
Summary: To avoid model_zoo, we need to make GenericModel pluggable. I also align creation APIs for convenience. Reviewed By: bottler, davnov134 Differential Revision: D35933093 fbshipit-source-id: 8228926528eb41a795fbfbe32304b8019197e2b1
1 parent 5c59841 commit a6dada3

File tree

11 files changed

+280
-176
lines changed

11 files changed

+280
-176
lines changed

projects/implicitron_trainer/experiment.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
ImplicitronDataset,
7272
)
7373
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
74-
from pytorch3d.implicitron.models.base import EvaluationMode, GenericModel
74+
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
7575
from pytorch3d.implicitron.tools import model_io, vis_utils
7676
from pytorch3d.implicitron.tools.config import (
7777
enable_get_default_args,
@@ -615,11 +615,11 @@ def run_eval(cfg, model, all_source_cameras, loader, task, device):
615615
preds = model(
616616
**{**frame_data_for_eval, "evaluation_mode": EvaluationMode.EVALUATION}
617617
)
618-
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
618+
implicitron_render = copy.deepcopy(preds["implicitron_render"])
619619
per_batch_eval_results.append(
620620
evaluate.eval_batch(
621621
frame_data,
622-
nvs_prediction,
622+
implicitron_render,
623623
bg_color="black",
624624
lpips_model=lpips_model,
625625
source_cameras=all_source_cameras,

projects/implicitron_trainer/visualize_reconstruction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
ImplicitronDataset,
3030
)
3131
from pytorch3d.implicitron.dataset.utils import is_train_frame
32-
from pytorch3d.implicitron.models.base import EvaluationMode
32+
from pytorch3d.implicitron.models.base_model import EvaluationMode
3333
from pytorch3d.implicitron.tools.configurable import get_default_args
3434
from pytorch3d.implicitron.tools.eval_video_trajectory import (
3535
generate_eval_video_cameras,

pytorch3d/implicitron/eval_demo.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8-
import copy
98
import dataclasses
109
import os
11-
from typing import cast, Optional
10+
from typing import cast, Optional, Tuple
1211

1312
import lpips
1413
import torch
@@ -76,7 +75,7 @@ def main() -> None:
7675

7776
def evaluate_dbir_for_category(
7877
category: str = "apple",
79-
bg_color: float = 0.0,
78+
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0),
8079
task: str = "singlesequence",
8180
single_sequence_id: Optional[int] = None,
8281
num_workers: int = 16,
@@ -141,8 +140,9 @@ def evaluate_dbir_for_category(
141140
raise ValueError("Image size should be set in the dataset")
142141

143142
# init the simple DBIR model
144-
model = ModelDBIR(
145-
image_size=image_size,
143+
model = ModelDBIR( # pyre-ignore[28]: c’tor implicitly overridden
144+
render_image_width=image_size,
145+
render_image_height=image_size,
146146
bg_color=bg_color,
147147
max_points=int(1e5),
148148
)
@@ -157,11 +157,10 @@ def evaluate_dbir_for_category(
157157
for frame_data in tqdm(test_dataloader):
158158
frame_data = dataclass_to_cuda_(frame_data)
159159
preds = model(**dataclasses.asdict(frame_data))
160-
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
161160
per_batch_eval_results.append(
162161
eval_batch(
163162
frame_data,
164-
nvs_prediction,
163+
preds["implicitron_render"],
165164
bg_color=bg_color,
166165
lpips_model=lpips_model,
167166
source_cameras=all_source_cameras,

pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py

Lines changed: 32 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
import warnings
1010
from collections import OrderedDict
1111
from dataclasses import dataclass, field
12-
from typing import Any, Dict, List, Optional, Union
12+
from typing import Any, Dict, List, Optional, Sequence, Union
1313

1414
import numpy as np
1515
import torch
16+
import torch.nn.functional as F
1617
from pytorch3d.implicitron.dataset.implicitron_dataset import FrameData
1718
from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame
19+
from pytorch3d.implicitron.models.base_model import ImplicitronRender
1820
from pytorch3d.implicitron.tools import vis_utils
1921
from pytorch3d.implicitron.tools.camera_utils import volumetric_camera_overlaps
2022
from pytorch3d.implicitron.tools.image_utils import mask_background
@@ -31,18 +33,6 @@
3133
EVAL_N_SRC_VIEWS = [1, 3, 5, 7, 9]
3234

3335

34-
@dataclass
35-
class NewViewSynthesisPrediction:
36-
"""
37-
Holds the tensors that describe a result of synthesizing new views.
38-
"""
39-
40-
depth_render: Optional[torch.Tensor] = None
41-
image_render: Optional[torch.Tensor] = None
42-
mask_render: Optional[torch.Tensor] = None
43-
camera_distance: Optional[torch.Tensor] = None
44-
45-
4636
@dataclass
4737
class _Visualizer:
4838
image_render: torch.Tensor
@@ -145,8 +135,8 @@ def show_depth(
145135

146136
def eval_batch(
147137
frame_data: FrameData,
148-
nvs_prediction: NewViewSynthesisPrediction,
149-
bg_color: Union[torch.Tensor, str, float] = "black",
138+
implicitron_render: ImplicitronRender,
139+
bg_color: Union[torch.Tensor, Sequence, str, float] = "black",
150140
mask_thr: float = 0.5,
151141
lpips_model=None,
152142
visualize: bool = False,
@@ -162,14 +152,14 @@ def eval_batch(
162152
is True), a new-view synthesis method (NVS) is tasked to generate new views
163153
of the scene from the viewpoint of the target views (for which
164154
frame_data.frame_type.endswith('known') is False). The resulting
165-
synthesized new views, stored in `nvs_prediction`, are compared to the
155+
synthesized new views, stored in `implicitron_render`, are compared to the
166156
target ground truth in `frame_data` in terms of geometry and appearance
167157
resulting in a dictionary of metrics returned by the `eval_batch` function.
168158
169159
Args:
170160
frame_data: A FrameData object containing the input to the new view
171161
synthesis method.
172-
nvs_prediction: The data describing the synthesized new views.
162+
implicitron_render: The data describing the synthesized new views.
173163
bg_color: The background color of the generated new views and the
174164
ground truth.
175165
lpips_model: A pre-trained model for evaluating the LPIPS metric.
@@ -184,26 +174,39 @@ def eval_batch(
184174
ValueError if frame_data does not have frame_type, camera, or image_rgb
185175
ValueError if the batch has a mix of training and test samples
186176
ValueError if the batch frames are not [unseen, known, known, ...]
187-
ValueError if one of the required fields in nvs_prediction is missing
177+
ValueError if one of the required fields in implicitron_render is missing
188178
"""
189-
REQUIRED_NVS_PREDICTION_FIELDS = ["mask_render", "image_render", "depth_render"]
190179
frame_type = frame_data.frame_type
191180
if frame_type is None:
192181
raise ValueError("Frame type has not been set.")
193182

194183
# we check that all those fields are not None but Pyre can't infer that properly
195-
# TODO: assign to local variables
184+
# TODO: assign to local variables and simplify the code.
196185
if frame_data.image_rgb is None:
197186
raise ValueError("Image is not in the evaluation batch.")
198187

199188
if frame_data.camera is None:
200189
raise ValueError("Camera is not in the evaluation batch.")
201190

202-
if any(not hasattr(nvs_prediction, k) for k in REQUIRED_NVS_PREDICTION_FIELDS):
203-
raise ValueError("One of the required predicted fields is missing")
191+
# eval all results in the resolution of the frame_data image
192+
image_resol = tuple(frame_data.image_rgb.shape[2:])
193+
194+
# Post-process the render:
195+
# 1) check implicitron_render for Nones,
196+
# 2) obtain copies to make sure we dont edit the original data,
197+
# 3) take only the 1st (target) image
198+
# 4) resize to match ground-truth resolution
199+
cloned_render: Dict[str, torch.Tensor] = {}
200+
for k in ["mask_render", "image_render", "depth_render"]:
201+
field = getattr(implicitron_render, k)
202+
if field is None:
203+
raise ValueError(f"A required predicted field {k} is missing")
204+
205+
imode = "bilinear" if k == "image_render" else "nearest"
206+
cloned_render[k] = (
207+
F.interpolate(field[:1], size=image_resol, mode=imode).detach().clone()
208+
)
204209

205-
# obtain copies to make sure we dont edit the original data
206-
nvs_prediction = copy.deepcopy(nvs_prediction)
207210
frame_data = copy.deepcopy(frame_data)
208211

209212
# mask the ground truth depth in case frame_data contains the depth mask
@@ -226,9 +229,6 @@ def eval_batch(
226229
+ " a target view while the rest should be source views."
227230
) # TODO: do we need to enforce this?
228231

229-
# take only the first (target image)
230-
for k in REQUIRED_NVS_PREDICTION_FIELDS:
231-
setattr(nvs_prediction, k, getattr(nvs_prediction, k)[:1])
232232
for k in [
233233
"depth_map",
234234
"image_rgb",
@@ -242,10 +242,6 @@ def eval_batch(
242242
if frame_data.depth_map is None or frame_data.depth_map.sum() <= 0:
243243
warnings.warn("Empty or missing depth map in evaluation!")
244244

245-
# eval all results in the resolution of the frame_data image
246-
# pyre-fixme[16]: `Optional` has no attribute `shape`.
247-
image_resol = list(frame_data.image_rgb.shape[2:])
248-
249245
# threshold the masks to make ground truth binary masks
250246
mask_fg, mask_crop = [
251247
(getattr(frame_data, k) >= mask_thr) for k in ("fg_probability", "mask_crop")
@@ -258,29 +254,14 @@ def eval_batch(
258254
bg_color=bg_color,
259255
)
260256

261-
# resize to the target resolution
262-
for k in REQUIRED_NVS_PREDICTION_FIELDS:
263-
imode = "bilinear" if k == "image_render" else "nearest"
264-
val = getattr(nvs_prediction, k)
265-
setattr(
266-
nvs_prediction,
267-
k,
268-
# pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got
269-
# `List[typing.Any]`.
270-
torch.nn.functional.interpolate(val, size=image_resol, mode=imode),
271-
)
272-
273257
# clamp predicted images
274-
# pyre-fixme[16]: `Optional` has no attribute `clamp`.
275-
image_render = nvs_prediction.image_render.clamp(0.0, 1.0)
258+
image_render = cloned_render["image_render"].clamp(0.0, 1.0)
276259

277260
if visualize:
278261
visualizer = _Visualizer(
279262
image_render=image_render,
280263
image_rgb_masked=image_rgb_masked,
281-
# pyre-fixme[6]: Expected `Tensor` for 3rd param but got
282-
# `Optional[torch.Tensor]`.
283-
depth_render=nvs_prediction.depth_render,
264+
depth_render=cloned_render["depth_render"],
284265
# pyre-fixme[6]: Expected `Tensor` for 4th param but got
285266
# `Optional[torch.Tensor]`.
286267
depth_map=frame_data.depth_map,
@@ -292,9 +273,7 @@ def eval_batch(
292273
results: Dict[str, Any] = {}
293274

294275
results["iou"] = iou(
295-
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
296-
# `Optional[torch.Tensor]`.
297-
nvs_prediction.mask_render,
276+
cloned_render["mask_render"],
298277
mask_fg,
299278
mask=mask_crop,
300279
)
@@ -321,11 +300,7 @@ def eval_batch(
321300
if name_postfix == "_fg":
322301
# only record depth metrics for the foreground
323302
_, abs_ = eval_depth(
324-
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
325-
# `Optional[torch.Tensor]`.
326-
nvs_prediction.depth_render,
327-
# pyre-fixme[6]: Expected `Tensor` for 2nd param but got
328-
# `Optional[torch.Tensor]`.
303+
cloned_render["depth_render"],
329304
frame_data.depth_map,
330305
get_best_scale=True,
331306
mask=loss_mask_now,
@@ -343,7 +318,7 @@ def eval_batch(
343318
if lpips_model is not None:
344319
im1, im2 = [
345320
2.0 * im.clamp(0.0, 1.0) - 1.0
346-
for im in (image_rgb_masked, nvs_prediction.image_render)
321+
for im in (image_rgb_masked, cloned_render["image_render"])
347322
]
348323
results["lpips"] = lpips_model.forward(im1, im2).item()
349324

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass
8+
from typing import Any, Dict, List, Optional
9+
10+
import torch
11+
from pytorch3d.implicitron.tools.config import ReplaceableBase
12+
from pytorch3d.renderer.cameras import CamerasBase
13+
14+
from .renderer.base import EvaluationMode
15+
16+
17+
@dataclass
18+
class ImplicitronRender:
19+
"""
20+
Holds the tensors that describe a result of rendering.
21+
"""
22+
23+
depth_render: Optional[torch.Tensor] = None
24+
image_render: Optional[torch.Tensor] = None
25+
mask_render: Optional[torch.Tensor] = None
26+
camera_distance: Optional[torch.Tensor] = None
27+
28+
def clone(self) -> "ImplicitronRender":
29+
def safe_clone(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
30+
return t.detach().clone() if t is not None else None
31+
32+
return ImplicitronRender(
33+
depth_render=safe_clone(self.depth_render),
34+
image_render=safe_clone(self.image_render),
35+
mask_render=safe_clone(self.mask_render),
36+
camera_distance=safe_clone(self.camera_distance),
37+
)
38+
39+
40+
class ImplicitronModelBase(ReplaceableBase):
41+
"""Replaceable abstract base for all image generation / rendering models.
42+
`forward()` method produces a render with a depth map.
43+
"""
44+
45+
def __init__(self) -> None:
46+
super().__init__()
47+
48+
def forward(
49+
self,
50+
*, # force keyword-only arguments
51+
image_rgb: Optional[torch.Tensor],
52+
camera: CamerasBase,
53+
fg_probability: Optional[torch.Tensor],
54+
mask_crop: Optional[torch.Tensor],
55+
depth_map: Optional[torch.Tensor],
56+
sequence_name: Optional[List[str]],
57+
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
58+
**kwargs,
59+
) -> Dict[str, Any]:
60+
"""
61+
Args:
62+
image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images;
63+
the first `min(B, n_train_target_views)` images are considered targets and
64+
are used to supervise the renders; the rest corresponding to the source
65+
viewpoints from which features will be extracted.
66+
camera: An instance of CamerasBase containing a batch of `B` cameras corresponding
67+
to the viewpoints of target images, from which the rays will be sampled,
68+
and source images, which will be used for intersecting with target rays.
69+
fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch of
70+
foreground masks.
71+
mask_crop: A binary tensor of shape `(B, 1, H, W)` deonting valid
72+
regions in the input images (i.e. regions that do not correspond
73+
to, e.g., zero-padding). When the `RaySampler`'s sampling mode is set to
74+
"mask_sample", rays will be sampled in the non zero regions.
75+
depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps.
76+
sequence_name: A list of `B` strings corresponding to the sequence names
77+
from which images `image_rgb` were extracted. They are used to match
78+
target frames with relevant source frames.
79+
evaluation_mode: one of EvaluationMode.TRAINING or
80+
EvaluationMode.EVALUATION which determines the settings used for
81+
rendering.
82+
83+
Returns:
84+
preds: A dictionary containing all outputs of the forward pass. All models should
85+
output an instance of `ImplicitronRender` in `preds["implicitron_render"]`.
86+
"""
87+
raise NotImplementedError()

0 commit comments

Comments
 (0)