diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 71106e05452c..9b1ebccb8ae9 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -16,7 +16,7 @@ import jax.numpy as jnp -class FlaxAttentionBlock(nn.Module): +class FlaxCrossAttention(nn.Module): r""" A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762 @@ -118,10 +118,10 @@ class FlaxBasicTransformerBlock(nn.Module): def setup(self): # self attention (or cross_attention if only_cross_attention is True) - self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) + self.attn1 = FlaxCrossAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) # cross attention - self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) - self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) + self.attn2 = FlaxCrossAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) + self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) @@ -242,10 +242,14 @@ def __call__(self, hidden_states, context, deterministic=True): return hidden_states -class FlaxGluFeedForward(nn.Module): +class FlaxFeedForward(nn.Module): r""" - Flax module that encapsulates two Linear layers separated by a gated linear unit activation from: + Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's + [`FeedForward`] class, with the following simplifications: + - The activation function is currently hardcoded to a gated linear unit from: https://arxiv.org/abs/2002.05202 + - `dim_out` is equal to `dim`. + - The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`]. Parameters: dim (:obj:`int`):