diff --git a/requirements.txt b/requirements.txt index 6239d0d90..fc79635f5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,6 +16,7 @@ tqdm usort parameterized PyYAML +psutil # for tests # https://github.com/pytorch/pytorch/blob/b96b1e8cff029bb0a73283e6e7f6cc240313f1dc/requirements.txt#L3 diff --git a/torchrec/distributed/mc_modules.py b/torchrec/distributed/mc_modules.py index 54b54d5a1..2dfc8f0a1 100644 --- a/torchrec/distributed/mc_modules.py +++ b/torchrec/distributed/mc_modules.py @@ -13,7 +13,17 @@ import math from collections import defaultdict, OrderedDict from dataclasses import dataclass -from typing import Any, DefaultDict, Dict, Iterator, List, Optional, Type, Union +from typing import ( + Any, + Callable, + DefaultDict, + Dict, + Iterator, + List, + Optional, + Type, + Union, +) import torch import torch.distributed as dist @@ -58,6 +68,7 @@ ShardingType, ) from torchrec.distributed.utils import append_prefix +from torchrec.modules.embedding_configs import BaseEmbeddingConfig from torchrec.modules.mc_modules import ManagedCollisionCollection from torchrec.modules.utils import construct_jagged_tensors from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor @@ -215,6 +226,9 @@ def __init__( self._feature_to_table: Dict[str, str] = module._feature_to_table self._table_to_features: Dict[str, List[str]] = module._table_to_features + self._table_name_to_config: Dict[str, BaseEmbeddingConfig] = ( + module._table_name_to_config + ) self._has_uninitialized_input_dists: bool = True self._input_dists: List[nn.Module] = [] self._managed_collision_modules = nn.ModuleDict() @@ -223,6 +237,9 @@ def __init__( self._create_output_dists() self._use_index_dedup = use_index_dedup self._initialize_torch_state() + self.post_lookup_tracker_fn: Optional[ + Callable[[KeyedJaggedTensor, torch.Tensor], None] + ] = None def _initialize_torch_state(self) -> None: self._model_parallel_mc_buffer_name_to_sharded_tensor = OrderedDict() @@ -732,6 +749,17 @@ def compute( mc_input = mcm.remap(mc_input) mc_input = self.global_to_local_index(mc_input) output.update(mc_input) + if hasattr( + mcm, + "_hash_zch_identities", + ): + if self.post_lookup_tracker_fn is not None: + self.post_lookup_tracker_fn( + KeyedJaggedTensor.from_jt_dict(mc_input), + mcm._hash_zch_identities.index_select( + dim=0, index=mc_input[table].values() + ), + ) values = torch.cat([jt.values() for jt in output.values()]) else: table: str = tables[0] @@ -750,6 +778,12 @@ def compute( mc_input = mcm.remap(mc_input) mc_input = self.global_to_local_index(mc_input) values = mc_input[table].values() + if hasattr(mcm, "_hash_zch_identities"): + if self.post_lookup_tracker_fn is not None: + self.post_lookup_tracker_fn( + KeyedJaggedTensor.from_jt_dict(mc_input), + mcm._hash_zch_identities.index_select(dim=0, index=values), + ) remapped_kjts.append( KeyedJaggedTensor( @@ -840,6 +874,24 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: def unsharded_module_type(self) -> Type[ManagedCollisionCollection]: return ManagedCollisionCollection + def register_post_lookup_tracker_fn( + self, + record_fn: Callable[[KeyedJaggedTensor, torch.Tensor], None], + ) -> None: + """ + Register a function to be called after lookup is done. This is used for + tracking the lookup results and optimizer states. + + Args: + record_fn (Callable[[KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done. + + """ + if self.post_lookup_tracker_fn is not None: + logger.warning( + "[ModelDeltaTracker] Custom record function already defined, overriding with new callable" + ) + self.post_lookup_tracker_fn = record_fn + class ManagedCollisionCollectionSharder( BaseEmbeddingSharder[ManagedCollisionCollection] diff --git a/torchrec/distributed/utils.py b/torchrec/distributed/utils.py index f2dbefa91..b255f6239 100644 --- a/torchrec/distributed/utils.py +++ b/torchrec/distributed/utils.py @@ -781,7 +781,8 @@ def modify_input_for_feature_processor( if is_collection: if hasattr(feature_processors, "pre_process_pipeline_input"): - feature_processors.pre_process_pipeline_input(features) # pyre-ignore[29] + # pyre-ignore[29] + feature_processors.pre_process_pipeline_input(features) else: logging.info( f"[Feature Processor Pipeline] Skipping pre_process_pipeline_input for feature processor {feature_processors=}" diff --git a/torchrec/modules/mc_modules.py b/torchrec/modules/mc_modules.py index 39cb2a847..54f0c671e 100644 --- a/torchrec/modules/mc_modules.py +++ b/torchrec/modules/mc_modules.py @@ -357,9 +357,11 @@ def __init__( len(features) for features in self._table_to_features.values() ] - table_to_config = {config.name: config for config in embedding_configs} + self._table_name_to_config: Dict[str, BaseEmbeddingConfig] = { + config.name: config for config in embedding_configs + } - for name, config in table_to_config.items(): + for name, config in self._table_name_to_config.items(): if name not in managed_collision_modules: raise ValueError( f"Table {name} is not present in managed_collision_modules"