|
109 | 109 | input_layouts=Replicate(),
|
110 | 110 | output_layouts=Shard(1),
|
111 | 111 | ),
|
| 112 | + "norm": SequenceParallel(), |
112 | 113 | "output": ColwiseParallel(
|
113 | 114 | input_layouts=Shard(1),
|
114 | 115 | output_layouts=Replicate()
|
115 | 116 | ),
|
116 |
| - "norm": SequenceParallel(), |
117 | 117 | }
|
118 | 118 | )
|
119 | 119 |
|
120 | 120 | for layer_id, transformer_block in enumerate(model.layers):
|
121 | 121 | layer_tp_plan = {
|
| 122 | + "attention_norm": SequenceParallel(), |
122 | 123 | "attention": PrepareModuleInput(
|
123 | 124 | input_layouts=(Shard(1), None),
|
124 | 125 | desired_input_layouts=(Replicate(), None),
|
|
127 | 128 | "attention.wk": ColwiseParallel(),
|
128 | 129 | "attention.wv": ColwiseParallel(),
|
129 | 130 | "attention.wo": RowwiseParallel(output_layouts=Shard(1)),
|
130 |
| - "attention_norm": SequenceParallel(), |
| 131 | + "ffn_norm": SequenceParallel(), |
131 | 132 | "feed_forward": PrepareModuleInput(
|
132 | 133 | input_layouts=(Shard(1),),
|
133 | 134 | desired_input_layouts=(Replicate(),),
|
134 | 135 | ),
|
135 | 136 | "feed_forward.w1": ColwiseParallel(),
|
136 | 137 | "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
|
137 | 138 | "feed_forward.w3": ColwiseParallel(),
|
138 |
| - "ffn_norm": SequenceParallel(), |
139 | 139 | }
|
140 | 140 |
|
141 | 141 | # Adjust attention module to use the local number of heads
|
|
0 commit comments