9
9
import warnings
10
10
from collections import OrderedDict
11
11
from dataclasses import dataclass , field
12
- from typing import Any , Dict , List , Optional , Union
12
+ from typing import Any , Dict , List , Optional , Sequence , Union
13
13
14
14
import numpy as np
15
15
import torch
16
+ import torch .nn .functional as F
16
17
from pytorch3d .implicitron .dataset .implicitron_dataset import FrameData
17
18
from pytorch3d .implicitron .dataset .utils import is_known_frame , is_train_frame
19
+ from pytorch3d .implicitron .models .base_model import ImplicitronRender
18
20
from pytorch3d .implicitron .tools import vis_utils
19
21
from pytorch3d .implicitron .tools .camera_utils import volumetric_camera_overlaps
20
22
from pytorch3d .implicitron .tools .image_utils import mask_background
31
33
EVAL_N_SRC_VIEWS = [1 , 3 , 5 , 7 , 9 ]
32
34
33
35
34
- @dataclass
35
- class NewViewSynthesisPrediction :
36
- """
37
- Holds the tensors that describe a result of synthesizing new views.
38
- """
39
-
40
- depth_render : Optional [torch .Tensor ] = None
41
- image_render : Optional [torch .Tensor ] = None
42
- mask_render : Optional [torch .Tensor ] = None
43
- camera_distance : Optional [torch .Tensor ] = None
44
-
45
-
46
36
@dataclass
47
37
class _Visualizer :
48
38
image_render : torch .Tensor
@@ -145,8 +135,8 @@ def show_depth(
145
135
146
136
def eval_batch (
147
137
frame_data : FrameData ,
148
- nvs_prediction : NewViewSynthesisPrediction ,
149
- bg_color : Union [torch .Tensor , str , float ] = "black" ,
138
+ implicitron_render : ImplicitronRender ,
139
+ bg_color : Union [torch .Tensor , Sequence , str , float ] = "black" ,
150
140
mask_thr : float = 0.5 ,
151
141
lpips_model = None ,
152
142
visualize : bool = False ,
@@ -162,14 +152,14 @@ def eval_batch(
162
152
is True), a new-view synthesis method (NVS) is tasked to generate new views
163
153
of the scene from the viewpoint of the target views (for which
164
154
frame_data.frame_type.endswith('known') is False). The resulting
165
- synthesized new views, stored in `nvs_prediction `, are compared to the
155
+ synthesized new views, stored in `implicitron_render `, are compared to the
166
156
target ground truth in `frame_data` in terms of geometry and appearance
167
157
resulting in a dictionary of metrics returned by the `eval_batch` function.
168
158
169
159
Args:
170
160
frame_data: A FrameData object containing the input to the new view
171
161
synthesis method.
172
- nvs_prediction : The data describing the synthesized new views.
162
+ implicitron_render : The data describing the synthesized new views.
173
163
bg_color: The background color of the generated new views and the
174
164
ground truth.
175
165
lpips_model: A pre-trained model for evaluating the LPIPS metric.
@@ -184,26 +174,39 @@ def eval_batch(
184
174
ValueError if frame_data does not have frame_type, camera, or image_rgb
185
175
ValueError if the batch has a mix of training and test samples
186
176
ValueError if the batch frames are not [unseen, known, known, ...]
187
- ValueError if one of the required fields in nvs_prediction is missing
177
+ ValueError if one of the required fields in implicitron_render is missing
188
178
"""
189
- REQUIRED_NVS_PREDICTION_FIELDS = ["mask_render" , "image_render" , "depth_render" ]
190
179
frame_type = frame_data .frame_type
191
180
if frame_type is None :
192
181
raise ValueError ("Frame type has not been set." )
193
182
194
183
# we check that all those fields are not None but Pyre can't infer that properly
195
- # TODO: assign to local variables
184
+ # TODO: assign to local variables and simplify the code.
196
185
if frame_data .image_rgb is None :
197
186
raise ValueError ("Image is not in the evaluation batch." )
198
187
199
188
if frame_data .camera is None :
200
189
raise ValueError ("Camera is not in the evaluation batch." )
201
190
202
- if any (not hasattr (nvs_prediction , k ) for k in REQUIRED_NVS_PREDICTION_FIELDS ):
203
- raise ValueError ("One of the required predicted fields is missing" )
191
+ # eval all results in the resolution of the frame_data image
192
+ image_resol = tuple (frame_data .image_rgb .shape [2 :])
193
+
194
+ # Post-process the render:
195
+ # 1) check implicitron_render for Nones,
196
+ # 2) obtain copies to make sure we dont edit the original data,
197
+ # 3) take only the 1st (target) image
198
+ # 4) resize to match ground-truth resolution
199
+ cloned_render : Dict [str , torch .Tensor ] = {}
200
+ for k in ["mask_render" , "image_render" , "depth_render" ]:
201
+ field = getattr (implicitron_render , k )
202
+ if field is None :
203
+ raise ValueError (f"A required predicted field { k } is missing" )
204
+
205
+ imode = "bilinear" if k == "image_render" else "nearest"
206
+ cloned_render [k ] = (
207
+ F .interpolate (field [:1 ], size = image_resol , mode = imode ).detach ().clone ()
208
+ )
204
209
205
- # obtain copies to make sure we dont edit the original data
206
- nvs_prediction = copy .deepcopy (nvs_prediction )
207
210
frame_data = copy .deepcopy (frame_data )
208
211
209
212
# mask the ground truth depth in case frame_data contains the depth mask
@@ -226,9 +229,6 @@ def eval_batch(
226
229
+ " a target view while the rest should be source views."
227
230
) # TODO: do we need to enforce this?
228
231
229
- # take only the first (target image)
230
- for k in REQUIRED_NVS_PREDICTION_FIELDS :
231
- setattr (nvs_prediction , k , getattr (nvs_prediction , k )[:1 ])
232
232
for k in [
233
233
"depth_map" ,
234
234
"image_rgb" ,
@@ -242,10 +242,6 @@ def eval_batch(
242
242
if frame_data .depth_map is None or frame_data .depth_map .sum () <= 0 :
243
243
warnings .warn ("Empty or missing depth map in evaluation!" )
244
244
245
- # eval all results in the resolution of the frame_data image
246
- # pyre-fixme[16]: `Optional` has no attribute `shape`.
247
- image_resol = list (frame_data .image_rgb .shape [2 :])
248
-
249
245
# threshold the masks to make ground truth binary masks
250
246
mask_fg , mask_crop = [
251
247
(getattr (frame_data , k ) >= mask_thr ) for k in ("fg_probability" , "mask_crop" )
@@ -258,29 +254,14 @@ def eval_batch(
258
254
bg_color = bg_color ,
259
255
)
260
256
261
- # resize to the target resolution
262
- for k in REQUIRED_NVS_PREDICTION_FIELDS :
263
- imode = "bilinear" if k == "image_render" else "nearest"
264
- val = getattr (nvs_prediction , k )
265
- setattr (
266
- nvs_prediction ,
267
- k ,
268
- # pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got
269
- # `List[typing.Any]`.
270
- torch .nn .functional .interpolate (val , size = image_resol , mode = imode ),
271
- )
272
-
273
257
# clamp predicted images
274
- # pyre-fixme[16]: `Optional` has no attribute `clamp`.
275
- image_render = nvs_prediction .image_render .clamp (0.0 , 1.0 )
258
+ image_render = cloned_render ["image_render" ].clamp (0.0 , 1.0 )
276
259
277
260
if visualize :
278
261
visualizer = _Visualizer (
279
262
image_render = image_render ,
280
263
image_rgb_masked = image_rgb_masked ,
281
- # pyre-fixme[6]: Expected `Tensor` for 3rd param but got
282
- # `Optional[torch.Tensor]`.
283
- depth_render = nvs_prediction .depth_render ,
264
+ depth_render = cloned_render ["depth_render" ],
284
265
# pyre-fixme[6]: Expected `Tensor` for 4th param but got
285
266
# `Optional[torch.Tensor]`.
286
267
depth_map = frame_data .depth_map ,
@@ -292,9 +273,7 @@ def eval_batch(
292
273
results : Dict [str , Any ] = {}
293
274
294
275
results ["iou" ] = iou (
295
- # pyre-fixme[6]: Expected `Tensor` for 1st param but got
296
- # `Optional[torch.Tensor]`.
297
- nvs_prediction .mask_render ,
276
+ cloned_render ["mask_render" ],
298
277
mask_fg ,
299
278
mask = mask_crop ,
300
279
)
@@ -321,11 +300,7 @@ def eval_batch(
321
300
if name_postfix == "_fg" :
322
301
# only record depth metrics for the foreground
323
302
_ , abs_ = eval_depth (
324
- # pyre-fixme[6]: Expected `Tensor` for 1st param but got
325
- # `Optional[torch.Tensor]`.
326
- nvs_prediction .depth_render ,
327
- # pyre-fixme[6]: Expected `Tensor` for 2nd param but got
328
- # `Optional[torch.Tensor]`.
303
+ cloned_render ["depth_render" ],
329
304
frame_data .depth_map ,
330
305
get_best_scale = True ,
331
306
mask = loss_mask_now ,
@@ -343,7 +318,7 @@ def eval_batch(
343
318
if lpips_model is not None :
344
319
im1 , im2 = [
345
320
2.0 * im .clamp (0.0 , 1.0 ) - 1.0
346
- for im in (image_rgb_masked , nvs_prediction . image_render )
321
+ for im in (image_rgb_masked , cloned_render [ " image_render" ] )
347
322
]
348
323
results ["lpips" ] = lpips_model .forward (im1 , im2 ).item ()
349
324
0 commit comments