diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index 4977477f3..e8994d51a 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -10,6 +10,7 @@ import abc import contextlib import logging +import threading from collections import deque from dataclasses import dataclass from typing import ( @@ -469,6 +470,9 @@ def __init__( dmp_collection_sync_interval_batches: Optional[int] = 1, enqueue_batch_after_forward: bool = False, inplace_copy_batch_to_gpu: bool = False, + # To overcome a host bottleneck, enable 'pipeline_thread' option for pipeline overlapping. + # Enabling this option starts an additional thread. + pipeline_thread: bool = False, ) -> None: self._model = model self._optimizer = optimizer @@ -549,6 +553,17 @@ def __init__( self._batch_ip2: Optional[In] = None self._context: TrainPipelineContext = context_type(version=0) + # parallel pipeline + self.pipeline_thread = pipeline_thread + if self.pipeline_thread: + self.helper_thread = threading.Thread( + target=self.progress_helper, daemon=True + ) + self.helper_go = threading.Event() + self.helper_done = threading.Event() + self._cur_dliter = None + self.helper_thread.start() + def detach(self) -> torch.nn.Module: """ Detaches the model from sparse data dist (SDD) pipeline. A user might want to get @@ -692,6 +707,30 @@ def _backward(self, losses: torch.Tensor) -> None: with record_function(f"## backward {batch_id} ##"): torch.sum(losses, dim=0).backward() + def progress_helper(self): + while True: + self.helper_go.wait() + if self.helper_go.is_set(): + self.helper_go.clear() + # preprocess next context + self.pipeline_prepare() + self.helper_done.set() + + def pipeline_prepare(self): + self.fill_pipeline(self._cur_dliter) + if not self.batches: + return False + self._wait_for_batch() + if len(self.batches) >= 2: + # invoke splits all_to_all comms (first part of input_dist) + self.start_sparse_data_dist(self.batches[1], self.contexts[1]) + self.enqueue_batch(self._cur_dliter) + + if len(self.batches) >= 2: + # invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist) + self.wait_sparse_data_dist(self.contexts[1]) + return True + def progress(self, dataloader_iter: Iterator[In]) -> Out: """ For TrainPipelineSparseDist, we assume the max pipelined batches == 3 (capacity): @@ -706,46 +745,29 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: if not self._model_attached: self.attach(self._model) - # fill the pipeline is only needed for the beginning when the pipeline (batches) is empty - self.fill_pipeline(dataloader_iter) - - # here is the expected stop after exhausting all batches - if not self.batches: - raise StopIteration + self._cur_dliter = dataloader_iter + # get first context in mainthread + if len(self.contexts) == 0 or not self.pipeline_thread: + if not self.pipeline_prepare(): + raise StopIteration # TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only) - self._set_module_context(self.contexts[0]) + cur_batch = self.batches.popleft() + cur_context = self.contexts.popleft() + self._set_module_context(cur_batch) if self._model.training: with record_function("## zero_grad ##"): self._optimizer.zero_grad() - # wait for batches[0] being available on device, this should always be completed since - # the input_dist of batches[0] has be invoked in previous iter. TODO: fact check - self._wait_for_batch() - - if len(self.batches) >= 2: - # invoke splits all_to_all comms (first part of input_dist) - self.start_sparse_data_dist(self.batches[1], self.contexts[1]) - - if not self._enqueue_batch_after_forward: - # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here - self.enqueue_batch(dataloader_iter) - # forward - with record_function(f"## forward {self.contexts[0].index} ##"): + with record_function(f"## forward {cur_context.index} ##"): self._state = PipelineState.CALL_FWD - losses, output = self._model_fwd(self.batches[0]) - - if self._enqueue_batch_after_forward: - # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here. - # Start this step after the forward of batch i, so that the H2D copy doesn't compete - # for pcie bandwidth with embedding lookup from UVM/UVM_CACHING. - self.enqueue_batch(dataloader_iter) + losses, output = self._model_fwd(cur_batch) - if len(self.batches) >= 2: - # invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist) - self.wait_sparse_data_dist(self.contexts[1]) + # run helper thread after forward + if self.pipeline_thread: + self.helper_go.set() if self._model.training: # backward @@ -755,14 +777,19 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: self.sync_embeddings( self._model, self._dmp_collection_sync_interval_batches, - self.contexts[0], + cur_context, ) # update with record_function(f"## optimizer {self.contexts[0].index} ##"): self._optimizer.step() - self.dequeue_batch() + if self.pipeline_thread: + self.helper_done.wait() + self.helper_done.clear() + # update PipelinedForward context to match next forward pass + if len(self.batches) >= 1: + self._set_module_context(self.contexts[0]) return output def _create_context(self) -> TrainPipelineContext: