Skip to content

Commit c2d876c

Browse files
Darijan Gudeljfacebook-github-bot
Darijan Gudelj
authored andcommitted
voxel grid implicit function
Summary: The implicit function and its members and internal working Reviewed By: kjchalup Differential Revision: D38829764 fbshipit-source-id: 28394fe7819e311ed52c9defc9a1b29f37fbc495
1 parent d6a197b commit c2d876c

File tree

3 files changed

+40
-7
lines changed

3 files changed

+40
-7
lines changed

pytorch3d/implicitron/models/implicit_function/decoding_functions.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
import torch
2020

21+
from omegaconf import DictConfig
22+
2123
from pytorch3d.implicitron.tools.config import (
2224
Configurable,
2325
registry,
@@ -179,8 +181,11 @@ def forward(self, x: torch.Tensor, z: Optional[torch.Tensor] = None):
179181
class MLPDecoder(DecoderFunctionBase):
180182
"""
181183
Decoding function which uses `MLPWithIputSkips` to convert the embedding to output.
184+
If using Implicitron config system `input_dim` of the `network` is changed to the
185+
value of `input_dim` member and `input_skips` is removed.
182186
"""
183187

188+
input_dim: int = 3
184189
network: MLPWithInputSkips
185190

186191
def __post_init__(self):
@@ -192,6 +197,20 @@ def forward(
192197
) -> torch.Tensor:
193198
return self.network(features, z)
194199

200+
@classmethod
201+
def network_tweak_args(cls, type, args: DictConfig) -> None:
202+
"""
203+
Special method to stop get_default_args exposing member's `input_dim`.
204+
"""
205+
args.pop("input_dim", None)
206+
207+
def create_network_impl(self, type, args: DictConfig) -> None:
208+
"""
209+
Set the input dimension of the `network` to the input dimension of the
210+
decoding function.
211+
"""
212+
self.network = MLPWithInputSkips(input_dim=self.input_dim, **args)
213+
195214

196215
class TransformerWithInputSkips(torch.nn.Module):
197216
def __init__(

pytorch3d/implicitron/models/implicit_function/voxel_grid.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
6565
padding: str = "zeros"
6666
mode: str = "bilinear"
6767
n_features: int = 1
68-
resolution: Tuple[int, int, int] = (64, 64, 64)
68+
resolution: Tuple[int, int, int] = (128, 128, 128)
6969

7070
def __post_init__(self):
7171
super().__init__()
@@ -507,8 +507,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
507507
voxel_grid_class_type: str = "FullResolutionVoxelGrid"
508508
voxel_grid: VoxelGridBase
509509

510-
# pyre-fixme[8]: Attribute has type `Tuple[float, float, float]`; used as `float`.
511-
extents: Tuple[float, float, float] = 1.0
510+
extents: Tuple[float, float, float] = (1.0, 1.0, 1.0)
512511
translation: Tuple[float, float, float] = (0.0, 0.0, 0.0)
513512

514513
init_std: float = 0.1
@@ -552,13 +551,28 @@ def forward(self, points: torch.Tensor) -> torch.Tensor:
552551
grid_sizes=(2, 2, 2),
553552
# The locator object uses (x, y, z) convention for the
554553
# voxel size and translation.
555-
voxel_size=self.extents,
556-
volume_translation=self.translation,
557-
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C._TensorBase...
554+
voxel_size=tuple(self.extents),
555+
volume_translation=tuple(self.translation),
556+
# pyre-ignore[29]
558557
device=next(self.params.values()).device,
559558
)
560559
# pyre-fixme[29]: `Union[torch._tensor.Tensor,
561560
# torch.nn.modules.module.Module]` is not a function.
562561
grid_values = self.voxel_grid.values_type(**self.params)
563562
# voxel grids operate with extra n_grids dimension, which we fix to one
564563
return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0]
564+
565+
@staticmethod
566+
def get_output_dim(args: DictConfig) -> int:
567+
"""
568+
Utility to help predict the shape of the output of `forward`.
569+
570+
Args:
571+
args: DictConfig which would be used to initialize the object
572+
Returns:
573+
int: the length of the last dimension of the output tensor
574+
"""
575+
grid = registry.get(VoxelGridBase, args["voxel_grid_class_type"])
576+
return grid.get_output_dim(
577+
args["voxel_grid_" + args["voxel_grid_class_type"] + "_args"]
578+
)

pytorch3d/renderer/implicit/harmonic_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
8888
embedding: a harmonic embedding of `x`
8989
of shape [..., (n_harmonic_functions * 2 + int(append_input)) * dim]
9090
"""
91-
embed = (x[..., None] * self._frequencies).view(*x.shape[:-1], -1)
91+
embed = (x[..., None] * self._frequencies).reshape(*x.shape[:-1], -1)
9292
embed = torch.cat(
9393
(embed.sin(), embed.cos(), x)
9494
if self.append_input

0 commit comments

Comments
 (0)