|
47 | 47 | EmbeddingPipelinedForward,
|
48 | 48 | EmbeddingTrainPipelineContext,
|
49 | 49 | In,
|
| 50 | + InSyncEmbeddingPipelinedForward, |
50 | 51 | Out,
|
51 | 52 | PipelinedForward,
|
52 | 53 | PipelinedPostproc,
|
@@ -540,8 +541,7 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
|
540 | 541 |
|
541 | 542 | # modify the (sharded) sparse module forward, and invoke the first part of input_dist
|
542 | 543 | self._init_pipelined_modules(
|
543 |
| - # pyre-ignore [6] |
544 |
| - self.batches[0], |
| 544 | + self.batches[0], # pyre-ignore [6] |
545 | 545 | self.contexts[0],
|
546 | 546 | self._pipelined_forward_type,
|
547 | 547 | )
|
@@ -803,6 +803,176 @@ def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
|
803 | 803 | self._batch_ip1 = self._copy_batch_to_gpu(dataloader_iter)
|
804 | 804 |
|
805 | 805 |
|
| 806 | +class TrainPipelineFusedSparseDist(TrainPipelineSparseDist[In, Out]): |
| 807 | + """ |
| 808 | + This pipeline modifies TrainPipelineSparseDist by running embedding lookup in a |
| 809 | + separate stream so that it can overlap with the previous optimizer. The assumption |
| 810 | + made here is the embedding is updated in the fused backward (fused-TBE) so the |
| 811 | + embedding lookup can start immediately after backward is completed without dependency |
| 812 | + on the optiimzer. |
| 813 | +
|
| 814 | + NOTE: This assumption is not true if there is feature processor(s). |
| 815 | + NOTE: This pipeline is still experimental, users should always run NE parity tests. |
| 816 | +
|
| 817 | + batch i+0: |
| 818 | + ShardedModule.compute_and_output_dist - uses emb_lookup CUDA stream |
| 819 | + forward (without emb lookup) |
| 820 | + backward and optimizer |
| 821 | + batch i+1: |
| 822 | + ShardedModule.input_dist() - uses data_dist CUDA stream |
| 823 | + batch i+2: |
| 824 | + copy batch to device |
| 825 | +
|
| 826 | + `ShardedModule.input_dist()` is only done for top-level modules in the call graph. |
| 827 | + To be considered a top-level module, a module can only depend on 'getattr' calls on |
| 828 | + input. |
| 829 | +
|
| 830 | + Input model must be symbolically traceable with the exception of `ShardedModule` and |
| 831 | + `DistributedDataParallel` modules. |
| 832 | +
|
| 833 | + Args: |
| 834 | + model (torch.nn.Module): model to pipeline. |
| 835 | + optimizer (torch.optim.Optimizer): optimizer to use. |
| 836 | + device (torch.device): device where device transfer, sparse data dist, and |
| 837 | + forward/backward pass will happen. |
| 838 | + execute_all_batches (bool): executes remaining batches in pipeline after |
| 839 | + exhausting dataloader iterator. |
| 840 | + apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules. |
| 841 | + TODO: pipeline_postproc, custom_model_fwd, strict |
| 842 | + use_emb_lookuo_stream (bool): if true invoke the compute_and_output_dist |
| 843 | + (for batch i+1) using a new stream, else re-using the data_dist stream |
| 844 | + """ |
| 845 | + |
| 846 | + # The PipelinedForward class that is used in _rewrite_model |
| 847 | + _pipelined_forward_type = InSyncEmbeddingPipelinedForward # pyre-ignore |
| 848 | + |
| 849 | + def __init__( |
| 850 | + self, |
| 851 | + model: torch.nn.Module, |
| 852 | + optimizer: torch.optim.Optimizer, |
| 853 | + device: torch.device, |
| 854 | + execute_all_batches: bool = True, |
| 855 | + apply_jit: bool = False, |
| 856 | + pipeline_postproc: bool = True, |
| 857 | + custom_model_fwd: Optional[ |
| 858 | + Callable[[Optional[In]], Tuple[torch.Tensor, Out]] |
| 859 | + ] = None, |
| 860 | + strict: bool = False, |
| 861 | + use_emb_lookup_stream: bool = False, # default False explained below |
| 862 | + ) -> None: |
| 863 | + super().__init__( |
| 864 | + model=model, |
| 865 | + optimizer=optimizer, |
| 866 | + device=device, |
| 867 | + execute_all_batches=execute_all_batches, |
| 868 | + apply_jit=apply_jit, |
| 869 | + context_type=EmbeddingTrainPipelineContext, |
| 870 | + pipeline_postproc=pipeline_postproc, |
| 871 | + custom_model_fwd=custom_model_fwd, |
| 872 | + ) |
| 873 | + if use_emb_lookup_stream: |
| 874 | + self._emb_lookup_stream: Optional[torch.Stream] = ( |
| 875 | + (torch.get_device_module(device).Stream()) |
| 876 | + if device.type in ["cuda", "mtia"] |
| 877 | + else None |
| 878 | + ) |
| 879 | + else: |
| 880 | + # default to False: re-use data_dist stream for emb lookup to reduce CUDA memory footprint |
| 881 | + # due to Caching Allocator reserving the memory for each stream |
| 882 | + self._emb_lookup_stream = self._data_dist_stream |
| 883 | + |
| 884 | + def wait_embedding_lookup(self) -> None: |
| 885 | + """ |
| 886 | + Waits on the embedding lookup requests to get the embedding lookup tensors requests |
| 887 | + """ |
| 888 | + current_stream = torch.get_device_module(self._device).current_stream() |
| 889 | + current_stream.wait_stream(self._emb_lookup_stream) |
| 890 | + |
| 891 | + def start_embedding_lookup( |
| 892 | + self, |
| 893 | + batch: Optional[In], |
| 894 | + context: EmbeddingTrainPipelineContext, |
| 895 | + ) -> None: |
| 896 | + """ |
| 897 | + Waits for batch to finish getting copied to GPU, then starts the input dist. This Event based vesrion. |
| 898 | + """ |
| 899 | + if batch is None: |
| 900 | + return |
| 901 | + |
| 902 | + with record_function(f"## start_embedding_lookup {context.index} ##"): |
| 903 | + current_stream = torch.get_device_module(self._device).current_stream() |
| 904 | + with self._stream_context(self._emb_lookup_stream): |
| 905 | + for module in self._pipelined_modules: |
| 906 | + _start_embedding_lookup( |
| 907 | + module, |
| 908 | + context, |
| 909 | + source_stream=self._emb_lookup_stream, |
| 910 | + target_stream=current_stream, |
| 911 | + stream_context=self._stream_context, |
| 912 | + ) |
| 913 | + |
| 914 | + def progress(self, dataloader_iter: Iterator[In]) -> Out: |
| 915 | + """ |
| 916 | + For TrainPipelineSparseDist, we assume the max pipelined batches == 3 (capacity): |
| 917 | + batches[0]: i+0 batch, fwd/bwd/opt (expecting output_dist) |
| 918 | + batches[1]: i+1 batch, for input_dist (expecting copied to device), and compute_and_output_dist |
| 919 | + batches[2]: i+2 batch, for copy_batch_to_gpu (expecting non-exhausted dataloader iter) |
| 920 | + """ |
| 921 | + |
| 922 | + # attach the model just in case the user forgets to call it, especially when the user |
| 923 | + # pauses the pipeline.progress and detach the model for other purpose. |
| 924 | + if not self._model_attached: |
| 925 | + self.attach(self._model) |
| 926 | + |
| 927 | + # fill the pipeline is only needed for the beginning when the pipeline (batches) is empty |
| 928 | + self.fill_pipeline(dataloader_iter) |
| 929 | + |
| 930 | + # here is the expected stop after exhausting all batches |
| 931 | + if not self.batches: |
| 932 | + raise StopIteration |
| 933 | + |
| 934 | + # TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only) |
| 935 | + self._set_module_context(self.contexts[0]) |
| 936 | + |
| 937 | + # start embedding_lookup so it can overlap with previous optimizer |
| 938 | + # pyre-ignore [6] |
| 939 | + self.start_embedding_lookup(self.batches[0], self.contexts[0]) |
| 940 | + |
| 941 | + if self._model.training: |
| 942 | + with record_function("## zero_grad ##"): |
| 943 | + self._optimizer.zero_grad() |
| 944 | + |
| 945 | + # wait for batches[0] being available on device, this should always be completed since |
| 946 | + # the input_dist of batches[0] has be invoked in previous iter. TODO: fact check |
| 947 | + self._wait_for_batch() |
| 948 | + |
| 949 | + if len(self.batches) >= 2: |
| 950 | + # invoke splits all_to_all comms (first part of input_dist) |
| 951 | + self.start_sparse_data_dist(self.batches[1], self.contexts[1]) |
| 952 | + |
| 953 | + # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here |
| 954 | + self.enqueue_batch(dataloader_iter) |
| 955 | + |
| 956 | + # forward |
| 957 | + with record_function("## forward ##"): |
| 958 | + losses, output = self._model_fwd(self.batches[0]) |
| 959 | + |
| 960 | + if len(self.batches) >= 2: |
| 961 | + # invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist) |
| 962 | + self.wait_sparse_data_dist(self.contexts[1]) |
| 963 | + |
| 964 | + if self._model.training: |
| 965 | + # backward |
| 966 | + self._backward(losses) |
| 967 | + |
| 968 | + # update |
| 969 | + with record_function("## optimizer ##"): |
| 970 | + self._optimizer.step() |
| 971 | + |
| 972 | + self.dequeue_batch() |
| 973 | + return output |
| 974 | + |
| 975 | + |
806 | 976 | class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]):
|
807 | 977 | """
|
808 | 978 | Novel method for RecSys model training by leveraging "Semi-Synchronous" training,
|
|
0 commit comments