Skip to content

Commit 56d3465

Browse files
Darijan Gudeljfacebook-github-bot
Darijan Gudelj
authored andcommitted
scaffold
Summary: Forward method is sped up using the scaffold, a low resolution voxel grid which is used to filter out the points in empty space. These points will be predicted as having 0 density and (0, 0, 0) color. The points which were not evaluated as empty space will be passed through the steps outlined above. Reviewed By: kjchalup Differential Revision: D39579671 fbshipit-source-id: 8eab8bb43ef77c2a73557efdb725e99a6c60d415
1 parent 95a2acf commit 56d3465

File tree

3 files changed

+173
-45
lines changed

3 files changed

+173
-45
lines changed

pytorch3d/implicitron/models/implicit_function/voxel_grid.py

Lines changed: 157 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
1616
"""
1717

18+
from collections.abc import Mapping
1819
from dataclasses import dataclass, field
19-
from typing import Callable, ClassVar, Dict, List, Optional, Tuple, Type
20+
from typing import Callable, ClassVar, Dict, Iterator, List, Optional, Tuple, Type
2021

2122
import torch
2223
from omegaconf import DictConfig
@@ -164,8 +165,9 @@ def get_output_dim(args: DictConfig) -> int:
164165

165166
def change_resolution(
166167
self,
167-
epoch: int,
168168
grid_values: VoxelGridValuesBase,
169+
epoch: int,
170+
*,
169171
mode: str = "linear",
170172
align_corners: bool = True,
171173
antialias: bool = False,
@@ -177,8 +179,8 @@ def change_resolution(
177179
epoch: current training epoch, used to see if the grid needs regridding
178180
grid_values: instance of self.values_type which contains
179181
the voxel grid which will be interpolated to create the new grid
180-
wanted_resolution: tuple of (x, y, z) resolutions which determine
181-
new grid's resolution
182+
epoch: epoch which is used to get the resolution of the new
183+
`grid_values` using `self.resolution_changes`.
182184
align_corners: as for torch.nn.functional.interpolate
183185
mode: as for torch.nn.functional.interpolate
184186
'nearest' | 'bicubic' | 'linear' | 'area' | 'nearest-exact'.
@@ -225,11 +227,17 @@ def change_individual_resolution(tensor, wanted_resolution):
225227
# pyre-ignore[29]
226228
return self.values_type(**params), True
227229

228-
def get_resolution_change_epochs(self) -> List[int]:
230+
def get_resolution_change_epochs(self) -> Tuple[int, ...]:
229231
"""
230232
Returns epochs at which this grid should change epochs.
231233
"""
232-
return list(self.resolution_changes.keys())
234+
return tuple(self.resolution_changes.keys())
235+
236+
def get_align_corners(self) -> bool:
237+
"""
238+
Returns True if voxel grid uses align_corners=True
239+
"""
240+
return self.align_corners
233241

234242

235243
@dataclass
@@ -583,6 +591,8 @@ class VoxelGridModule(Configurable, torch.nn.Module):
583591
"""
584592
A wrapper torch.nn.Module for the VoxelGrid classes, which
585593
contains parameters that are needed to train the VoxelGrid classes.
594+
Can contain the parameters for the voxel grid as pytorch parameters
595+
or as registered buffers.
586596
587597
Members:
588598
voxel_grid_class_type: The name of the class to use for voxel_grid,
@@ -596,17 +606,21 @@ class VoxelGridModule(Configurable, torch.nn.Module):
596606
with mean=init_mean and std=init_std. Default 0.1
597607
init_mean: Parameters are initialized using the gaussian distribution
598608
with mean=init_mean and std=init_std. Default 0.
609+
hold_voxel_grid_as_parameters: if True components of the underlying voxel grids
610+
will be saved as parameters and therefore be trainable. Default True.
599611
"""
600612

601613
voxel_grid_class_type: str = "FullResolutionVoxelGrid"
602614
voxel_grid: VoxelGridBase
603615

604-
extents: Tuple[float, float, float] = (1.0, 1.0, 1.0)
616+
extents: Tuple[float, float, float] = (2.0, 2.0, 2.0)
605617
translation: Tuple[float, float, float] = (0.0, 0.0, 0.0)
606618

607619
init_std: float = 0.1
608620
init_mean: float = 0
609621

622+
hold_voxel_grid_as_parameters: bool = True
623+
610624
def __post_init__(self):
611625
super().__init__()
612626
run_auto_creation(self)
@@ -619,7 +633,8 @@ def __post_init__(self):
619633
)
620634
for name, shape in shapes.items()
621635
}
622-
self.params = torch.nn.ParameterDict(params)
636+
637+
self.set_voxel_grid_parameters(self.voxel_grid.values_type(**params))
623638
self._register_load_state_dict_pre_hook(self._create_parameters_with_new_size)
624639

625640
def forward(self, points: torch.Tensor) -> torch.Tensor:
@@ -632,31 +647,29 @@ def forward(self, points: torch.Tensor) -> torch.Tensor:
632647
Returns:
633648
torch.Tensor of shape (..., n_features)
634649
"""
635-
locator = VolumeLocator(
636-
batch_size=1,
637-
# The resolution of the voxel grid does not need to be known
638-
# to the locator object. It is easiest to fix the resolution of the locator.
639-
# In particular we fix it to (2,2,2) so that there is exactly one voxel of the
640-
# desired size. The locator object uses (z, y, x) convention for the grid_size,
641-
# and this module uses (x, y, z) convention so the order has to be reversed
642-
# (irrelevant in this case since they are all equal).
643-
# It is (2, 2, 2) because the VolumeLocator object behaves like
644-
# align_corners=True, which means that the points are in the corners of
645-
# the volume. So in the grid of (2, 2, 2) there is only one voxel.
646-
grid_sizes=(2, 2, 2),
647-
# The locator object uses (x, y, z) convention for the
648-
# voxel size and translation.
649-
voxel_size=tuple(self.extents),
650-
volume_translation=tuple(self.translation),
651-
# pyre-ignore[29]
652-
device=next(val for val in self.params.values() if val is not None).device,
653-
)
650+
locator = self._get_volume_locator()
654651
# pyre-fixme[29]: `Union[torch._tensor.Tensor,
655652
# torch.nn.modules.module.Module]` is not a function.
656653
grid_values = self.voxel_grid.values_type(**self.params)
657654
# voxel grids operate with extra n_grids dimension, which we fix to one
658655
return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0]
659656

657+
def set_voxel_grid_parameters(self, params: VoxelGridValuesBase) -> None:
658+
"""
659+
Sets the parameters of the underlying voxel grid.
660+
661+
Args:
662+
params: parameters of type `self.voxel_grid.values_type` which will
663+
replace current parameters
664+
"""
665+
if self.hold_voxel_grid_as_parameters:
666+
# pyre-ignore [16]
667+
self.params = torch.nn.ParameterDict(vars(params))
668+
else:
669+
# Torch Module to hold parameters since they can only be registered
670+
# at object level.
671+
self.params = _RegistratedBufferDict(vars(params))
672+
660673
@staticmethod
661674
def get_output_dim(args: DictConfig) -> int:
662675
"""
@@ -672,12 +685,12 @@ def get_output_dim(args: DictConfig) -> int:
672685
args["voxel_grid_" + args["voxel_grid_class_type"] + "_args"]
673686
)
674687

675-
def subscribe_to_epochs(self) -> Tuple[List[int], Callable[[int], bool]]:
688+
def subscribe_to_epochs(self) -> Tuple[Tuple[int, ...], Callable[[int], bool]]:
676689
"""
677690
Method which expresses interest in subscribing to optimization epoch updates.
678691
679692
Returns:
680-
list of epochs on which to call a callable and callable to be called on
693+
tuple of epochs on which to call a callable and callable to be called on
681694
particular epoch. The callable returns True if parameter change has
682695
happened else False and it must be supplied with one argument, epoch.
683696
"""
@@ -697,13 +710,12 @@ def _apply_epochs(self, epoch: int) -> bool:
697710
"""
698711
# pyre-ignore[29]
699712
grid_values = self.voxel_grid.values_type(**self.params)
700-
grid_values, change = self.voxel_grid.change_resolution(epoch, grid_values)
713+
grid_values, change = self.voxel_grid.change_resolution(
714+
grid_values, epoch=epoch
715+
)
701716
if change:
702-
# pyre-ignore[16]
703-
self.params = torch.nn.ParameterDict(
704-
{name: tensor for name, tensor in vars(grid_values).items()}
705-
)
706-
return change
717+
self.set_voxel_grid_parameters(grid_values)
718+
return change and self.hold_voxel_grid_as_parameters
707719

708720
def _create_parameters_with_new_size(
709721
self,
@@ -749,5 +761,113 @@ def _create_parameters_with_new_size(
749761
key = prefix + "params." + name
750762
if key in state_dict:
751763
new_params[name] = torch.zeros_like(state_dict[key])
752-
# pyre-ignore[16]
753-
self.params = torch.nn.ParameterDict(new_params)
764+
# pyre-ignore[29]
765+
self.set_voxel_grid_parameters(self.voxel_grid.values_type(**new_params))
766+
767+
def get_device(self) -> torch.device:
768+
"""
769+
Returns torch.device on which module parameters are located
770+
"""
771+
# pyre-ignore[29]
772+
return next(val for val in self.params.values() if val is not None).device
773+
774+
def _get_volume_locator(self) -> VolumeLocator:
775+
"""
776+
Returns VolumeLocator calculated from `extents` and `translation` members.
777+
"""
778+
return VolumeLocator(
779+
batch_size=1,
780+
# The resolution of the voxel grid does not need to be known
781+
# to the locator object. It is easiest to fix the resolution of the locator.
782+
# In particular we fix it to (2,2,2) so that there is exactly one voxel of the
783+
# desired size. The locator object uses (z, y, x) convention for the grid_size,
784+
# and this module uses (x, y, z) convention so the order has to be reversed
785+
# (irrelevant in this case since they are all equal).
786+
# It is (2, 2, 2) because the VolumeLocator object behaves like
787+
# align_corners=True, which means that the points are in the corners of
788+
# the volume. So in the grid of (2, 2, 2) there is only one voxel.
789+
grid_sizes=(2, 2, 2),
790+
# The locator object uses (x, y, z) convention for the
791+
# voxel size and translation.
792+
voxel_size=tuple(self.extents),
793+
# volume_translation is defined in `VolumeLocator` as a vector from the origin
794+
# of local coordinate frame to origin of world coordinate frame, that is:
795+
# x_world = x_local * extents/2 - translation.
796+
# To get the reverse we need to negate it.
797+
volume_translation=tuple(-t for t in self.translation),
798+
device=self.get_device(),
799+
)
800+
801+
def get_grid_points(self, epoch: int) -> torch.Tensor:
802+
"""
803+
Returns a grid of points that represent centers of voxels of the
804+
underlying voxel grid in world coordinates at specific epoch.
805+
806+
Args:
807+
epoch: underlying voxel grids change resolution depending on the
808+
epoch, this argument is used to determine the resolution
809+
of the voxel grid at that epoch.
810+
Returns:
811+
tensor of shape [xresolution, yresolution, zresolution, 3] where
812+
xresolution, yresolution, zresolution are resolutions of the
813+
underlying voxel grid
814+
"""
815+
xresolution, yresolution, zresolution = self.voxel_grid.get_resolution(epoch)
816+
width, height, depth = self.extents
817+
if not self.voxel_grid.get_align_corners():
818+
width = (
819+
width * (xresolution - 1) / xresolution if xresolution > 1 else width
820+
)
821+
height = (
822+
height * (xresolution - 1) / xresolution if xresolution > 1 else height
823+
)
824+
depth = (
825+
depth * (xresolution - 1) / xresolution if xresolution > 1 else depth
826+
)
827+
xs = torch.linspace(
828+
-width / 2, width / 2, xresolution, device=self.get_device()
829+
)
830+
ys = torch.linspace(
831+
-height / 2, height / 2, yresolution, device=self.get_device()
832+
)
833+
zs = torch.linspace(
834+
-depth / 2, depth / 2, zresolution, device=self.get_device()
835+
)
836+
xmesh, ymesh, zmesh = torch.meshgrid(xs, ys, zs, indexing="ij")
837+
return torch.stack((xmesh, ymesh, zmesh), dim=3)
838+
839+
840+
class _RegistratedBufferDict(torch.nn.Module, Mapping):
841+
"""
842+
Mapping class and a torch.nn.Module that registeres its values
843+
with `self.register_buffer`. Can be indexed like a regular Python
844+
dictionary, but torch.Tensors it contains are properly registered, and will be visible
845+
by all Module methods. Supports only `torch.Tensor` as value and str as key.
846+
"""
847+
848+
def __init__(self, init_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
849+
"""
850+
Args:
851+
init_dict: dictionary which will be used to populate the object
852+
"""
853+
super().__init__()
854+
self._keys = set()
855+
if init_dict is not None:
856+
for k, v in init_dict.items():
857+
self[k] = v
858+
859+
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
860+
return iter({k: self[k] for k in self._keys})
861+
862+
def __len__(self) -> int:
863+
return len(self._keys)
864+
865+
def __getitem__(self, key: str) -> torch.Tensor:
866+
return getattr(self, key)
867+
868+
def __setitem__(self, key, value) -> None:
869+
self._keys.add(key)
870+
self.register_buffer(key, value)
871+
872+
def __hash__(self) -> int:
873+
return hash(repr(self))

pytorch3d/structures/volumes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,7 @@ def __init__(
653653
volume_translation: _Translation = (0.0, 0.0, 0.0),
654654
):
655655
"""
656-
**batch_size** : Batch size of the underlaying grids
656+
**batch_size** : Batch size of the underlying grids
657657
**grid_sizes** : Represents the resolutions of different grids in the batch. Can be
658658
a) tuple of form (H, W, D)
659659
b) list/tuple of length batch_size of lists/tuples of form (H, W, D)

tests/implicitron/test_voxel_grids.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
3535
one by one sample and comparing with the batched implementation.
3636
"""
3737

38-
def test_my_code(self):
39-
return
40-
4138
def get_random_normalized_points(
4239
self, n_grids, n_points=None, dimension=3
4340
) -> torch.Tensor:
@@ -293,6 +290,8 @@ def test_interpolation(self):
293290
padding_mode="zeros",
294291
mode="bilinear",
295292
),
293+
rtol=0.0001,
294+
atol=0.0001,
296295
)
297296
with self.subTest("2D interpolation"):
298297
points = self.get_random_normalized_points(
@@ -308,6 +307,8 @@ def test_interpolation(self):
308307
padding_mode="zeros",
309308
mode="bilinear",
310309
),
310+
rtol=0.0001,
311+
atol=0.0001,
311312
)
312313

313314
with self.subTest("3D interpolation"):
@@ -325,6 +326,7 @@ def test_interpolation(self):
325326
mode="bilinear",
326327
),
327328
rtol=0.0001,
329+
atol=0.0001,
328330
)
329331

330332
def test_floating_point_query(self):
@@ -378,7 +380,8 @@ def test_floating_point_query(self):
378380
assert torch.allclose(
379381
grid.evaluate_local(points, params),
380382
expected_result,
381-
rtol=0.00001,
383+
rtol=0.0001,
384+
atol=0.0001,
382385
), grid.evaluate_local(points, params)
383386
with self.subTest("CP"):
384387
grid = CPFactorizedVoxelGrid(
@@ -446,14 +449,16 @@ def test_floating_point_query(self):
446449
assert torch.allclose(
447450
grid.evaluate_local(points, params),
448451
expected_result_matrix,
449-
rtol=0.00001,
452+
rtol=0.0001,
453+
atol=0.0001,
450454
)
451455
del params.basis_matrix
452456
with self.subTest("CP with sum reduction"):
453457
assert torch.allclose(
454458
grid.evaluate_local(points, params),
455459
expected_result_sum,
456-
rtol=0.00001,
460+
rtol=0.0001,
461+
atol=0.0001,
457462
)
458463

459464
with self.subTest("VM"):
@@ -540,14 +545,16 @@ def test_floating_point_query(self):
540545
assert torch.allclose(
541546
grid.evaluate_local(points, params),
542547
expected_result_matrix,
543-
rtol=0.00001,
548+
rtol=0.0001,
549+
atol=0.0001,
544550
)
545551
del params.basis_matrix
546552
with self.subTest("VM with sum reduction"):
547553
assert torch.allclose(
548554
grid.evaluate_local(points, params),
549555
expected_result_sum,
550556
rtol=0.0001,
557+
atol=0.0001,
551558
), grid.evaluate_local(points, params)
552559

553560
def test_forward_with_small_init_std(self):
@@ -613,6 +620,7 @@ def test_voxel_grid_module_location(self, n_times=10):
613620
grid(world_point)[0, 0],
614621
grid.voxel_grid.evaluate_local(local_point[None], grid_values)[0, 0, 0],
615622
rtol=0.0001,
623+
atol=0.0001,
616624
)
617625

618626
def test_resolution_change(self, n_times=10):

0 commit comments

Comments
 (0)