1313import math
1414from collections import defaultdict , OrderedDict
1515from dataclasses import dataclass
16- from typing import Any , DefaultDict , Dict , Iterator , List , Optional , Type , Union
16+ from typing import (
17+ Any ,
18+ Callable ,
19+ DefaultDict ,
20+ Dict ,
21+ Iterator ,
22+ List ,
23+ Optional ,
24+ Type ,
25+ Union ,
26+ )
1727
1828import torch
1929import torch .distributed as dist
5868 ShardingType ,
5969)
6070from torchrec .distributed .utils import append_prefix
71+ from torchrec .modules .embedding_configs import BaseEmbeddingConfig
72+ from torchrec .modules .hash_mc_modules import HashZchManagedCollisionModule
6173from torchrec .modules .mc_modules import ManagedCollisionCollection
6274from torchrec .modules .utils import construct_jagged_tensors
6375from torchrec .sparse .jagged_tensor import JaggedTensor , KeyedJaggedTensor
@@ -215,6 +227,9 @@ def __init__(
215227
216228 self ._feature_to_table : Dict [str , str ] = module ._feature_to_table
217229 self ._table_to_features : Dict [str , List [str ]] = module ._table_to_features
230+ self ._table_name_to_config : Dict [str , BaseEmbeddingConfig ] = (
231+ module ._table_name_to_config
232+ )
218233 self ._has_uninitialized_input_dists : bool = True
219234 self ._input_dists : List [nn .Module ] = []
220235 self ._managed_collision_modules = nn .ModuleDict ()
@@ -223,6 +238,9 @@ def __init__(
223238 self ._create_output_dists ()
224239 self ._use_index_dedup = use_index_dedup
225240 self ._initialize_torch_state ()
241+ self .post_lookup_tracker_fn : Optional [
242+ Callable [[KeyedJaggedTensor , torch .Tensor ], None ]
243+ ] = None
226244
227245 def _initialize_torch_state (self ) -> None :
228246 self ._model_parallel_mc_buffer_name_to_sharded_tensor = OrderedDict ()
@@ -732,6 +750,17 @@ def compute(
732750 mc_input = mcm .remap (mc_input )
733751 mc_input = self .global_to_local_index (mc_input )
734752 output .update (mc_input )
753+ if isinstance (
754+ mcm ,
755+ HashZchManagedCollisionModule ,
756+ ):
757+ if self .post_lookup_tracker_fn is not None :
758+ self .post_lookup_tracker_fn (
759+ KeyedJaggedTensor .from_jt_dict (output ),
760+ mcm ._hash_zch_identities .index_select (
761+ dim = 0 , index = mc_input [table ].values ()
762+ ),
763+ )
735764 values = torch .cat ([jt .values () for jt in output .values ()])
736765 else :
737766 table : str = tables [0 ]
@@ -750,6 +779,12 @@ def compute(
750779 mc_input = mcm .remap (mc_input )
751780 mc_input = self .global_to_local_index (mc_input )
752781 values = mc_input [table ].values ()
782+ if isinstance (mcm , HashZchManagedCollisionModule ):
783+ if self .post_lookup_tracker_fn is not None :
784+ self .post_lookup_tracker_fn (
785+ KeyedJaggedTensor .from_jt_dict (mc_input ),
786+ mcm ._hash_zch_identities .index_select (dim = 0 , index = values ),
787+ )
753788
754789 remapped_kjts .append (
755790 KeyedJaggedTensor (
@@ -840,6 +875,24 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
840875 def unsharded_module_type (self ) -> Type [ManagedCollisionCollection ]:
841876 return ManagedCollisionCollection
842877
878+ def register_post_lookup_tracker_fn (
879+ self ,
880+ record_fn : Callable [[KeyedJaggedTensor , torch .Tensor ], None ],
881+ ) -> None :
882+ """
883+ Register a function to be called after lookup is done. This is used for
884+ tracking the lookup results and optimizer states.
885+
886+ Args:
887+ record_fn (Callable[[KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done.
888+
889+ """
890+ if self .post_lookup_tracker_fn is not None :
891+ logger .warning (
892+ "[ModelDeltaTracker] Custom record function already defined, overriding with new callable"
893+ )
894+ self .post_lookup_tracker_fn = record_fn
895+
843896
844897class ManagedCollisionCollectionSharder (
845898 BaseEmbeddingSharder [ManagedCollisionCollection ]
0 commit comments