From 6fcefaf03c649635cb9f9671831bf3fb0c95be17 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Thu, 16 May 2024 14:26:39 -0700 Subject: [PATCH] [Tensor Parallel] update tutorial to simplify embedding + first transformer block --- intermediate_source/TP_tutorial.rst | 47 +++++++++++++++++++++-------- 1 file changed, 35 insertions(+), 12 deletions(-) diff --git a/intermediate_source/TP_tutorial.rst b/intermediate_source/TP_tutorial.rst index 72e72869348..2d0193990d4 100644 --- a/intermediate_source/TP_tutorial.rst +++ b/intermediate_source/TP_tutorial.rst @@ -164,6 +164,22 @@ Finally, we need to call ``parallelize_module`` API to make the plan for each `` ) Now that we have elaborated the sharding plan for each ``TransformerBlock``, there is usually a ``nn.Embedding`` in the first layer and a final ``nn.Linear`` projection layer, where user could choose row-wise or column-wise sharding to the first ``nn.Embedding`` and column-wise sharding to the last ``nn.Linear`` projection layer with proper input and output layouts specified. +Here is an example: + +.. code-block:: python + + model = parallelize_module( + model, + tp_mesh, + { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + ), + "output": ColwiseParallel( + output_layouts=Replicate(), + ), + } + ) .. note:: If the model to be partitioned is too large to fit into CPU memory, one could either use ``meta`` device initialization (for example, initialize the model on meta device first, shard the layers, and the materialize the model), or parallelize the ``TransformerBlock`` layer by layer during the Transformer model initialization. @@ -203,6 +219,7 @@ Next let's adjust the ``layer_tp_plan`` to enable sequence parallel on the ``RMS layer_tp_plan = { # Now the input and output of SequenceParallel has Shard(1) layouts, # to represent the input/output tensors sharded on the sequence dimension + "attention_norm": SequenceParallel(), "attention": PrepareModuleInput( input_layouts=(Shard(1),), desired_input_layouts=(Replicate(),), @@ -211,7 +228,7 @@ Next let's adjust the ``layer_tp_plan`` to enable sequence parallel on the ``RMS "attention.wk": ColwiseParallel(), "attention.wv": ColwiseParallel(), "attention.wo": RowwiseParallel(output_layouts=Shard(1)), - "attention_norm": SequenceParallel(), + "ffn_norm": SequenceParallel(), "feed_forward": PrepareModuleInput( input_layouts=(Shard(1),), desired_input_layouts=(Replicate(),), @@ -219,7 +236,6 @@ Next let's adjust the ``layer_tp_plan`` to enable sequence parallel on the ``RMS "feed_forward.w1": ColwiseParallel(), "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)), "feed_forward.w3": ColwiseParallel(), - "ffn_norm": SequenceParallel(), } @@ -227,17 +243,24 @@ One can see we now use ``PrepareModuleInput`` to modify the module input layouts Just like what happens to Tensor Parallelism, one only needs to specify the tensor sharding layouts of the inputs and outputs, and the communication between layers will happen automatically. Note that with Sequence Parallel, we assume the inputs and outputs of a ``TransformerBlock`` are always sharded on the sequence dimension, so that multiple ``TransformerBlocks`` can be concatenated seamlessly. -The only exception is that the input to the first ``TransformerBlock`` is replicated from the data loaders, so it has to be converted explicitly: +This can be facilitated by explicitly specifying the output of the beginning ``nn.Embedding`` layer and the input of the final ``nn.Linear`` projection layer to be ``Shard(1)``: .. code-block:: python model = parallelize_module( model, tp_mesh, - "layers.0": PrepareModuleInput( - input_layouts=(Replicate(),), - desired_input_layouts=(Shard(1),), - ), + { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "norm": SequenceParallel(), + "output": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Replicate() + ), + } ) @@ -263,16 +286,16 @@ To apply Loss Parallel, the model predictions, usually of the shape ``[batch siz model, tp_mesh, { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "norm": SequenceParallel(), "output": ColwiseParallel( input_layouts=Shard(1), # use DTensor as the output use_local_output=False, ), - "norm": SequenceParallel(), - "layers.0": PrepareModuleInput( - input_layouts=(Replicate(),), - desired_input_layouts=(Shard(1),), - ), }, )