Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1605,7 +1605,7 @@ def compute_and_output_dist(
):
embs = lookup(features)
if self.post_lookup_tracker_fn is not None:
self.post_lookup_tracker_fn(features, embs, self)
self.post_lookup_tracker_fn(features, embs, self, None)

with maybe_annotate_embedding_event(
EmbeddingEvent.OUTPUT_DIST, self._module_fqn, sharding_type
Expand Down
18 changes: 16 additions & 2 deletions torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,15 @@ def __init__(
self._lookups: List[nn.Module] = []
self._output_dists: List[nn.Module] = []
self.post_lookup_tracker_fn: Optional[
Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None]
Callable[
[
KeyedJaggedTensor,
torch.Tensor,
Optional[nn.Module],
Optional[torch.Tensor],
],
None,
]
] = None
self.post_odist_tracker_fn: Optional[Callable[..., None]] = None

Expand Down Expand Up @@ -445,7 +453,13 @@ def train(self, mode: bool = True): # pyre-ignore[3]
def register_post_lookup_tracker_fn(
self,
record_fn: Callable[
[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None
[
KeyedJaggedTensor,
torch.Tensor,
Optional[nn.Module],
Optional[torch.Tensor],
],
None,
],
) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1671,7 +1671,7 @@ def compute_and_output_dist(
):
embs = lookup(features)
if self.post_lookup_tracker_fn is not None:
self.post_lookup_tracker_fn(features, embs, self)
self.post_lookup_tracker_fn(features, embs, self, None)

with maybe_annotate_embedding_event(
EmbeddingEvent.OUTPUT_DIST,
Expand Down
24 changes: 22 additions & 2 deletions torchrec/distributed/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,15 @@ def __init__(
self._use_index_dedup = use_index_dedup
self._initialize_torch_state()
self.post_lookup_tracker_fn: Optional[
Callable[[KeyedJaggedTensor, torch.Tensor], None]
Callable[
[
KeyedJaggedTensor,
torch.Tensor,
Optional[nn.Module],
Optional[torch.Tensor],
],
None,
]
] = None

def _initialize_torch_state(self) -> None:
Expand Down Expand Up @@ -756,6 +764,8 @@ def compute(
if self.post_lookup_tracker_fn is not None:
self.post_lookup_tracker_fn(
KeyedJaggedTensor.from_jt_dict(mc_input),
torch.empty(0),
None,
mcm._hash_zch_identities.index_select(
dim=0, index=mc_input[table].values()
),
Expand All @@ -782,6 +792,8 @@ def compute(
if self.post_lookup_tracker_fn is not None:
self.post_lookup_tracker_fn(
KeyedJaggedTensor.from_jt_dict(mc_input),
torch.empty(0),
None,
mcm._hash_zch_identities.index_select(dim=0, index=values),
)

Expand Down Expand Up @@ -876,7 +888,15 @@ def unsharded_module_type(self) -> Type[ManagedCollisionCollection]:

def register_post_lookup_tracker_fn(
self,
record_fn: Callable[[KeyedJaggedTensor, torch.Tensor], None],
record_fn: Callable[
[
KeyedJaggedTensor,
torch.Tensor,
Optional[nn.Module],
Optional[torch.Tensor],
],
None,
],
) -> None:
"""
Register a function to be called after lookup is done. This is used for
Expand Down
18 changes: 17 additions & 1 deletion torchrec/distributed/model_tracker/delta_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def append(
fqn: str,
ids: torch.Tensor,
states: Optional[torch.Tensor],
raw_ids: Optional[torch.Tensor] = None,
) -> None:
"""
Append a batch of ids and states to the store for a specific table.
Expand Down Expand Up @@ -162,10 +163,11 @@ def append(
fqn: str,
ids: torch.Tensor,
states: Optional[torch.Tensor],
raw_ids: Optional[torch.Tensor] = None,
) -> None:
table_fqn_lookup = self.per_fqn_lookups.get(fqn, [])
table_fqn_lookup.append(
IndexedLookup(batch_idx=batch_idx, ids=ids, states=states)
IndexedLookup(batch_idx=batch_idx, ids=ids, states=states, raw_ids=raw_ids)
)
self.per_fqn_lookups[fqn] = table_fqn_lookup

Expand Down Expand Up @@ -224,6 +226,20 @@ def compact(self, start_idx: int, end_idx: int) -> None:
)
self.per_fqn_lookups = new_per_fqn_lookups

def get_indexed_lookups(
self, start_idx: int, end_idx: int
) -> Dict[str, List[IndexedLookup]]:
r"""
Return all unique/delta ids per table from the Delta Store.
"""
per_fqn_lookups: Dict[str, List[IndexedLookup]] = {}
for table_fqn, lookups in self.per_fqn_lookups.items():
indexices = [h.batch_idx for h in lookups]
index_l = bisect_left(indexices, start_idx)
index_r = bisect_left(indexices, end_idx)
per_fqn_lookups[table_fqn] = lookups[index_l:index_r]
return per_fqn_lookups

def get_unique(self, from_idx: int = 0) -> Dict[str, UniqueRows]:
r"""
Return all unique/delta ids per table from the Delta Store.
Expand Down
9 changes: 9 additions & 0 deletions torchrec/distributed/model_tracker/model_delta_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def record_lookup(
kjt: KeyedJaggedTensor,
states: torch.Tensor,
emb_module: Optional[nn.Module] = None,
raw_ids: Optional[torch.Tensor] = None,
) -> None:
"""
Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states.
Expand Down Expand Up @@ -131,6 +132,13 @@ def clear(self, consumer: Optional[str] = None) -> None:
"""
pass

@abstractmethod
def step(self) -> None:
"""
Advance the batch index for all consumers.
"""
pass


class ModelDeltaTrackerTrec(ModelDeltaTracker):
r"""
Expand Down Expand Up @@ -244,6 +252,7 @@ def record_lookup(
kjt: KeyedJaggedTensor,
states: torch.Tensor,
emb_module: Optional[nn.Module] = None,
raw_ids: Optional[torch.Tensor] = None,
) -> None:
"""
Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states.
Expand Down
15 changes: 15 additions & 0 deletions torchrec/distributed/model_tracker/trackers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

"""MPZCH Raw ID Tracker
"""

from torchrec.distributed.model_tracker.trackers.raw_id_tracker import ( # noqa
RawIdTracker,
)
Loading
Loading