diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 2f6c6b9ed..5e9a96ca0 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -1605,7 +1605,7 @@ def compute_and_output_dist( ): embs = lookup(features) if self.post_lookup_tracker_fn is not None: - self.post_lookup_tracker_fn(features, embs, self) + self.post_lookup_tracker_fn(features, embs, self, None) with maybe_annotate_embedding_event( EmbeddingEvent.OUTPUT_DIST, self._module_fqn, sharding_type diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index d0a5ef920..5e92fa528 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -391,7 +391,15 @@ def __init__( self._lookups: List[nn.Module] = [] self._output_dists: List[nn.Module] = [] self.post_lookup_tracker_fn: Optional[ - Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None] + Callable[ + [ + KeyedJaggedTensor, + torch.Tensor, + Optional[nn.Module], + Optional[torch.Tensor], + ], + None, + ] ] = None self.post_odist_tracker_fn: Optional[Callable[..., None]] = None @@ -445,7 +453,13 @@ def train(self, mode: bool = True): # pyre-ignore[3] def register_post_lookup_tracker_fn( self, record_fn: Callable[ - [KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None + [ + KeyedJaggedTensor, + torch.Tensor, + Optional[nn.Module], + Optional[torch.Tensor], + ], + None, ], ) -> None: """ diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index fd6117884..754b9e6fa 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -1671,7 +1671,7 @@ def compute_and_output_dist( ): embs = lookup(features) if self.post_lookup_tracker_fn is not None: - self.post_lookup_tracker_fn(features, embs, self) + self.post_lookup_tracker_fn(features, embs, self, None) with maybe_annotate_embedding_event( EmbeddingEvent.OUTPUT_DIST, diff --git a/torchrec/distributed/mc_modules.py b/torchrec/distributed/mc_modules.py index 2dfc8f0a1..017025796 100644 --- a/torchrec/distributed/mc_modules.py +++ b/torchrec/distributed/mc_modules.py @@ -238,7 +238,15 @@ def __init__( self._use_index_dedup = use_index_dedup self._initialize_torch_state() self.post_lookup_tracker_fn: Optional[ - Callable[[KeyedJaggedTensor, torch.Tensor], None] + Callable[ + [ + KeyedJaggedTensor, + torch.Tensor, + Optional[nn.Module], + Optional[torch.Tensor], + ], + None, + ] ] = None def _initialize_torch_state(self) -> None: @@ -756,6 +764,8 @@ def compute( if self.post_lookup_tracker_fn is not None: self.post_lookup_tracker_fn( KeyedJaggedTensor.from_jt_dict(mc_input), + torch.empty(0), + None, mcm._hash_zch_identities.index_select( dim=0, index=mc_input[table].values() ), @@ -782,6 +792,8 @@ def compute( if self.post_lookup_tracker_fn is not None: self.post_lookup_tracker_fn( KeyedJaggedTensor.from_jt_dict(mc_input), + torch.empty(0), + None, mcm._hash_zch_identities.index_select(dim=0, index=values), ) @@ -876,7 +888,15 @@ def unsharded_module_type(self) -> Type[ManagedCollisionCollection]: def register_post_lookup_tracker_fn( self, - record_fn: Callable[[KeyedJaggedTensor, torch.Tensor], None], + record_fn: Callable[ + [ + KeyedJaggedTensor, + torch.Tensor, + Optional[nn.Module], + Optional[torch.Tensor], + ], + None, + ], ) -> None: """ Register a function to be called after lookup is done. This is used for diff --git a/torchrec/distributed/model_tracker/delta_store.py b/torchrec/distributed/model_tracker/delta_store.py index bd2ee1b27..cfac71b8c 100644 --- a/torchrec/distributed/model_tracker/delta_store.py +++ b/torchrec/distributed/model_tracker/delta_store.py @@ -91,6 +91,7 @@ def append( fqn: str, ids: torch.Tensor, states: Optional[torch.Tensor], + raw_ids: Optional[torch.Tensor] = None, ) -> None: """ Append a batch of ids and states to the store for a specific table. @@ -162,10 +163,11 @@ def append( fqn: str, ids: torch.Tensor, states: Optional[torch.Tensor], + raw_ids: Optional[torch.Tensor] = None, ) -> None: table_fqn_lookup = self.per_fqn_lookups.get(fqn, []) table_fqn_lookup.append( - IndexedLookup(batch_idx=batch_idx, ids=ids, states=states) + IndexedLookup(batch_idx=batch_idx, ids=ids, states=states, raw_ids=raw_ids) ) self.per_fqn_lookups[fqn] = table_fqn_lookup @@ -224,6 +226,20 @@ def compact(self, start_idx: int, end_idx: int) -> None: ) self.per_fqn_lookups = new_per_fqn_lookups + def get_indexed_lookups( + self, start_idx: int, end_idx: int + ) -> Dict[str, List[IndexedLookup]]: + r""" + Return all unique/delta ids per table from the Delta Store. + """ + per_fqn_lookups: Dict[str, List[IndexedLookup]] = {} + for table_fqn, lookups in self.per_fqn_lookups.items(): + indexices = [h.batch_idx for h in lookups] + index_l = bisect_left(indexices, start_idx) + index_r = bisect_left(indexices, end_idx) + per_fqn_lookups[table_fqn] = lookups[index_l:index_r] + return per_fqn_lookups + def get_unique(self, from_idx: int = 0) -> Dict[str, UniqueRows]: r""" Return all unique/delta ids per table from the Delta Store. diff --git a/torchrec/distributed/model_tracker/model_delta_tracker.py b/torchrec/distributed/model_tracker/model_delta_tracker.py index 50a9bd250..3e444ee4e 100644 --- a/torchrec/distributed/model_tracker/model_delta_tracker.py +++ b/torchrec/distributed/model_tracker/model_delta_tracker.py @@ -84,6 +84,7 @@ def record_lookup( kjt: KeyedJaggedTensor, states: torch.Tensor, emb_module: Optional[nn.Module] = None, + raw_ids: Optional[torch.Tensor] = None, ) -> None: """ Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states. @@ -131,6 +132,13 @@ def clear(self, consumer: Optional[str] = None) -> None: """ pass + @abstractmethod + def step(self) -> None: + """ + Advance the batch index for all consumers. + """ + pass + class ModelDeltaTrackerTrec(ModelDeltaTracker): r""" @@ -244,6 +252,7 @@ def record_lookup( kjt: KeyedJaggedTensor, states: torch.Tensor, emb_module: Optional[nn.Module] = None, + raw_ids: Optional[torch.Tensor] = None, ) -> None: """ Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states. diff --git a/torchrec/distributed/model_tracker/trackers/__init__.py b/torchrec/distributed/model_tracker/trackers/__init__.py new file mode 100644 index 000000000..07a1ae891 --- /dev/null +++ b/torchrec/distributed/model_tracker/trackers/__init__.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +"""MPZCH Raw ID Tracker +""" + +from torchrec.distributed.model_tracker.trackers.raw_id_tracker import ( # noqa + RawIdTracker, +) diff --git a/torchrec/distributed/model_tracker/trackers/raw_id_tracker.py b/torchrec/distributed/model_tracker/trackers/raw_id_tracker.py new file mode 100644 index 000000000..42eeb90e9 --- /dev/null +++ b/torchrec/distributed/model_tracker/trackers/raw_id_tracker.py @@ -0,0 +1,304 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +# pyre-strict + +import logging +from collections import Counter, OrderedDict +from typing import Dict, Iterable, List, Optional, Tuple + +import torch + +from torch import nn +from torchrec.distributed.embedding_types import ( + KeyedJaggedTensor, + ShardedEmbeddingTable, +) +from torchrec.distributed.mc_embeddingbag import ( + ShardedManagedCollisionEmbeddingBagCollection, +) +from torchrec.distributed.mc_modules import ShardedManagedCollisionCollection +from torchrec.distributed.model_tracker.delta_store import DeltaStoreTrec + +from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTracker +from torchrec.distributed.model_tracker.types import IndexedLookup, UniqueRows + +logger: logging.Logger = logging.getLogger(__name__) + +SUPPORTED_MODULES = (ShardedManagedCollisionCollection,) + + +class RawIdTracker(ModelDeltaTracker): + def __init__( + self, + model: nn.Module, + delete_on_read: bool = True, + fqns_to_skip: Iterable[str] = (), + ) -> None: + self._model = model + self._consumers: Optional[List[str]] = None + self._delete_on_read = delete_on_read + self._fqn_to_feature_map: Dict[str, List[str]] = {} + self._fqns_to_skip: Iterable[str] = fqns_to_skip + + self.curr_batch_idx: int = 0 + self.curr_compact_index: int = 0 + + # from module FQN to SUPPORTED_MODULES + self.tracked_modules: Dict[str, nn.Module] = {} + self.table_to_fqn: Dict[str, str] = {} + self.feature_to_fqn: Dict[str, str] = {} + # Generate the mapping from FQN to feature names. + self.fqn_to_feature_names() + # Validate is the mode is supported for the given module and initialize tracker functions + self._validate_and_init_tracker_fns() + # init TBE tracker wrapper and register consumer ids + self._init_tbe_tracker_wrapper(self._model) + + # per_consumer_batch_idx is used to track the batch index for each consumer. + # This is used to retrieve the delta values for a given consumer as well as + # start_ids for compaction window. + + # Note: For raw id tracking, this has to be assigned after the _init_tbe_tracker_wrapper() + # call as _init_tbe_tracker_wrapper is setting up consumers for TBEs + + self.per_consumer_batch_idx: Dict[str, int] = { + c: -1 for c in (self._consumers or [self.DEFAULT_CONSUMER]) + } + + self.store: DeltaStoreTrec = DeltaStoreTrec() + + # Mapping feature name to corresponding FQNs. This is used for retrieving + # the FQN associated with a given feature name in record_lookup(). + for fqn, feature_names in self._fqn_to_feature_map.items(): + for feature_name in feature_names: + if feature_name in self.feature_to_fqn: + logger.warning( + f"Duplicate feature name: {feature_name} in fqn {fqn}" + ) + continue + self.feature_to_fqn[feature_name] = fqn + logger.info(f"feature_to_fqn: {self.feature_to_fqn}") + + def step(self) -> None: + # Move batch index forward for all consumers. + self.curr_batch_idx += 1 + + def _should_skip_fqn(self, fqn: str) -> bool: + split_fqn = fqn.split(".") + # Skipping partial FQNs present in fqns_to_skip + # TODO: Validate if we need to support more complex patterns for skipping fqns + should_skip = False + for fqn_to_skip in self._fqns_to_skip: + if fqn_to_skip in split_fqn: + logger.info(f"Skipping {fqn} because it is part of fqns_to_skip") + should_skip = True + break + return should_skip + + def _should_track_table( + self, embedding_tables: List[ShardedEmbeddingTable] + ) -> bool: + should_track = True + for table_config in embedding_tables: + for fqn_to_skip in self._fqns_to_skip: + if fqn_to_skip in table_config.name: + should_track = False + break + return should_track + + def fqn_to_feature_names(self) -> Dict[str, List[str]]: + """ + Returns a mapping of FQN to feature names from all Supported Modules [EmbeddingCollection and EmbeddingBagCollection] present in the given model. + """ + if (self._fqn_to_feature_map is not None) and len(self._fqn_to_feature_map) > 0: + return self._fqn_to_feature_map + + table_to_feature_names: Dict[str, List[str]] = OrderedDict() + for fqn, named_module in self._model.named_modules(): + if self._should_skip_fqn(fqn): + continue + # Using FQNs of the embedding and mapping them to features as state_dict() API uses these to key states. + if isinstance(named_module, SUPPORTED_MODULES): + should_track_module = True + for table_name, config in named_module._table_name_to_config.items(): + for fqn_to_skip in self._fqns_to_skip: + if fqn_to_skip in table_name: + should_track_module = False + logger.info( + f"Found {table_name} for {fqn} with features {config.feature_names} should_track_module: {should_track_module}" + ) + table_to_feature_names[table_name] = config.feature_names + if should_track_module: + self.tracked_modules[self._clean_fqn_fn(fqn)] = named_module + for table_name in table_to_feature_names: + # Using the split FQN to get the exact table name matching. Otherwise, checking "table_name in fqn" + # will incorrectly match fqn with all the table names that have the same prefix + split_fqn = fqn.split(".") + if table_name in split_fqn: + embedding_fqn = self._clean_fqn_fn(fqn) + if table_name in self.table_to_fqn: + # Sanity check for validating that we don't have more then one table mapping to same fqn. + logger.warning( + f"Override {self.table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}" + ) + self.table_to_fqn[table_name] = embedding_fqn + logger.info(f"Table to fqn: {self.table_to_fqn}") + flatten_names = [ + name for names in table_to_feature_names.values() for name in names + ] + # TODO: Validate if there is a better way to handle duplicate feature names. + # Logging a warning if duplicate feature names are found across tables, but continue execution as this could be a valid case. + if len(set(flatten_names)) != len(flatten_names): + counts = Counter(flatten_names) + duplicates = [item for item, count in counts.items() if count > 1] + logger.warning(f"duplicate feature names found: {duplicates}") + + fqn_to_feature_names: Dict[str, List[str]] = OrderedDict() + for table_name in table_to_feature_names: + if table_name not in self.table_to_fqn: + # This is likely unexpected, where we can't locate the FQN associated with this table. + logger.warning( + f"Table {table_name} not found in {self.table_to_fqn}, skipping" + ) + continue + fqn_to_feature_names[self.table_to_fqn[table_name]] = ( + table_to_feature_names[table_name] + ) + self._fqn_to_feature_map = fqn_to_feature_names + return fqn_to_feature_names + + def record_lookup( + self, + kjt: KeyedJaggedTensor, + states: torch.Tensor, + emb_module: Optional[nn.Module] = None, + raw_ids: Optional[torch.Tensor] = None, + ) -> None: + per_table_ids: Dict[str, List[torch.Tensor]] = {} + per_table_raw_ids: Dict[str, List[torch.Tensor]] = {} + + # Skip storing invalid input or raw ids + if ( + raw_ids is None + or (kjt.values().numel() == 0) + or not (raw_ids.numel() % kjt.values().numel() == 0) + ): + return + + embeddings_2d = raw_ids.view(kjt.values().numel(), -1) + + offset: int = 0 + for key in kjt.keys(): + table_fqn = self.table_to_fqn[key] + ids_list: List[torch.Tensor] = per_table_ids.get(table_fqn, []) + emb_list: List[torch.Tensor] = per_table_raw_ids.get(table_fqn, []) + + ids = kjt[key].values() + ids_list.append(ids) + emb_list.append(embeddings_2d[offset : offset + ids.numel()]) + offset += ids.numel() + + per_table_ids[table_fqn] = ids_list + per_table_raw_ids[table_fqn] = emb_list + + for table_fqn, ids_list in per_table_ids.items(): + self.store.append( + batch_idx=self.curr_batch_idx, + fqn=table_fqn, + ids=torch.cat(ids_list), + states=None, + raw_ids=torch.cat(per_table_raw_ids[table_fqn]), + ) + + def _clean_fqn_fn(self, fqn: str) -> str: + # strip FQN prefixes added by DMP and other TorchRec operations to match state dict FQN + # handles both "_dmp_wrapped_module.module." and "module." prefixes + prefixes_to_strip = ["_dmp_wrapped_module.module.", "module."] + for prefix in prefixes_to_strip: + if fqn.startswith(prefix): + return fqn[len(prefix) :] + return fqn + + def _validate_and_init_tracker_fns(self) -> None: + "To validate the mode is supported for the given module" + for module in self.tracked_modules.values(): + if isinstance(module, SUPPORTED_MODULES): + # register post lookup function + module.register_post_lookup_tracker_fn(self.record_lookup) + + def _init_tbe_tracker_wrapper(self, module: nn.Module) -> None: + for fqn, named_module in self._model.named_modules(): + if self._should_skip_fqn(fqn): + continue + if isinstance(named_module, ShardedManagedCollisionEmbeddingBagCollection): + for lookup in named_module._embedding_module._lookups: + # pyre-ignore + for emb in lookup._emb_modules: + # Only initialize tracker for TBEs that contain tables we want to track + should_track_table = self._should_track_table( + emb._config.embedding_tables + ) + if should_track_table: + emb.init_raw_id_tracker( + self.get_indexed_lookups, + self.delete, + ) + if self._consumers is None: + self._consumers = [] + self._consumers.append(emb._emb_module.uuid) + + def get_unique_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tensor]: + return {} + + def get_unique( + self, + consumer: Optional[str] = None, + top_percentage: Optional[float] = 1.0, + per_table_percentage: Optional[Dict[str, Tuple[float, str]]] = None, + sorted_by_indices: Optional[bool] = True, + ) -> Dict[str, UniqueRows]: + return {} + + def clear(self, consumer: Optional[str] = None) -> None: + pass + + def get_indexed_lookups( + self, + tables: List[str], + consumer: Optional[str] = None, + ) -> Dict[str, List[torch.Tensor]]: + raw_id_per_table: Dict[str, List[torch.Tensor]] = {} + consumer = consumer or self.DEFAULT_CONSUMER + assert ( + consumer in self.per_consumer_batch_idx + ), f"consumer {consumer} not present in {self.per_consumer_batch_idx.values()}" + + index_end: int = self.curr_batch_idx + 1 + index_start = self.per_consumer_batch_idx[consumer] + indexed_lookups = {} + if index_start < index_end: + self.per_consumer_batch_idx[consumer] = index_end + indexed_lookups = self.store.get_indexed_lookups(index_start, index_end) + + for table in tables: + raw_ids_list = [] + fqn = self.table_to_fqn[table] + if fqn in indexed_lookups: + for indexed_lookup in indexed_lookups[fqn]: + if indexed_lookup.raw_ids is not None: + raw_ids_list.append(indexed_lookup.raw_ids) + raw_id_per_table[table] = raw_ids_list + + if self._delete_on_read: + self.store.delete(up_to_idx=min(self.per_consumer_batch_idx.values())) + + return raw_id_per_table + + def delete(self, up_to_idx: Optional[int]) -> None: + self.store.delete(up_to_idx) diff --git a/torchrec/distributed/model_tracker/types.py b/torchrec/distributed/model_tracker/types.py index 3fbb70063..1bf00e6db 100644 --- a/torchrec/distributed/model_tracker/types.py +++ b/torchrec/distributed/model_tracker/types.py @@ -23,6 +23,7 @@ class IndexedLookup: batch_idx: int ids: torch.Tensor states: Optional[torch.Tensor] + raw_ids: Optional[torch.Tensor] = None compact: bool = False