Skip to content

Commit 9242c60

Browse files
H-Huangsvekars
andauthored
update pipelining tutorial (#3182)
* update pipelining tutorial * Update intermediate_source/pipelining_tutorial.rst Co-authored-by: Svetlana Karslioglu <svekars@meta.com> * Update intermediate_source/pipelining_tutorial.rst Co-authored-by: Svetlana Karslioglu <svekars@meta.com> * Update intermediate_source/pipelining_tutorial.rst Co-authored-by: Svetlana Karslioglu <svekars@meta.com> --------- Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
1 parent 7038ce7 commit 9242c60

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

intermediate_source/pipelining_tutorial.rst

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ chunks. First, let us define the model:
6767
h = layer(h, h)
6868
6969
h = self.norm(h) if self.norm else h
70-
output = self.output(h).float() if self.output else h
70+
output = self.output(h).clone() if self.output else h
7171
return output
7272
7373
Then, we need to import the necessary libraries in our script and initialize the distributed training process. In this case, we are defining some global variables to use
@@ -109,32 +109,29 @@ Step 1: Partition the Transformer Model
109109
There are two different ways of partitioning the model:
110110

111111
First is the manual mode in which we can manually create two instances of the model by deleting portions of
112-
attributes of the model. In this example for a 2 stage (2 ranks) the model is cut in half.
112+
attributes of the model. In this example for two stages (2 ranks), the model is cut in half.
113113

114114
.. code:: python
115115
116-
def manual_model_split(model, example_input_microbatch, model_args) -> PipelineStage:
116+
def manual_model_split(model) -> PipelineStage:
117117
if stage_index == 0:
118118
# prepare the first stage model
119119
for i in range(4, 8):
120120
del model.layers[str(i)]
121121
model.norm = None
122122
model.output = None
123-
stage_input_microbatch = example_input_microbatch
124123
125124
elif stage_index == 1:
126125
# prepare the second stage model
127126
for i in range(4):
128127
del model.layers[str(i)]
129128
model.tok_embeddings = None
130-
stage_input_microbatch = torch.randn(example_input_microbatch.shape[0], example_input_microbatch.shape[1], model_args.dim)
131129
132130
stage = PipelineStage(
133131
model,
134132
stage_index,
135133
num_stages,
136134
device,
137-
input_args=stage_input_microbatch,
138135
)
139136
return stage
140137
@@ -181,18 +178,19 @@ as well as multiple-stage-per-rank schedules such as ``Interleaved1F1B`` and ``L
181178
example_input_microbatch = x.chunk(num_microbatches)[0]
182179
183180
# Option 1: Manual model splitting
184-
stage = manual_model_split(model, example_input_microbatch, model_args)
181+
stage = manual_model_split(model)
185182
186183
# Option 2: Tracer model splitting
187184
# stage = tracer_model_split(model, example_input_microbatch)
188185
186+
model.to(device)
189187
x = x.to(device)
190188
y = y.to(device)
191189
192190
def tokenwise_loss_fn(outputs, targets):
193191
loss_fn = nn.CrossEntropyLoss()
194-
outputs = outputs.view(-1, model_args.vocab_size)
195-
targets = targets.view(-1)
192+
outputs = outputs.reshape(-1, model_args.vocab_size)
193+
targets = targets.reshape(-1)
196194
return loss_fn(outputs, targets)
197195
198196
schedule = ScheduleGPipe(stage, n_microbatches=num_microbatches, loss_fn=tokenwise_loss_fn)
@@ -202,6 +200,7 @@ as well as multiple-stage-per-rank schedules such as ``Interleaved1F1B`` and ``L
202200
elif rank == 1:
203201
losses = []
204202
output = schedule.step(target=y, losses=losses)
203+
print(f"losses: {losses}")
205204
dist.destroy_process_group()
206205
207206
In the example above, we are using the manual method to split the model, but the code can be uncommented to also try the
@@ -232,5 +231,10 @@ We discussed two methods of model partitioning, manual and tracer-based, and dem
232231
micro-batches across different stages. Finally, we covered the execution of the pipeline schedule and the launch of distributed
233232
processes using ``torchrun``.
234233

235-
For a production ready usage of pipeline parallelism as well as composition with other distributed techniques, see also
234+
Additional Resources
235+
--------------------
236+
237+
We have successfully integrated ``torch.distributed.pipelining`` into the `torchtitan repository <https://github.com/pytorch/torchtitan>`__. TorchTitan is a clean, minimal code base for
238+
large-scale LLM training using native PyTorch. For a production ready usage of pipeline
239+
parallelism as well as composition with other distributed techniques, see
236240
`TorchTitan end to end example of 3D parallelism <https://github.com/pytorch/torchtitan>`__.

0 commit comments

Comments
 (0)