|
4 | 4 | # This source code is licensed under the BSD-style license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
| 7 | +import os |
7 | 8 | import unittest
|
8 | 9 |
|
9 | 10 | import torch
|
| 11 | +from omegaconf import DictConfig, OmegaConf |
10 | 12 | from pytorch3d.implicitron.models.generic_model import GenericModel
|
11 | 13 | from pytorch3d.implicitron.models.renderer.base import EvaluationMode
|
12 | 14 | from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
|
13 | 15 | from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras
|
14 | 16 |
|
15 | 17 |
|
| 18 | +if os.environ.get("FB_TEST", False): |
| 19 | + from common_testing import get_pytorch3d_dir |
| 20 | +else: |
| 21 | + from tests.common_testing import get_pytorch3d_dir |
| 22 | + |
| 23 | +IMPLICITRON_CONFIGS_DIR = ( |
| 24 | + get_pytorch3d_dir() / "projects" / "implicitron_trainer" / "configs" |
| 25 | +) |
| 26 | + |
| 27 | + |
16 | 28 | class TestGenericModel(unittest.TestCase):
|
| 29 | + def setUp(self): |
| 30 | + torch.manual_seed(42) |
| 31 | + |
17 | 32 | def test_gm(self):
|
18 | 33 | # Simple test of a forward and backward pass of the default GenericModel.
|
19 | 34 | device = torch.device("cuda:1")
|
20 | 35 | expand_args_fields(GenericModel)
|
21 | 36 | model = GenericModel()
|
22 | 37 | model.to(device)
|
| 38 | + self._one_model_test(model, device) |
| 39 | + |
| 40 | + def test_all_gm_configs(self): |
| 41 | + # Tests all model settings in the implicitron_trainer config folder. |
| 42 | + device = torch.device("cuda:0") |
| 43 | + config_files = [] |
| 44 | + |
| 45 | + for pattern in ("repro_singleseq*.yaml", "repro_multiseq*.yaml"): |
| 46 | + config_files.extend( |
| 47 | + [ |
| 48 | + f |
| 49 | + for f in IMPLICITRON_CONFIGS_DIR.glob(pattern) |
| 50 | + if not f.name.endswith("_base.yaml") |
| 51 | + ] |
| 52 | + ) |
| 53 | + |
| 54 | + for config_file in config_files: |
| 55 | + with self.subTest(name=config_file.stem): |
| 56 | + cfg = _load_model_config_from_yaml(str(config_file)) |
| 57 | + model = GenericModel(**cfg) |
| 58 | + model.to(device) |
| 59 | + self._one_model_test(model, device, eval_test=True) |
| 60 | + |
| 61 | + def _one_model_test( |
| 62 | + self, |
| 63 | + model, |
| 64 | + device, |
| 65 | + n_train_cameras: int = 5, |
| 66 | + eval_test: bool = True, |
| 67 | + ): |
23 | 68 |
|
24 |
| - n_train_cameras = 2 |
25 | 69 | R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
|
26 | 70 | cameras = PerspectiveCameras(R=R, T=T, device=device)
|
27 | 71 |
|
28 |
| - # TODO: make these default to None? |
29 |
| - defaulted_args = { |
30 |
| - "fg_probability": None, |
31 |
| - "depth_map": None, |
32 |
| - "mask_crop": None, |
33 |
| - "sequence_name": None, |
| 72 | + N, H, W = n_train_cameras, model.render_image_height, model.render_image_width |
| 73 | + |
| 74 | + random_args = { |
| 75 | + "camera": cameras, |
| 76 | + "fg_probability": _random_input_tensor(N, 1, H, W, True, device), |
| 77 | + "depth_map": _random_input_tensor(N, 1, H, W, False, device) + 0.1, |
| 78 | + "mask_crop": _random_input_tensor(N, 1, H, W, True, device), |
| 79 | + "sequence_name": ["sequence"] * N, |
| 80 | + "image_rgb": _random_input_tensor(N, 3, H, W, False, device), |
34 | 81 | }
|
35 | 82 |
|
36 |
| - with self.assertWarnsRegex(UserWarning, "No main objective found"): |
37 |
| - model( |
38 |
| - camera=cameras, |
39 |
| - evaluation_mode=EvaluationMode.TRAINING, |
40 |
| - **defaulted_args, |
41 |
| - image_rgb=None, |
42 |
| - ) |
43 |
| - target_image_rgb = torch.rand( |
44 |
| - (n_train_cameras, 3, model.render_image_height, model.render_image_width), |
45 |
| - device=device, |
46 |
| - ) |
| 83 | + # training foward pass |
| 84 | + model.train() |
47 | 85 | train_preds = model(
|
48 |
| - camera=cameras, |
| 86 | + **random_args, |
49 | 87 | evaluation_mode=EvaluationMode.TRAINING,
|
50 |
| - image_rgb=target_image_rgb, |
51 |
| - **defaulted_args, |
52 | 88 | )
|
53 | 89 | self.assertGreater(train_preds["objective"].item(), 0)
|
54 | 90 | train_preds["objective"].backward()
|
55 | 91 |
|
56 |
| - model.eval() |
57 |
| - with torch.no_grad(): |
58 |
| - # TODO: perhaps this warning should be skipped in eval mode? |
59 |
| - with self.assertWarnsRegex(UserWarning, "No main objective found"): |
| 92 | + if eval_test: |
| 93 | + model.eval() |
| 94 | + with torch.no_grad(): |
60 | 95 | eval_preds = model(
|
61 |
| - camera=cameras[0], |
62 |
| - **defaulted_args, |
63 |
| - image_rgb=None, |
| 96 | + **random_args, |
| 97 | + evaluation_mode=EvaluationMode.EVALUATION, |
| 98 | + ) |
| 99 | + self.assertEqual( |
| 100 | + eval_preds["images_render"].shape, |
| 101 | + (1, 3, model.render_image_height, model.render_image_width), |
64 | 102 | )
|
65 |
| - self.assertEqual( |
66 |
| - eval_preds["images_render"].shape, |
67 |
| - (1, 3, model.render_image_height, model.render_image_width), |
68 |
| - ) |
69 | 103 |
|
70 | 104 | def test_idr(self):
|
71 | 105 | # Forward pass of GenericModel with IDR.
|
@@ -104,3 +138,44 @@ def test_idr(self):
|
104 | 138 | **defaulted_args,
|
105 | 139 | )
|
106 | 140 | self.assertGreater(train_preds["objective"].item(), 0)
|
| 141 | + |
| 142 | + |
| 143 | +def _random_input_tensor( |
| 144 | + N: int, |
| 145 | + C: int, |
| 146 | + H: int, |
| 147 | + W: int, |
| 148 | + is_binary: bool, |
| 149 | + device: torch.device, |
| 150 | +) -> torch.Tensor: |
| 151 | + T = torch.rand(N, C, H, W, device=device) |
| 152 | + if is_binary: |
| 153 | + T = (T > 0.5).float() |
| 154 | + return T |
| 155 | + |
| 156 | + |
| 157 | +def _load_model_config_from_yaml(config_path, strict=True) -> DictConfig: |
| 158 | + default_cfg = get_default_args(GenericModel) |
| 159 | + cfg = _load_model_config_from_yaml_rec(default_cfg, config_path) |
| 160 | + return cfg |
| 161 | + |
| 162 | + |
| 163 | +def _load_model_config_from_yaml_rec(cfg: DictConfig, config_path: str) -> DictConfig: |
| 164 | + cfg_loaded = OmegaConf.load(config_path) |
| 165 | + if "generic_model_args" in cfg_loaded: |
| 166 | + cfg_model_loaded = cfg_loaded.generic_model_args |
| 167 | + else: |
| 168 | + cfg_model_loaded = None |
| 169 | + defaults = cfg_loaded.pop("defaults", None) |
| 170 | + if defaults is not None: |
| 171 | + for default_name in defaults: |
| 172 | + if default_name in ("_self_", "default_config"): |
| 173 | + continue |
| 174 | + default_name = os.path.splitext(default_name)[0] |
| 175 | + defpath = os.path.join(os.path.dirname(config_path), default_name + ".yaml") |
| 176 | + cfg = _load_model_config_from_yaml_rec(cfg, defpath) |
| 177 | + if cfg_model_loaded is not None: |
| 178 | + cfg = OmegaConf.merge(cfg, cfg_model_loaded) |
| 179 | + elif cfg_model_loaded is not None: |
| 180 | + cfg = OmegaConf.merge(cfg, cfg_model_loaded) |
| 181 | + return cfg |
0 commit comments