Skip to content

Commit e4a9fb3

Browse files
authored
Bring Flax attention naming in sync with PyTorch (#2511)
Bring flax attention naming in sync with PyTorch.
1 parent eadf0e2 commit e4a9fb3

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

src/diffusers/models/attention_flax.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import jax.numpy as jnp
1717

1818

19-
class FlaxAttentionBlock(nn.Module):
19+
class FlaxCrossAttention(nn.Module):
2020
r"""
2121
A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
2222
@@ -118,10 +118,10 @@ class FlaxBasicTransformerBlock(nn.Module):
118118

119119
def setup(self):
120120
# self attention (or cross_attention if only_cross_attention is True)
121-
self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
121+
self.attn1 = FlaxCrossAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
122122
# cross attention
123-
self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
124-
self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
123+
self.attn2 = FlaxCrossAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
124+
self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
125125
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
126126
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
127127
self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
@@ -242,10 +242,14 @@ def __call__(self, hidden_states, context, deterministic=True):
242242
return hidden_states
243243

244244

245-
class FlaxGluFeedForward(nn.Module):
245+
class FlaxFeedForward(nn.Module):
246246
r"""
247-
Flax module that encapsulates two Linear layers separated by a gated linear unit activation from:
247+
Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's
248+
[`FeedForward`] class, with the following simplifications:
249+
- The activation function is currently hardcoded to a gated linear unit from:
248250
https://arxiv.org/abs/2002.05202
251+
- `dim_out` is equal to `dim`.
252+
- The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`].
249253
250254
Parameters:
251255
dim (:obj:`int`):

0 commit comments

Comments
 (0)