Skip to content

Commit 76cddd9

Browse files
Darijan Gudeljfacebook-github-bot
Darijan Gudelj
authored andcommitted
Elementwise decoder
Summary: Tensorf does relu or softmax after the density grid. This diff adds the ability to replicate that. Reviewed By: bottler Differential Revision: D40023228 fbshipit-source-id: 9f19868cd68460af98ab6e61c7f708158c26dc08
1 parent a607dd0 commit 76cddd9

File tree

2 files changed

+31
-37
lines changed

2 files changed

+31
-37
lines changed

pytorch3d/implicitron/models/implicit_function/decoding_functions.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,43 @@ def forward(
5454

5555

5656
@registry.register
57-
class IdentityDecoder(DecoderFunctionBase):
57+
class ElementwiseDecoder(DecoderFunctionBase):
5858
"""
59-
Decoding function which returns its input.
59+
Decoding function which scales the input, adds shift and then applies
60+
`relu`, `softplus`, `sigmoid` or nothing on its input:
61+
`result = operation(input * scale + shift)`
62+
63+
Members:
64+
scale: a scalar with which input is multiplied before being shifted.
65+
Defaults to 1.
66+
shift: a scalar which is added to the scaled input before performing
67+
the operation. Defaults to 0.
68+
operation: which operation to perform on the transformed input. Options are:
69+
`relu`, `softplus`, `sigmoid` and `identity`. Defaults to `identity`.
6070
"""
6171

72+
scale: float = 1
73+
shift: float = 0
74+
operation: str = "identity"
75+
76+
def __post_init__(self):
77+
super().__post_init__()
78+
if self.operation not in ["relu", "softplus", "sigmoid", "identity"]:
79+
raise ValueError(
80+
"`operation` can only be `relu`, `softplus`, `sigmoid` or identity."
81+
)
82+
6283
def forward(
6384
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
6485
) -> torch.Tensor:
65-
return features
86+
transfomed_input = features * self.scale + self.shift
87+
if self.operation == "softplus":
88+
return torch.nn.functional.softplus(transfomed_input)
89+
if self.operation == "relu":
90+
return torch.nn.functional.relu(transfomed_input)
91+
if self.operation == "sigmoid":
92+
return torch.nn.functional.sigmoid(transfomed_input)
93+
return transfomed_input
6694

6795

6896
class MLPWithInputSkips(Configurable, torch.nn.Module):

tests/implicitron/test_decoding_functions.py

Lines changed: 0 additions & 34 deletions
This file was deleted.

0 commit comments

Comments
 (0)