Skip to content

Commit c3d7808

Browse files
bottlerfacebook-github-bot
authored andcommitted
register_buffer compatibility
Summary: In D30349234 (1b8d86a) we introduced persistent=False to some register_buffer calls, which depend on PyTorch 1.6. We go back to the old behaviour for PyTorch 1.5. Reviewed By: nikhilaravi Differential Revision: D30731327 fbshipit-source-id: ab02ef98ee87440ef02479b72f4872b562ab85b5
1 parent bbc7573 commit c3d7808

File tree

3 files changed

+14
-3
lines changed

3 files changed

+14
-3
lines changed

projects/nerf/nerf/harmonic_embedding.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,12 @@ def __init__(
6969
dtype=torch.float32,
7070
)
7171

72-
self.register_buffer("_frequencies", omega0 * frequencies, persistent=False)
72+
try:
73+
self.register_buffer("_frequencies", omega0 * frequencies, persistent=False)
74+
except TypeError:
75+
# workaround for pytorch<1.6
76+
self.register_buffer("_frequencies", omega0 * frequencies)
77+
7378
self.include_input = include_input
7479

7580
def forward(self, x: torch.Tensor) -> torch.Tensor:

pytorch3d/renderer/implicit/raysampling.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,10 @@ def __init__(
9696
),
9797
dim=-1,
9898
)
99-
self.register_buffer("_xy_grid", _xy_grid, persistent=False)
99+
try:
100+
self.register_buffer("_xy_grid", _xy_grid, persistent=False)
101+
except TypeError:
102+
self.register_buffer("_xy_grid", _xy_grid) # workaround for pytorch<1.6
100103

101104
def forward(self, cameras: CamerasBase, **kwargs) -> RayBundle:
102105
"""

tests/test_raysampling.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,10 @@ def _check_raysampler_ray_directions(self, cameras, raysampler, ray_bundle):
426426
atol=1e-5,
427427
)
428428

429-
def test_load_state(self):
429+
@unittest.skipIf(
430+
torch.__version__[:4] == "1.5.", "non persistent buffer needs PyTorch 1.6"
431+
)
432+
def test_load_state_different_resolution(self):
430433
# check that we can load the state of one ray sampler into
431434
# another with different image size.
432435
module1 = NDCGridRaysampler(

0 commit comments

Comments
 (0)