From 0ab7688addef27e94ff3e5aa5751f4ac4927337e Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 15 May 2024 17:57:47 -0700 Subject: [PATCH 1/2] [Tensor Parallel] update examples to simplify embedding + first transformer block [ghstack-poisoned] --- distributed/tensor_parallelism/fsdp_tp_example.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/distributed/tensor_parallelism/fsdp_tp_example.py b/distributed/tensor_parallelism/fsdp_tp_example.py index 258da9fd64..1a1e942861 100644 --- a/distributed/tensor_parallelism/fsdp_tp_example.py +++ b/distributed/tensor_parallelism/fsdp_tp_example.py @@ -107,17 +107,13 @@ { "tok_embeddings": RowwiseParallel( input_layouts=Replicate(), + output_layouts=Shard(1), ), "output": ColwiseParallel( input_layouts=Shard(1), output_layouts=Replicate() ), "norm": SequenceParallel(), - "layers.0": PrepareModuleInput( - input_layouts=(Replicate(), None), - desired_input_layouts=(Shard(1), None), - use_local_output=True, - ), } ) From 0960b2e204dfc7c413d79f98459826928408e029 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 15 May 2024 18:42:20 -0700 Subject: [PATCH 2/2] Update on "[Tensor Parallel] update examples to simplify embedding + first transformer block" Following changes in https://github.com/pytorch/torchtitan/pull/314, to apply a reduce-scatter instead of the more expensive all-reduce + local chunk. cross PR with https://github.com/pytorch/tutorials/pull/2871 [ghstack-poisoned] --- distributed/tensor_parallelism/fsdp_tp_example.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/distributed/tensor_parallelism/fsdp_tp_example.py b/distributed/tensor_parallelism/fsdp_tp_example.py index 1a1e942861..dbab48c1b8 100644 --- a/distributed/tensor_parallelism/fsdp_tp_example.py +++ b/distributed/tensor_parallelism/fsdp_tp_example.py @@ -109,16 +109,17 @@ input_layouts=Replicate(), output_layouts=Shard(1), ), + "norm": SequenceParallel(), "output": ColwiseParallel( input_layouts=Shard(1), output_layouts=Replicate() ), - "norm": SequenceParallel(), } ) for layer_id, transformer_block in enumerate(model.layers): layer_tp_plan = { + "attention_norm": SequenceParallel(), "attention": PrepareModuleInput( input_layouts=(Shard(1), None), desired_input_layouts=(Replicate(), None), @@ -127,7 +128,7 @@ "attention.wk": ColwiseParallel(), "attention.wv": ColwiseParallel(), "attention.wo": RowwiseParallel(output_layouts=Shard(1)), - "attention_norm": SequenceParallel(), + "ffn_norm": SequenceParallel(), "feed_forward": PrepareModuleInput( input_layouts=(Shard(1),), desired_input_layouts=(Replicate(),), @@ -135,7 +136,6 @@ "feed_forward.w1": ColwiseParallel(), "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)), "feed_forward.w3": ColwiseParallel(), - "ffn_norm": SequenceParallel(), } # Adjust attention module to use the local number of heads