You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: intermediate_source/TP_tutorial.rst
+33-10Lines changed: 33 additions & 10 deletions
Original file line number
Diff line number
Diff line change
@@ -164,6 +164,22 @@ Finally, we need to call ``parallelize_module`` API to make the plan for each ``
164
164
)
165
165
166
166
Now that we have elaborated the sharding plan for each ``TransformerBlock``, there is usually a ``nn.Embedding`` in the first layer and a final ``nn.Linear`` projection layer, where user could choose row-wise or column-wise sharding to the first ``nn.Embedding`` and column-wise sharding to the last ``nn.Linear`` projection layer with proper input and output layouts specified.
167
+
Here is an example:
168
+
169
+
.. code-block:: python
170
+
171
+
model = parallelize_module(
172
+
model,
173
+
tp_mesh,
174
+
{
175
+
"tok_embeddings": RowwiseParallel(
176
+
input_layouts=Replicate(),
177
+
),
178
+
"output": ColwiseParallel(
179
+
output_layouts=Replicate(),
180
+
),
181
+
}
182
+
)
167
183
168
184
.. note::
169
185
If the model to be partitioned is too large to fit into CPU memory, one could either use ``meta`` device initialization (for example, initialize the model on meta device first, shard the layers, and the materialize the model), or parallelize the ``TransformerBlock`` layer by layer during the Transformer model initialization.
@@ -227,17 +243,24 @@ One can see we now use ``PrepareModuleInput`` to modify the module input layouts
227
243
Just like what happens to Tensor Parallelism, one only needs to specify the tensor sharding layouts of the inputs and outputs, and the communication between layers will happen automatically.
228
244
229
245
Note that with Sequence Parallel, we assume the inputs and outputs of a ``TransformerBlock`` are always sharded on the sequence dimension, so that multiple ``TransformerBlocks`` can be concatenated seamlessly.
230
-
The only exception is that the input to the first ``TransformerBlock`` is replicated from the data loaders, so it has to be converted explicitly:
246
+
This can be facilitated by explicitly specifying the output of the beginning ``nn.Embedding`` layer and the input of the final ``nn.Linear`` projection layer to be ``Shard(1)``:
231
247
232
248
.. code-block:: python
233
249
234
250
model = parallelize_module(
235
251
model,
236
252
tp_mesh,
237
-
"layers.0": PrepareModuleInput(
238
-
input_layouts=(Replicate(),),
239
-
desired_input_layouts=(Shard(1),),
240
-
),
253
+
{
254
+
"tok_embeddings": RowwiseParallel(
255
+
input_layouts=Replicate(),
256
+
output_layouts=Shard(1),
257
+
),
258
+
"norm": SequenceParallel(),
259
+
"output": ColwiseParallel(
260
+
input_layouts=Shard(1),
261
+
output_layouts=Replicate()
262
+
),
263
+
}
241
264
)
242
265
243
266
@@ -263,16 +286,16 @@ To apply Loss Parallel, the model predictions, usually of the shape ``[batch siz
0 commit comments