Skip to content

update pipelining tutorial #3182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions intermediate_source/pipelining_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ chunks. First, let us define the model:
h = layer(h, h)

h = self.norm(h) if self.norm else h
output = self.output(h).float() if self.output else h
output = self.output(h).clone() if self.output else h
return output

Then, we need to import the necessary libraries in our script and initialize the distributed training process. In this case, we are defining some global variables to use
Expand Down Expand Up @@ -109,32 +109,29 @@ Step 1: Partition the Transformer Model
There are two different ways of partitioning the model:

First is the manual mode in which we can manually create two instances of the model by deleting portions of
attributes of the model. In this example for a 2 stage (2 ranks) the model is cut in half.
attributes of the model. In this example for two stages (2 ranks), the model is cut in half.

.. code:: python

def manual_model_split(model, example_input_microbatch, model_args) -> PipelineStage:
def manual_model_split(model) -> PipelineStage:
if stage_index == 0:
# prepare the first stage model
for i in range(4, 8):
del model.layers[str(i)]
model.norm = None
model.output = None
stage_input_microbatch = example_input_microbatch

elif stage_index == 1:
# prepare the second stage model
for i in range(4):
del model.layers[str(i)]
model.tok_embeddings = None
stage_input_microbatch = torch.randn(example_input_microbatch.shape[0], example_input_microbatch.shape[1], model_args.dim)

stage = PipelineStage(
model,
stage_index,
num_stages,
device,
input_args=stage_input_microbatch,
)
return stage

Expand Down Expand Up @@ -181,18 +178,19 @@ as well as multiple-stage-per-rank schedules such as ``Interleaved1F1B`` and ``L
example_input_microbatch = x.chunk(num_microbatches)[0]

# Option 1: Manual model splitting
stage = manual_model_split(model, example_input_microbatch, model_args)
stage = manual_model_split(model)

# Option 2: Tracer model splitting
# stage = tracer_model_split(model, example_input_microbatch)

model.to(device)
x = x.to(device)
y = y.to(device)

def tokenwise_loss_fn(outputs, targets):
loss_fn = nn.CrossEntropyLoss()
outputs = outputs.view(-1, model_args.vocab_size)
targets = targets.view(-1)
outputs = outputs.reshape(-1, model_args.vocab_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, why this change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes weren't necessary but I do think they are clearer

targets = targets.reshape(-1)
return loss_fn(outputs, targets)

schedule = ScheduleGPipe(stage, n_microbatches=num_microbatches, loss_fn=tokenwise_loss_fn)
Expand All @@ -202,6 +200,7 @@ as well as multiple-stage-per-rank schedules such as ``Interleaved1F1B`` and ``L
elif rank == 1:
losses = []
output = schedule.step(target=y, losses=losses)
print(f"losses: {losses}")
dist.destroy_process_group()

In the example above, we are using the manual method to split the model, but the code can be uncommented to also try the
Expand Down Expand Up @@ -232,5 +231,10 @@ We discussed two methods of model partitioning, manual and tracer-based, and dem
micro-batches across different stages. Finally, we covered the execution of the pipeline schedule and the launch of distributed
processes using ``torchrun``.

For a production ready usage of pipeline parallelism as well as composition with other distributed techniques, see also
Additional Resources
--------------------

We have successfully integrated ``torch.distributed.pipelining`` into the `torchtitan repository <https://github.com/pytorch/torchtitan>`__. TorchTitan is a clean, minimal code base for
large-scale LLM training using native PyTorch. For a production ready usage of pipeline
parallelism as well as composition with other distributed techniques, see
`TorchTitan end to end example of 3D parallelism <https://github.com/pytorch/torchtitan>`__.
Loading