From 82d92a586cb3b0c8604e84b9c8bb83f3330b5f46 Mon Sep 17 00:00:00 2001 From: Ali Afzal Date: Fri, 7 Nov 2025 06:26:23 -0800 Subject: [PATCH] Add raw_id_tracker for tracking hash_zch_identities of MPZCH module (#3501) Summary: This diff introduces a new `RawIdTracker` class that extends TorchRec's model delta tracking infra to capture and track raw hash identities from MCC modules during training. This is specifically required for tracking raw ids for MPZCH tables. internal This is needed to support MPZCH modules for Raw embedding streaming. Mode details : https://docs.google.com/document/d/1KEHwiXKLgXwRIdDFBYopjX3OiP3mRLM24Qkbiiu-TgE/edit?tab=t.0#bookmark=id.lhhgee2cs6ld Reviewed By: FriedCosey, chouxi Differential Revision: D84920167 --- torchrec/distributed/embedding.py | 2 +- torchrec/distributed/embedding_types.py | 18 +- torchrec/distributed/embeddingbag.py | 2 +- torchrec/distributed/mc_modules.py | 24 +- .../distributed/model_tracker/delta_store.py | 18 +- .../model_tracker/model_delta_tracker.py | 9 + .../model_tracker/trackers/__init__.py | 15 + .../model_tracker/trackers/raw_id_tracker.py | 304 ++++++++++++++++++ torchrec/distributed/model_tracker/types.py | 1 + 9 files changed, 386 insertions(+), 7 deletions(-) create mode 100644 torchrec/distributed/model_tracker/trackers/__init__.py create mode 100644 torchrec/distributed/model_tracker/trackers/raw_id_tracker.py diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 2f6c6b9ed..5e9a96ca0 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -1605,7 +1605,7 @@ def compute_and_output_dist( ): embs = lookup(features) if self.post_lookup_tracker_fn is not None: - self.post_lookup_tracker_fn(features, embs, self) + self.post_lookup_tracker_fn(features, embs, self, None) with maybe_annotate_embedding_event( EmbeddingEvent.OUTPUT_DIST, self._module_fqn, sharding_type diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index d0a5ef920..5e92fa528 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -391,7 +391,15 @@ def __init__( self._lookups: List[nn.Module] = [] self._output_dists: List[nn.Module] = [] self.post_lookup_tracker_fn: Optional[ - Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None] + Callable[ + [ + KeyedJaggedTensor, + torch.Tensor, + Optional[nn.Module], + Optional[torch.Tensor], + ], + None, + ] ] = None self.post_odist_tracker_fn: Optional[Callable[..., None]] = None @@ -445,7 +453,13 @@ def train(self, mode: bool = True): # pyre-ignore[3] def register_post_lookup_tracker_fn( self, record_fn: Callable[ - [KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None + [ + KeyedJaggedTensor, + torch.Tensor, + Optional[nn.Module], + Optional[torch.Tensor], + ], + None, ], ) -> None: """ diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index fd6117884..754b9e6fa 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -1671,7 +1671,7 @@ def compute_and_output_dist( ): embs = lookup(features) if self.post_lookup_tracker_fn is not None: - self.post_lookup_tracker_fn(features, embs, self) + self.post_lookup_tracker_fn(features, embs, self, None) with maybe_annotate_embedding_event( EmbeddingEvent.OUTPUT_DIST, diff --git a/torchrec/distributed/mc_modules.py b/torchrec/distributed/mc_modules.py index 2dfc8f0a1..017025796 100644 --- a/torchrec/distributed/mc_modules.py +++ b/torchrec/distributed/mc_modules.py @@ -238,7 +238,15 @@ def __init__( self._use_index_dedup = use_index_dedup self._initialize_torch_state() self.post_lookup_tracker_fn: Optional[ - Callable[[KeyedJaggedTensor, torch.Tensor], None] + Callable[ + [ + KeyedJaggedTensor, + torch.Tensor, + Optional[nn.Module], + Optional[torch.Tensor], + ], + None, + ] ] = None def _initialize_torch_state(self) -> None: @@ -756,6 +764,8 @@ def compute( if self.post_lookup_tracker_fn is not None: self.post_lookup_tracker_fn( KeyedJaggedTensor.from_jt_dict(mc_input), + torch.empty(0), + None, mcm._hash_zch_identities.index_select( dim=0, index=mc_input[table].values() ), @@ -782,6 +792,8 @@ def compute( if self.post_lookup_tracker_fn is not None: self.post_lookup_tracker_fn( KeyedJaggedTensor.from_jt_dict(mc_input), + torch.empty(0), + None, mcm._hash_zch_identities.index_select(dim=0, index=values), ) @@ -876,7 +888,15 @@ def unsharded_module_type(self) -> Type[ManagedCollisionCollection]: def register_post_lookup_tracker_fn( self, - record_fn: Callable[[KeyedJaggedTensor, torch.Tensor], None], + record_fn: Callable[ + [ + KeyedJaggedTensor, + torch.Tensor, + Optional[nn.Module], + Optional[torch.Tensor], + ], + None, + ], ) -> None: """ Register a function to be called after lookup is done. This is used for diff --git a/torchrec/distributed/model_tracker/delta_store.py b/torchrec/distributed/model_tracker/delta_store.py index bd2ee1b27..cfac71b8c 100644 --- a/torchrec/distributed/model_tracker/delta_store.py +++ b/torchrec/distributed/model_tracker/delta_store.py @@ -91,6 +91,7 @@ def append( fqn: str, ids: torch.Tensor, states: Optional[torch.Tensor], + raw_ids: Optional[torch.Tensor] = None, ) -> None: """ Append a batch of ids and states to the store for a specific table. @@ -162,10 +163,11 @@ def append( fqn: str, ids: torch.Tensor, states: Optional[torch.Tensor], + raw_ids: Optional[torch.Tensor] = None, ) -> None: table_fqn_lookup = self.per_fqn_lookups.get(fqn, []) table_fqn_lookup.append( - IndexedLookup(batch_idx=batch_idx, ids=ids, states=states) + IndexedLookup(batch_idx=batch_idx, ids=ids, states=states, raw_ids=raw_ids) ) self.per_fqn_lookups[fqn] = table_fqn_lookup @@ -224,6 +226,20 @@ def compact(self, start_idx: int, end_idx: int) -> None: ) self.per_fqn_lookups = new_per_fqn_lookups + def get_indexed_lookups( + self, start_idx: int, end_idx: int + ) -> Dict[str, List[IndexedLookup]]: + r""" + Return all unique/delta ids per table from the Delta Store. + """ + per_fqn_lookups: Dict[str, List[IndexedLookup]] = {} + for table_fqn, lookups in self.per_fqn_lookups.items(): + indexices = [h.batch_idx for h in lookups] + index_l = bisect_left(indexices, start_idx) + index_r = bisect_left(indexices, end_idx) + per_fqn_lookups[table_fqn] = lookups[index_l:index_r] + return per_fqn_lookups + def get_unique(self, from_idx: int = 0) -> Dict[str, UniqueRows]: r""" Return all unique/delta ids per table from the Delta Store. diff --git a/torchrec/distributed/model_tracker/model_delta_tracker.py b/torchrec/distributed/model_tracker/model_delta_tracker.py index 50a9bd250..3e444ee4e 100644 --- a/torchrec/distributed/model_tracker/model_delta_tracker.py +++ b/torchrec/distributed/model_tracker/model_delta_tracker.py @@ -84,6 +84,7 @@ def record_lookup( kjt: KeyedJaggedTensor, states: torch.Tensor, emb_module: Optional[nn.Module] = None, + raw_ids: Optional[torch.Tensor] = None, ) -> None: """ Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states. @@ -131,6 +132,13 @@ def clear(self, consumer: Optional[str] = None) -> None: """ pass + @abstractmethod + def step(self) -> None: + """ + Advance the batch index for all consumers. + """ + pass + class ModelDeltaTrackerTrec(ModelDeltaTracker): r""" @@ -244,6 +252,7 @@ def record_lookup( kjt: KeyedJaggedTensor, states: torch.Tensor, emb_module: Optional[nn.Module] = None, + raw_ids: Optional[torch.Tensor] = None, ) -> None: """ Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states. diff --git a/torchrec/distributed/model_tracker/trackers/__init__.py b/torchrec/distributed/model_tracker/trackers/__init__.py new file mode 100644 index 000000000..07a1ae891 --- /dev/null +++ b/torchrec/distributed/model_tracker/trackers/__init__.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +"""MPZCH Raw ID Tracker +""" + +from torchrec.distributed.model_tracker.trackers.raw_id_tracker import ( # noqa + RawIdTracker, +) diff --git a/torchrec/distributed/model_tracker/trackers/raw_id_tracker.py b/torchrec/distributed/model_tracker/trackers/raw_id_tracker.py new file mode 100644 index 000000000..42eeb90e9 --- /dev/null +++ b/torchrec/distributed/model_tracker/trackers/raw_id_tracker.py @@ -0,0 +1,304 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +# pyre-strict + +import logging +from collections import Counter, OrderedDict +from typing import Dict, Iterable, List, Optional, Tuple + +import torch + +from torch import nn +from torchrec.distributed.embedding_types import ( + KeyedJaggedTensor, + ShardedEmbeddingTable, +) +from torchrec.distributed.mc_embeddingbag import ( + ShardedManagedCollisionEmbeddingBagCollection, +) +from torchrec.distributed.mc_modules import ShardedManagedCollisionCollection +from torchrec.distributed.model_tracker.delta_store import DeltaStoreTrec + +from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTracker +from torchrec.distributed.model_tracker.types import IndexedLookup, UniqueRows + +logger: logging.Logger = logging.getLogger(__name__) + +SUPPORTED_MODULES = (ShardedManagedCollisionCollection,) + + +class RawIdTracker(ModelDeltaTracker): + def __init__( + self, + model: nn.Module, + delete_on_read: bool = True, + fqns_to_skip: Iterable[str] = (), + ) -> None: + self._model = model + self._consumers: Optional[List[str]] = None + self._delete_on_read = delete_on_read + self._fqn_to_feature_map: Dict[str, List[str]] = {} + self._fqns_to_skip: Iterable[str] = fqns_to_skip + + self.curr_batch_idx: int = 0 + self.curr_compact_index: int = 0 + + # from module FQN to SUPPORTED_MODULES + self.tracked_modules: Dict[str, nn.Module] = {} + self.table_to_fqn: Dict[str, str] = {} + self.feature_to_fqn: Dict[str, str] = {} + # Generate the mapping from FQN to feature names. + self.fqn_to_feature_names() + # Validate is the mode is supported for the given module and initialize tracker functions + self._validate_and_init_tracker_fns() + # init TBE tracker wrapper and register consumer ids + self._init_tbe_tracker_wrapper(self._model) + + # per_consumer_batch_idx is used to track the batch index for each consumer. + # This is used to retrieve the delta values for a given consumer as well as + # start_ids for compaction window. + + # Note: For raw id tracking, this has to be assigned after the _init_tbe_tracker_wrapper() + # call as _init_tbe_tracker_wrapper is setting up consumers for TBEs + + self.per_consumer_batch_idx: Dict[str, int] = { + c: -1 for c in (self._consumers or [self.DEFAULT_CONSUMER]) + } + + self.store: DeltaStoreTrec = DeltaStoreTrec() + + # Mapping feature name to corresponding FQNs. This is used for retrieving + # the FQN associated with a given feature name in record_lookup(). + for fqn, feature_names in self._fqn_to_feature_map.items(): + for feature_name in feature_names: + if feature_name in self.feature_to_fqn: + logger.warning( + f"Duplicate feature name: {feature_name} in fqn {fqn}" + ) + continue + self.feature_to_fqn[feature_name] = fqn + logger.info(f"feature_to_fqn: {self.feature_to_fqn}") + + def step(self) -> None: + # Move batch index forward for all consumers. + self.curr_batch_idx += 1 + + def _should_skip_fqn(self, fqn: str) -> bool: + split_fqn = fqn.split(".") + # Skipping partial FQNs present in fqns_to_skip + # TODO: Validate if we need to support more complex patterns for skipping fqns + should_skip = False + for fqn_to_skip in self._fqns_to_skip: + if fqn_to_skip in split_fqn: + logger.info(f"Skipping {fqn} because it is part of fqns_to_skip") + should_skip = True + break + return should_skip + + def _should_track_table( + self, embedding_tables: List[ShardedEmbeddingTable] + ) -> bool: + should_track = True + for table_config in embedding_tables: + for fqn_to_skip in self._fqns_to_skip: + if fqn_to_skip in table_config.name: + should_track = False + break + return should_track + + def fqn_to_feature_names(self) -> Dict[str, List[str]]: + """ + Returns a mapping of FQN to feature names from all Supported Modules [EmbeddingCollection and EmbeddingBagCollection] present in the given model. + """ + if (self._fqn_to_feature_map is not None) and len(self._fqn_to_feature_map) > 0: + return self._fqn_to_feature_map + + table_to_feature_names: Dict[str, List[str]] = OrderedDict() + for fqn, named_module in self._model.named_modules(): + if self._should_skip_fqn(fqn): + continue + # Using FQNs of the embedding and mapping them to features as state_dict() API uses these to key states. + if isinstance(named_module, SUPPORTED_MODULES): + should_track_module = True + for table_name, config in named_module._table_name_to_config.items(): + for fqn_to_skip in self._fqns_to_skip: + if fqn_to_skip in table_name: + should_track_module = False + logger.info( + f"Found {table_name} for {fqn} with features {config.feature_names} should_track_module: {should_track_module}" + ) + table_to_feature_names[table_name] = config.feature_names + if should_track_module: + self.tracked_modules[self._clean_fqn_fn(fqn)] = named_module + for table_name in table_to_feature_names: + # Using the split FQN to get the exact table name matching. Otherwise, checking "table_name in fqn" + # will incorrectly match fqn with all the table names that have the same prefix + split_fqn = fqn.split(".") + if table_name in split_fqn: + embedding_fqn = self._clean_fqn_fn(fqn) + if table_name in self.table_to_fqn: + # Sanity check for validating that we don't have more then one table mapping to same fqn. + logger.warning( + f"Override {self.table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}" + ) + self.table_to_fqn[table_name] = embedding_fqn + logger.info(f"Table to fqn: {self.table_to_fqn}") + flatten_names = [ + name for names in table_to_feature_names.values() for name in names + ] + # TODO: Validate if there is a better way to handle duplicate feature names. + # Logging a warning if duplicate feature names are found across tables, but continue execution as this could be a valid case. + if len(set(flatten_names)) != len(flatten_names): + counts = Counter(flatten_names) + duplicates = [item for item, count in counts.items() if count > 1] + logger.warning(f"duplicate feature names found: {duplicates}") + + fqn_to_feature_names: Dict[str, List[str]] = OrderedDict() + for table_name in table_to_feature_names: + if table_name not in self.table_to_fqn: + # This is likely unexpected, where we can't locate the FQN associated with this table. + logger.warning( + f"Table {table_name} not found in {self.table_to_fqn}, skipping" + ) + continue + fqn_to_feature_names[self.table_to_fqn[table_name]] = ( + table_to_feature_names[table_name] + ) + self._fqn_to_feature_map = fqn_to_feature_names + return fqn_to_feature_names + + def record_lookup( + self, + kjt: KeyedJaggedTensor, + states: torch.Tensor, + emb_module: Optional[nn.Module] = None, + raw_ids: Optional[torch.Tensor] = None, + ) -> None: + per_table_ids: Dict[str, List[torch.Tensor]] = {} + per_table_raw_ids: Dict[str, List[torch.Tensor]] = {} + + # Skip storing invalid input or raw ids + if ( + raw_ids is None + or (kjt.values().numel() == 0) + or not (raw_ids.numel() % kjt.values().numel() == 0) + ): + return + + embeddings_2d = raw_ids.view(kjt.values().numel(), -1) + + offset: int = 0 + for key in kjt.keys(): + table_fqn = self.table_to_fqn[key] + ids_list: List[torch.Tensor] = per_table_ids.get(table_fqn, []) + emb_list: List[torch.Tensor] = per_table_raw_ids.get(table_fqn, []) + + ids = kjt[key].values() + ids_list.append(ids) + emb_list.append(embeddings_2d[offset : offset + ids.numel()]) + offset += ids.numel() + + per_table_ids[table_fqn] = ids_list + per_table_raw_ids[table_fqn] = emb_list + + for table_fqn, ids_list in per_table_ids.items(): + self.store.append( + batch_idx=self.curr_batch_idx, + fqn=table_fqn, + ids=torch.cat(ids_list), + states=None, + raw_ids=torch.cat(per_table_raw_ids[table_fqn]), + ) + + def _clean_fqn_fn(self, fqn: str) -> str: + # strip FQN prefixes added by DMP and other TorchRec operations to match state dict FQN + # handles both "_dmp_wrapped_module.module." and "module." prefixes + prefixes_to_strip = ["_dmp_wrapped_module.module.", "module."] + for prefix in prefixes_to_strip: + if fqn.startswith(prefix): + return fqn[len(prefix) :] + return fqn + + def _validate_and_init_tracker_fns(self) -> None: + "To validate the mode is supported for the given module" + for module in self.tracked_modules.values(): + if isinstance(module, SUPPORTED_MODULES): + # register post lookup function + module.register_post_lookup_tracker_fn(self.record_lookup) + + def _init_tbe_tracker_wrapper(self, module: nn.Module) -> None: + for fqn, named_module in self._model.named_modules(): + if self._should_skip_fqn(fqn): + continue + if isinstance(named_module, ShardedManagedCollisionEmbeddingBagCollection): + for lookup in named_module._embedding_module._lookups: + # pyre-ignore + for emb in lookup._emb_modules: + # Only initialize tracker for TBEs that contain tables we want to track + should_track_table = self._should_track_table( + emb._config.embedding_tables + ) + if should_track_table: + emb.init_raw_id_tracker( + self.get_indexed_lookups, + self.delete, + ) + if self._consumers is None: + self._consumers = [] + self._consumers.append(emb._emb_module.uuid) + + def get_unique_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tensor]: + return {} + + def get_unique( + self, + consumer: Optional[str] = None, + top_percentage: Optional[float] = 1.0, + per_table_percentage: Optional[Dict[str, Tuple[float, str]]] = None, + sorted_by_indices: Optional[bool] = True, + ) -> Dict[str, UniqueRows]: + return {} + + def clear(self, consumer: Optional[str] = None) -> None: + pass + + def get_indexed_lookups( + self, + tables: List[str], + consumer: Optional[str] = None, + ) -> Dict[str, List[torch.Tensor]]: + raw_id_per_table: Dict[str, List[torch.Tensor]] = {} + consumer = consumer or self.DEFAULT_CONSUMER + assert ( + consumer in self.per_consumer_batch_idx + ), f"consumer {consumer} not present in {self.per_consumer_batch_idx.values()}" + + index_end: int = self.curr_batch_idx + 1 + index_start = self.per_consumer_batch_idx[consumer] + indexed_lookups = {} + if index_start < index_end: + self.per_consumer_batch_idx[consumer] = index_end + indexed_lookups = self.store.get_indexed_lookups(index_start, index_end) + + for table in tables: + raw_ids_list = [] + fqn = self.table_to_fqn[table] + if fqn in indexed_lookups: + for indexed_lookup in indexed_lookups[fqn]: + if indexed_lookup.raw_ids is not None: + raw_ids_list.append(indexed_lookup.raw_ids) + raw_id_per_table[table] = raw_ids_list + + if self._delete_on_read: + self.store.delete(up_to_idx=min(self.per_consumer_batch_idx.values())) + + return raw_id_per_table + + def delete(self, up_to_idx: Optional[int]) -> None: + self.store.delete(up_to_idx) diff --git a/torchrec/distributed/model_tracker/types.py b/torchrec/distributed/model_tracker/types.py index 3fbb70063..1bf00e6db 100644 --- a/torchrec/distributed/model_tracker/types.py +++ b/torchrec/distributed/model_tracker/types.py @@ -23,6 +23,7 @@ class IndexedLookup: batch_idx: int ids: torch.Tensor states: Optional[torch.Tensor] + raw_ids: Optional[torch.Tensor] = None compact: bool = False