Skip to content

Commit 33aeafa

Browse files
Bill Yangfacebook-github-bot
Bill Yang
authored andcommitted
add activation on last logic (#2924)
Summary: Pull Request resolved: #2924 # context Adding parity for activation on last flag to make module more configurable. activation_on_last is a flag that toggles the given (or default) activation function to the last layer of the MLP. Typically ALL layers of the MLP will have the activation function. If it is set to false, then the last layer will not have the activation function applied. This is so users can optionally use the raw MLP output for their own customized needs. Reviewed By: TroyGarden Differential Revision: D73691616 fbshipit-source-id: 87a720e9b7e10f2bbb478b68a562a1cd90a36199
1 parent aa82c8e commit 33aeafa

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

torchrec/modules/mlp.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def __init__(
128128
] = torch.relu,
129129
device: Optional[torch.device] = None,
130130
dtype: torch.dtype = torch.float32,
131+
activation_on_last: bool = True,
131132
) -> None:
132133
super().__init__()
133134

@@ -143,7 +144,11 @@ def __init__(
143144
layer_sizes[i - 1] if i > 0 else in_size,
144145
layer_sizes[i],
145146
bias=bias,
146-
activation=extract_module_or_tensor_callable(activation),
147+
activation=(
148+
torch.nn.Identity()
149+
if not activation_on_last and i == len(layer_sizes) - 1
150+
else extract_module_or_tensor_callable(activation)
151+
),
147152
device=device,
148153
dtype=dtype,
149154
)
@@ -158,7 +163,11 @@ def __init__(
158163
layer_sizes[i - 1] if i > 0 else in_size,
159164
layer_sizes[i],
160165
bias=bias,
161-
activation=SwishLayerNorm(layer_sizes[i], device=device),
166+
activation=(
167+
torch.nn.Identity()
168+
if not activation_on_last and i == len(layer_sizes) - 1
169+
else SwishLayerNorm(layer_sizes[i], device=device)
170+
),
162171
device=device,
163172
)
164173
for i in range(len(layer_sizes))

0 commit comments

Comments
 (0)