|
16 | 16 | import jax.numpy as jnp
|
17 | 17 |
|
18 | 18 |
|
19 |
| -class FlaxAttentionBlock(nn.Module): |
| 19 | +class FlaxCrossAttention(nn.Module): |
20 | 20 | r"""
|
21 | 21 | A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
|
22 | 22 |
|
@@ -118,10 +118,10 @@ class FlaxBasicTransformerBlock(nn.Module):
|
118 | 118 |
|
119 | 119 | def setup(self):
|
120 | 120 | # self attention (or cross_attention if only_cross_attention is True)
|
121 |
| - self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) |
| 121 | + self.attn1 = FlaxCrossAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) |
122 | 122 | # cross attention
|
123 |
| - self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) |
124 |
| - self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) |
| 123 | + self.attn2 = FlaxCrossAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) |
| 124 | + self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) |
125 | 125 | self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
126 | 126 | self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
127 | 127 | self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
@@ -242,10 +242,14 @@ def __call__(self, hidden_states, context, deterministic=True):
|
242 | 242 | return hidden_states
|
243 | 243 |
|
244 | 244 |
|
245 |
| -class FlaxGluFeedForward(nn.Module): |
| 245 | +class FlaxFeedForward(nn.Module): |
246 | 246 | r"""
|
247 |
| - Flax module that encapsulates two Linear layers separated by a gated linear unit activation from: |
| 247 | + Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's |
| 248 | + [`FeedForward`] class, with the following simplifications: |
| 249 | + - The activation function is currently hardcoded to a gated linear unit from: |
248 | 250 | https://arxiv.org/abs/2002.05202
|
| 251 | + - `dim_out` is equal to `dim`. |
| 252 | + - The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`]. |
249 | 253 |
|
250 | 254 | Parameters:
|
251 | 255 | dim (:obj:`int`):
|
|
0 commit comments