Skip to content

Commit 0960b2e

Browse files
committed
Update on "[Tensor Parallel] update examples to simplify embedding + first transformer block"
Following changes in pytorch/torchtitan#314, to apply a reduce-scatter instead of the more expensive all-reduce + local chunk. cross PR with pytorch/tutorials#2871 [ghstack-poisoned]
1 parent 0ab7688 commit 0960b2e

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

distributed/tensor_parallelism/fsdp_tp_example.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,16 +109,17 @@
109109
input_layouts=Replicate(),
110110
output_layouts=Shard(1),
111111
),
112+
"norm": SequenceParallel(),
112113
"output": ColwiseParallel(
113114
input_layouts=Shard(1),
114115
output_layouts=Replicate()
115116
),
116-
"norm": SequenceParallel(),
117117
}
118118
)
119119

120120
for layer_id, transformer_block in enumerate(model.layers):
121121
layer_tp_plan = {
122+
"attention_norm": SequenceParallel(),
122123
"attention": PrepareModuleInput(
123124
input_layouts=(Shard(1), None),
124125
desired_input_layouts=(Replicate(), None),
@@ -127,15 +128,14 @@
127128
"attention.wk": ColwiseParallel(),
128129
"attention.wv": ColwiseParallel(),
129130
"attention.wo": RowwiseParallel(output_layouts=Shard(1)),
130-
"attention_norm": SequenceParallel(),
131+
"ffn_norm": SequenceParallel(),
131132
"feed_forward": PrepareModuleInput(
132133
input_layouts=(Shard(1),),
133134
desired_input_layouts=(Replicate(),),
134135
),
135136
"feed_forward.w1": ColwiseParallel(),
136137
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
137138
"feed_forward.w3": ColwiseParallel(),
138-
"ffn_norm": SequenceParallel(),
139139
}
140140

141141
# Adjust attention module to use the local number of heads

0 commit comments

Comments
 (0)