Skip to content

Commit f03bd79

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
Explore new pipeline that overlaps optimizer with emb_lookup (#2916)
Summary: Pull Request resolved: #2916 # context * this workstream started from an training QPS optimization initiated from the PG side (see doc in the reference section), observing the embedding lookup can overlap with the optimizer. * Embedding table weights are updated in the fused backward (fused-TBE), so the embedding lookup can start immediately after backward is completed without dependency on the optiimzer. * we use a separate stream to run embedding lookup so that it can overlap with the previous optimizer (changed, see below) * there is also an option of using data_dist stream for this embedding lookup, the output_dist won't be block but the start_sparse_data_dist would, which results a smaller mem footprint. WARNING: This pipeline **DOES NOT** work for EBC/EC with feature processors because the embedding lookup is started immediately after TBE backward (where the embedding tables' weights have been updated) # benchmark readings * runtime: SemiSync < FusedSparseDist (lookup after opt) < FusedSparseDist (lookup before opt) < SparseDist ``` TrainPipelineSemiSync | Runtime (P90): 5447.42 ms | Peak Memory alloc (P90): 61.63 GB | Peak Memory reserved (P90): 64.31 GB TrainPipelineFusedSparseDist | Runtime (P90): 5605.63 ms | Peak Memory alloc (P90): 53.23 GB | Peak Memory reserved (P90): 68.61 GB TrainPipelineFusedSparseDist* | Runtime (P90): 5661.92 ms | Peak Memory alloc (P90): 53.23 GB | Peak Memory reserved (P90): 68.67 GB TrainPipelineSparseDist | Runtime (P90): 6034.46 ms | Peak Memory alloc (P90): 51.80 GB | Peak Memory reserved (P90): 62.25 GB * embedding_lookup_after_opt = False ``` * traces show that: (1) the emb_lookup is right behind the TBE-bwd (on the same cuda stream) (2) the output_dist is invoked right after each emb_lookup (there are two, one for unweighted ebc, one for weighted) (3) the optimizer seems **NOT** overlap with emb_lookup kernel when `embedding_lookup_after_opt = False` {F1977309185} (4) the optimizer still does **NOT** overlap with emb_lookup kernel, but it fills in the gap between the `KJTTensorAwaitable.wait()` and the embedding lookup kernel when `embedding_lookup_after_opt = True` {F1977309202} (5) if use a separate stream for embedding lookup, so that the following `start_sparse_data_dist` can start immediately. however this causes extra memory consumption. {F1977366363} (6) if re-use the data_dist stream for embedding lookup, the following up `start_sparse_data_dist` will wait for embedding lookup to complete, the measured memory footprint is smaller {F1977366349} NOTE: Based on (5) and (6) we set `use_emb_lookup_stream = False` is the default behavior # conclusions * Based on a simple model (SparseNN), both "Fused Sparse Dist" pipeline and the "Semi Sync" pipeline are faster than the current default (commonly used) "Sparse Dist" pipeline, respectively -7% (fused sparse dist) and -10% (semi sync) in runtime. * In a more realistic scenario, the optimizer step has a longer runtime footprint, which can amplify this optimization. * The "Semi Sync" pipeline has a larger QPS win but it produces slightly different numerical training results, while the "Fused Sparse Dist" pipeline with a slight few QPS win should be numerically the same as the default pipeline. * It would be the user's choice for which one to use. # reference * https://dev-discuss.pytorch.org/t/fsdp-cudacachingallocator-an-outsider-newb-perspective/1486 Reviewed By: dstaay-fb Differential Revision: D64479105 fbshipit-source-id: c6bd1306823d02afafd7e2f1d9f95e1d068d6f61
1 parent f35befa commit f03bd79

File tree

4 files changed

+190
-2
lines changed

4 files changed

+190
-2
lines changed

torchrec/distributed/benchmark/benchmark_train_sparsenn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from torchrec.distributed.train_pipeline import (
4343
TrainPipeline,
4444
TrainPipelineBase,
45+
TrainPipelineFusedSparseDist,
4546
TrainPipelineSparseDist,
4647
)
4748
from torchrec.distributed.train_pipeline.train_pipelines import (
@@ -106,6 +107,7 @@ def generate_pipeline(
106107
] = {
107108
"base": TrainPipelineBase,
108109
"sparse": TrainPipelineSparseDist,
110+
"fused": TrainPipelineFusedSparseDist,
109111
"semi": TrainPipelineSemiSync,
110112
"prefetch": PrefetchTrainPipelineSparseDist,
111113
}

torchrec/distributed/train_pipeline/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
TorchCompileConfig, # noqa
1616
TrainPipeline, # noqa
1717
TrainPipelineBase, # noqa
18+
TrainPipelineFusedSparseDist, # noqa
1819
TrainPipelinePT2, # noqa
1920
TrainPipelineSparseDist, # noqa
2021
TrainPipelineSparseDistCompAutograd, # noqa

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 172 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
EmbeddingPipelinedForward,
4848
EmbeddingTrainPipelineContext,
4949
In,
50+
InSyncEmbeddingPipelinedForward,
5051
Out,
5152
PipelinedForward,
5253
PipelinedPostproc,
@@ -540,8 +541,7 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
540541

541542
# modify the (sharded) sparse module forward, and invoke the first part of input_dist
542543
self._init_pipelined_modules(
543-
# pyre-ignore [6]
544-
self.batches[0],
544+
self.batches[0], # pyre-ignore [6]
545545
self.contexts[0],
546546
self._pipelined_forward_type,
547547
)
@@ -803,6 +803,176 @@ def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
803803
self._batch_ip1 = self._copy_batch_to_gpu(dataloader_iter)
804804

805805

806+
class TrainPipelineFusedSparseDist(TrainPipelineSparseDist[In, Out]):
807+
"""
808+
This pipeline modifies TrainPipelineSparseDist by running embedding lookup in a
809+
separate stream so that it can overlap with the previous optimizer. The assumption
810+
made here is the embedding is updated in the fused backward (fused-TBE) so the
811+
embedding lookup can start immediately after backward is completed without dependency
812+
on the optiimzer.
813+
814+
NOTE: This assumption is not true if there is feature processor(s).
815+
NOTE: This pipeline is still experimental, users should always run NE parity tests.
816+
817+
batch i+0:
818+
ShardedModule.compute_and_output_dist - uses emb_lookup CUDA stream
819+
forward (without emb lookup)
820+
backward and optimizer
821+
batch i+1:
822+
ShardedModule.input_dist() - uses data_dist CUDA stream
823+
batch i+2:
824+
copy batch to device
825+
826+
`ShardedModule.input_dist()` is only done for top-level modules in the call graph.
827+
To be considered a top-level module, a module can only depend on 'getattr' calls on
828+
input.
829+
830+
Input model must be symbolically traceable with the exception of `ShardedModule` and
831+
`DistributedDataParallel` modules.
832+
833+
Args:
834+
model (torch.nn.Module): model to pipeline.
835+
optimizer (torch.optim.Optimizer): optimizer to use.
836+
device (torch.device): device where device transfer, sparse data dist, and
837+
forward/backward pass will happen.
838+
execute_all_batches (bool): executes remaining batches in pipeline after
839+
exhausting dataloader iterator.
840+
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
841+
TODO: pipeline_postproc, custom_model_fwd, strict
842+
use_emb_lookuo_stream (bool): if true invoke the compute_and_output_dist
843+
(for batch i+1) using a new stream, else re-using the data_dist stream
844+
"""
845+
846+
# The PipelinedForward class that is used in _rewrite_model
847+
_pipelined_forward_type = InSyncEmbeddingPipelinedForward # pyre-ignore
848+
849+
def __init__(
850+
self,
851+
model: torch.nn.Module,
852+
optimizer: torch.optim.Optimizer,
853+
device: torch.device,
854+
execute_all_batches: bool = True,
855+
apply_jit: bool = False,
856+
pipeline_postproc: bool = True,
857+
custom_model_fwd: Optional[
858+
Callable[[Optional[In]], Tuple[torch.Tensor, Out]]
859+
] = None,
860+
strict: bool = False,
861+
use_emb_lookup_stream: bool = False, # default False explained below
862+
) -> None:
863+
super().__init__(
864+
model=model,
865+
optimizer=optimizer,
866+
device=device,
867+
execute_all_batches=execute_all_batches,
868+
apply_jit=apply_jit,
869+
context_type=EmbeddingTrainPipelineContext,
870+
pipeline_postproc=pipeline_postproc,
871+
custom_model_fwd=custom_model_fwd,
872+
)
873+
if use_emb_lookup_stream:
874+
self._emb_lookup_stream: Optional[torch.Stream] = (
875+
(torch.get_device_module(device).Stream())
876+
if device.type in ["cuda", "mtia"]
877+
else None
878+
)
879+
else:
880+
# default to False: re-use data_dist stream for emb lookup to reduce CUDA memory footprint
881+
# due to Caching Allocator reserving the memory for each stream
882+
self._emb_lookup_stream = self._data_dist_stream
883+
884+
def wait_embedding_lookup(self) -> None:
885+
"""
886+
Waits on the embedding lookup requests to get the embedding lookup tensors requests
887+
"""
888+
current_stream = torch.get_device_module(self._device).current_stream()
889+
current_stream.wait_stream(self._emb_lookup_stream)
890+
891+
def start_embedding_lookup(
892+
self,
893+
batch: Optional[In],
894+
context: EmbeddingTrainPipelineContext,
895+
) -> None:
896+
"""
897+
Waits for batch to finish getting copied to GPU, then starts the input dist. This Event based vesrion.
898+
"""
899+
if batch is None:
900+
return
901+
902+
with record_function(f"## start_embedding_lookup {context.index} ##"):
903+
current_stream = torch.get_device_module(self._device).current_stream()
904+
with self._stream_context(self._emb_lookup_stream):
905+
for module in self._pipelined_modules:
906+
_start_embedding_lookup(
907+
module,
908+
context,
909+
source_stream=self._emb_lookup_stream,
910+
target_stream=current_stream,
911+
stream_context=self._stream_context,
912+
)
913+
914+
def progress(self, dataloader_iter: Iterator[In]) -> Out:
915+
"""
916+
For TrainPipelineSparseDist, we assume the max pipelined batches == 3 (capacity):
917+
batches[0]: i+0 batch, fwd/bwd/opt (expecting output_dist)
918+
batches[1]: i+1 batch, for input_dist (expecting copied to device), and compute_and_output_dist
919+
batches[2]: i+2 batch, for copy_batch_to_gpu (expecting non-exhausted dataloader iter)
920+
"""
921+
922+
# attach the model just in case the user forgets to call it, especially when the user
923+
# pauses the pipeline.progress and detach the model for other purpose.
924+
if not self._model_attached:
925+
self.attach(self._model)
926+
927+
# fill the pipeline is only needed for the beginning when the pipeline (batches) is empty
928+
self.fill_pipeline(dataloader_iter)
929+
930+
# here is the expected stop after exhausting all batches
931+
if not self.batches:
932+
raise StopIteration
933+
934+
# TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only)
935+
self._set_module_context(self.contexts[0])
936+
937+
# start embedding_lookup so it can overlap with previous optimizer
938+
# pyre-ignore [6]
939+
self.start_embedding_lookup(self.batches[0], self.contexts[0])
940+
941+
if self._model.training:
942+
with record_function("## zero_grad ##"):
943+
self._optimizer.zero_grad()
944+
945+
# wait for batches[0] being available on device, this should always be completed since
946+
# the input_dist of batches[0] has be invoked in previous iter. TODO: fact check
947+
self._wait_for_batch()
948+
949+
if len(self.batches) >= 2:
950+
# invoke splits all_to_all comms (first part of input_dist)
951+
self.start_sparse_data_dist(self.batches[1], self.contexts[1])
952+
953+
# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here
954+
self.enqueue_batch(dataloader_iter)
955+
956+
# forward
957+
with record_function("## forward ##"):
958+
losses, output = self._model_fwd(self.batches[0])
959+
960+
if len(self.batches) >= 2:
961+
# invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
962+
self.wait_sparse_data_dist(self.contexts[1])
963+
964+
if self._model.training:
965+
# backward
966+
self._backward(losses)
967+
968+
# update
969+
with record_function("## optimizer ##"):
970+
self._optimizer.step()
971+
972+
self.dequeue_batch()
973+
return output
974+
975+
806976
class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]):
807977
"""
808978
Novel method for RecSys model training by leveraging "Semi-Synchronous" training,

torchrec/distributed/train_pipeline/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,20 @@ def detach_embeddings(
670670
self._context.detached_embedding_tensors.append(detached_tensors)
671671

672672

673+
class InSyncEmbeddingPipelinedForward(EmbeddingPipelinedForward):
674+
"""
675+
This pipeline is used in TrainPipelineFusedSparseDist
676+
"""
677+
678+
def detach_embeddings(
679+
self,
680+
embeddings: Union[Dict[str, JaggedTensor], KeyedTensor],
681+
cur_stream: torch.Stream,
682+
) -> None:
683+
# doing nothing
684+
pass
685+
686+
673687
class PrefetchPipelinedForward(BaseForward[PrefetchTrainPipelineContext]):
674688
"""
675689
This pipeline is used in PrefetchTrainPipelineSparseDist
@@ -853,6 +867,7 @@ def _start_data_dist(
853867
PipelinedForward,
854868
PrefetchPipelinedForward,
855869
EmbeddingPipelinedForward,
870+
InSyncEmbeddingPipelinedForward,
856871
),
857872
)
858873

0 commit comments

Comments
 (0)