diff --git a/distributed/tensor_parallelism/fsdp_tp_example.py b/distributed/tensor_parallelism/fsdp_tp_example.py index 258da9fd64..dbab48c1b8 100644 --- a/distributed/tensor_parallelism/fsdp_tp_example.py +++ b/distributed/tensor_parallelism/fsdp_tp_example.py @@ -107,22 +107,19 @@ { "tok_embeddings": RowwiseParallel( input_layouts=Replicate(), + output_layouts=Shard(1), ), + "norm": SequenceParallel(), "output": ColwiseParallel( input_layouts=Shard(1), output_layouts=Replicate() ), - "norm": SequenceParallel(), - "layers.0": PrepareModuleInput( - input_layouts=(Replicate(), None), - desired_input_layouts=(Shard(1), None), - use_local_output=True, - ), } ) for layer_id, transformer_block in enumerate(model.layers): layer_tp_plan = { + "attention_norm": SequenceParallel(), "attention": PrepareModuleInput( input_layouts=(Shard(1), None), desired_input_layouts=(Replicate(), None), @@ -131,7 +128,7 @@ "attention.wk": ColwiseParallel(), "attention.wv": ColwiseParallel(), "attention.wo": RowwiseParallel(output_layouts=Shard(1)), - "attention_norm": SequenceParallel(), + "ffn_norm": SequenceParallel(), "feed_forward": PrepareModuleInput( input_layouts=(Shard(1),), desired_input_layouts=(Replicate(),), @@ -139,7 +136,6 @@ "feed_forward.w1": ColwiseParallel(), "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)), "feed_forward.w3": ColwiseParallel(), - "ffn_norm": SequenceParallel(), } # Adjust attention module to use the local number of heads