Skip to content

Commit 2374d19

Browse files
davnov134facebook-github-bot
authored andcommitted
Test all CO3D model configs in test_forward_pass
Summary: Tests all possible model configs in test_forward_pass.py Reviewed By: shapovalov Differential Revision: D35851507 fbshipit-source-id: 4860ee1d37cf17a2faab5fc14d4b2ba0b96c4b8b
1 parent 1f39537 commit 2374d19

File tree

1 file changed

+107
-32
lines changed

1 file changed

+107
-32
lines changed

tests/implicitron/test_forward_pass.py

Lines changed: 107 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,68 +4,102 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import os
78
import unittest
89

910
import torch
11+
from omegaconf import DictConfig, OmegaConf
1012
from pytorch3d.implicitron.models.generic_model import GenericModel
1113
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
1214
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
1315
from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras
1416

1517

18+
if os.environ.get("FB_TEST", False):
19+
from common_testing import get_pytorch3d_dir
20+
else:
21+
from tests.common_testing import get_pytorch3d_dir
22+
23+
IMPLICITRON_CONFIGS_DIR = (
24+
get_pytorch3d_dir() / "projects" / "implicitron_trainer" / "configs"
25+
)
26+
27+
1628
class TestGenericModel(unittest.TestCase):
29+
def setUp(self):
30+
torch.manual_seed(42)
31+
1732
def test_gm(self):
1833
# Simple test of a forward and backward pass of the default GenericModel.
1934
device = torch.device("cuda:1")
2035
expand_args_fields(GenericModel)
2136
model = GenericModel()
2237
model.to(device)
38+
self._one_model_test(model, device)
39+
40+
def test_all_gm_configs(self):
41+
# Tests all model settings in the implicitron_trainer config folder.
42+
device = torch.device("cuda:0")
43+
config_files = []
44+
45+
for pattern in ("repro_singleseq*.yaml", "repro_multiseq*.yaml"):
46+
config_files.extend(
47+
[
48+
f
49+
for f in IMPLICITRON_CONFIGS_DIR.glob(pattern)
50+
if not f.name.endswith("_base.yaml")
51+
]
52+
)
53+
54+
for config_file in config_files:
55+
with self.subTest(name=config_file.stem):
56+
cfg = _load_model_config_from_yaml(str(config_file))
57+
model = GenericModel(**cfg)
58+
model.to(device)
59+
self._one_model_test(model, device, eval_test=True)
60+
61+
def _one_model_test(
62+
self,
63+
model,
64+
device,
65+
n_train_cameras: int = 5,
66+
eval_test: bool = True,
67+
):
2368

24-
n_train_cameras = 2
2569
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
2670
cameras = PerspectiveCameras(R=R, T=T, device=device)
2771

28-
# TODO: make these default to None?
29-
defaulted_args = {
30-
"fg_probability": None,
31-
"depth_map": None,
32-
"mask_crop": None,
33-
"sequence_name": None,
72+
N, H, W = n_train_cameras, model.render_image_height, model.render_image_width
73+
74+
random_args = {
75+
"camera": cameras,
76+
"fg_probability": _random_input_tensor(N, 1, H, W, True, device),
77+
"depth_map": _random_input_tensor(N, 1, H, W, False, device) + 0.1,
78+
"mask_crop": _random_input_tensor(N, 1, H, W, True, device),
79+
"sequence_name": ["sequence"] * N,
80+
"image_rgb": _random_input_tensor(N, 3, H, W, False, device),
3481
}
3582

36-
with self.assertWarnsRegex(UserWarning, "No main objective found"):
37-
model(
38-
camera=cameras,
39-
evaluation_mode=EvaluationMode.TRAINING,
40-
**defaulted_args,
41-
image_rgb=None,
42-
)
43-
target_image_rgb = torch.rand(
44-
(n_train_cameras, 3, model.render_image_height, model.render_image_width),
45-
device=device,
46-
)
83+
# training foward pass
84+
model.train()
4785
train_preds = model(
48-
camera=cameras,
86+
**random_args,
4987
evaluation_mode=EvaluationMode.TRAINING,
50-
image_rgb=target_image_rgb,
51-
**defaulted_args,
5288
)
5389
self.assertGreater(train_preds["objective"].item(), 0)
5490
train_preds["objective"].backward()
5591

56-
model.eval()
57-
with torch.no_grad():
58-
# TODO: perhaps this warning should be skipped in eval mode?
59-
with self.assertWarnsRegex(UserWarning, "No main objective found"):
92+
if eval_test:
93+
model.eval()
94+
with torch.no_grad():
6095
eval_preds = model(
61-
camera=cameras[0],
62-
**defaulted_args,
63-
image_rgb=None,
96+
**random_args,
97+
evaluation_mode=EvaluationMode.EVALUATION,
98+
)
99+
self.assertEqual(
100+
eval_preds["images_render"].shape,
101+
(1, 3, model.render_image_height, model.render_image_width),
64102
)
65-
self.assertEqual(
66-
eval_preds["images_render"].shape,
67-
(1, 3, model.render_image_height, model.render_image_width),
68-
)
69103

70104
def test_idr(self):
71105
# Forward pass of GenericModel with IDR.
@@ -104,3 +138,44 @@ def test_idr(self):
104138
**defaulted_args,
105139
)
106140
self.assertGreater(train_preds["objective"].item(), 0)
141+
142+
143+
def _random_input_tensor(
144+
N: int,
145+
C: int,
146+
H: int,
147+
W: int,
148+
is_binary: bool,
149+
device: torch.device,
150+
) -> torch.Tensor:
151+
T = torch.rand(N, C, H, W, device=device)
152+
if is_binary:
153+
T = (T > 0.5).float()
154+
return T
155+
156+
157+
def _load_model_config_from_yaml(config_path, strict=True) -> DictConfig:
158+
default_cfg = get_default_args(GenericModel)
159+
cfg = _load_model_config_from_yaml_rec(default_cfg, config_path)
160+
return cfg
161+
162+
163+
def _load_model_config_from_yaml_rec(cfg: DictConfig, config_path: str) -> DictConfig:
164+
cfg_loaded = OmegaConf.load(config_path)
165+
if "generic_model_args" in cfg_loaded:
166+
cfg_model_loaded = cfg_loaded.generic_model_args
167+
else:
168+
cfg_model_loaded = None
169+
defaults = cfg_loaded.pop("defaults", None)
170+
if defaults is not None:
171+
for default_name in defaults:
172+
if default_name in ("_self_", "default_config"):
173+
continue
174+
default_name = os.path.splitext(default_name)[0]
175+
defpath = os.path.join(os.path.dirname(config_path), default_name + ".yaml")
176+
cfg = _load_model_config_from_yaml_rec(cfg, defpath)
177+
if cfg_model_loaded is not None:
178+
cfg = OmegaConf.merge(cfg, cfg_model_loaded)
179+
elif cfg_model_loaded is not None:
180+
cfg = OmegaConf.merge(cfg, cfg_model_loaded)
181+
return cfg

0 commit comments

Comments
 (0)