5353 PrefetchPipelinedForward ,
5454)
5555from torchrec .distributed .train_pipeline .tracing import PipelinedPostproc
56+ from torchrec .distributed .train_pipeline .types import PipelineState
5657from torchrec .distributed .train_pipeline .utils import (
5758 _override_input_dist_forwards ,
5859 _pipeline_detach_model ,
7273from torchrec .sparse .jagged_tensor import KeyedJaggedTensor
7374from torchrec .streamable import Pipelineable
7475
76+
7577logger : logging .Logger = logging .getLogger (__name__ )
7678
7779# This is required to support older torch package export for older models
@@ -104,6 +106,10 @@ class TrainPipeline(abc.ABC, Generic[In, Out]):
104106 def progress (self , dataloader_iter : Iterator [In ]) -> Out :
105107 pass
106108
109+ def __init__ (self ) -> None :
110+ # pipeline state such as in foward, in backward etc, used in training recover scenarios
111+ self ._state : PipelineState = PipelineState .IDLE
112+
107113 def sync_embeddings (
108114 self ,
109115 model : torch .nn .Module ,
@@ -192,6 +198,7 @@ def __init__(
192198 self ._cur_batch : Optional [In ] = None
193199 self ._connected = False
194200 self ._data_iter_stopped = False
201+ super ().__init__ ()
195202
196203 def _reset_data_iter (self ) -> None :
197204 self ._connected = False
@@ -311,6 +318,7 @@ def __init__(
311318 self ._cur_batch : Optional [In ] = None
312319
313320 def progress (self , dataloader_iter : Iterator [In ]) -> Out :
321+ self ._state = PipelineState .IDLE
314322 if self ._iter == 0 :
315323 # Turn on sync collectives for PT2 pipeline.
316324 # To have similar logic between compiled/graph_break ranks.
@@ -335,6 +343,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
335343 self ._optimizer .zero_grad ()
336344
337345 with record_function ("## forward ##" ):
346+ self ._state = PipelineState .CALL_FWD
338347 if self ._iter == cc .compile_on_iter :
339348 logger .info ("Compiling model..." )
340349 if self ._pre_compile_fn :
@@ -362,6 +371,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
362371
363372 if self ._model .training :
364373 with record_function ("## backward ##" ):
374+ self ._state = PipelineState .CALL_BWD
365375 torch .sum (losses ).backward ()
366376
367377 with record_function ("## optimizer ##" ):
@@ -478,11 +488,13 @@ def __init__(
478488 self ._dmp_collection_sync_interval_batches = (
479489 dmp_collection_sync_interval_batches
480490 )
491+
481492 if self ._dmp_collection_sync_interval_batches is not None :
482493 logger .info (
483494 f"{ self .__class__ .__name__ } : [Sparse 2D] DMP collection will sync every "
484495 f"{ self ._dmp_collection_sync_interval_batches } batches"
485496 )
497+ super ().__init__ ()
486498
487499 # DEPRECATED FIELDS
488500 self ._batch_i : Optional [In ] = None
@@ -634,6 +646,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
634646 batches[2]: i+2 batch, for copy_batch_to_gpu (expecting non-exhausted dataloader iter)
635647 """
636648
649+ self ._state = PipelineState .IDLE
637650 # attach the model just in case the user forgets to call it, especially when the user
638651 # pauses the pipeline.progress and detach the model for other purpose.
639652 if not self ._model_attached :
@@ -667,6 +680,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
667680
668681 # forward
669682 with record_function ("## forward ##" ):
683+ self ._state = PipelineState .CALL_FWD
670684 losses , output = self ._model_fwd (self .batches [0 ])
671685
672686 if self ._enqueue_batch_after_forward :
@@ -681,6 +695,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
681695
682696 if self ._model .training :
683697 # backward
698+ self ._state = PipelineState .CALL_BWD
684699 self ._backward (losses )
685700
686701 self .sync_embeddings (
0 commit comments