Skip to content

Commit 1b8d86a

Browse files
bottlerfacebook-github-bot
authored andcommitted
(breaking) image_size-agnostic GridRaySampler
Summary: As suggested in #802. By not persisting the _xy_grid buffer, we can allow (in some cases) a model with one image_size to be loaded from a saved model which was trained at a different resolution. Also avoid persisting _frequencies in HarmonicEmbedding for similar reasons. BC-break: This will cause load_state_dict, in strict mode, to complain if you try to load an old model with the new code. Reviewed By: patricklabatut Differential Revision: D30349234 fbshipit-source-id: d6061d1e51c9f79a78d61a9f732c9a5dfadbbb47
1 parent 1251446 commit 1b8d86a

File tree

3 files changed

+24
-4
lines changed

3 files changed

+24
-4
lines changed

projects/nerf/nerf/harmonic_embedding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def __init__(
1414
omega0: float = 1.0,
1515
logspace: bool = True,
1616
include_input: bool = True,
17-
):
17+
) -> None:
1818
"""
1919
Given an input tensor `x` of shape [minibatch, ... , dim],
2020
the harmonic embedding layer converts each feature
@@ -69,10 +69,10 @@ def __init__(
6969
dtype=torch.float32,
7070
)
7171

72-
self.register_buffer("_frequencies", omega0 * frequencies)
72+
self.register_buffer("_frequencies", omega0 * frequencies, persistent=False)
7373
self.include_input = include_input
7474

75-
def forward(self, x: torch.Tensor):
75+
def forward(self, x: torch.Tensor) -> torch.Tensor:
7676
"""
7777
Args:
7878
x: tensor of shape [..., dim]

pytorch3d/renderer/implicit/raysampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def __init__(
9696
),
9797
dim=-1,
9898
)
99-
self.register_buffer("_xy_grid", _xy_grid)
99+
self.register_buffer("_xy_grid", _xy_grid, persistent=False)
100100

101101
def forward(self, cameras: CamerasBase, **kwargs) -> RayBundle:
102102
"""

tests/test_raysampling.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,3 +425,23 @@ def _check_raysampler_ray_directions(self, cameras, raysampler, ray_bundle):
425425
ray_bundle_camera_fix_seed.directions.view(batch_size, -1, 3),
426426
atol=1e-5,
427427
)
428+
429+
def test_load_state(self):
430+
# check that we can load the state of one ray sampler into
431+
# another with different image size.
432+
module1 = NDCGridRaysampler(
433+
image_width=20,
434+
image_height=30,
435+
n_pts_per_ray=40,
436+
min_depth=1.2,
437+
max_depth=2.3,
438+
)
439+
module2 = NDCGridRaysampler(
440+
image_width=22,
441+
image_height=32,
442+
n_pts_per_ray=42,
443+
min_depth=1.2,
444+
max_depth=2.3,
445+
)
446+
state = module1.state_dict()
447+
module2.load_state_dict(state)

0 commit comments

Comments
 (0)