Skip to content

Commit 907a5ae

Browse files
aliafzalfacebook-github-bot
authored andcommitted
add tracker_fn to MCC
Summary: Adding post lookup tracker function within MMC module to allow tracking of hash_zch_identities with delta tracker. internal This is needed to support MPZCH modules for Raw embedding streaming. Mode details : https://docs.google.com/document/d/1KEHwiXKLgXwRIdDFBYopjX3OiP3mRLM24Qkbiiu-TgE/edit?tab=t.0#bookmark=id.lhhgee2cs6ld Differential Revision: D84920121
1 parent 75dfb7f commit 907a5ae

File tree

2 files changed

+58
-3
lines changed

2 files changed

+58
-3
lines changed

torchrec/distributed/mc_modules.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,17 @@
1313
import math
1414
from collections import defaultdict, OrderedDict
1515
from dataclasses import dataclass
16-
from typing import Any, DefaultDict, Dict, Iterator, List, Optional, Type, Union
16+
from typing import (
17+
Any,
18+
Callable,
19+
DefaultDict,
20+
Dict,
21+
Iterator,
22+
List,
23+
Optional,
24+
Type,
25+
Union,
26+
)
1727

1828
import torch
1929
import torch.distributed as dist
@@ -58,6 +68,8 @@
5868
ShardingType,
5969
)
6070
from torchrec.distributed.utils import append_prefix
71+
from torchrec.modules.embedding_configs import BaseEmbeddingConfig
72+
from torchrec.modules.hash_mc_modules import HashZchManagedCollisionModule
6173
from torchrec.modules.mc_modules import ManagedCollisionCollection
6274
from torchrec.modules.utils import construct_jagged_tensors
6375
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
@@ -215,6 +227,9 @@ def __init__(
215227

216228
self._feature_to_table: Dict[str, str] = module._feature_to_table
217229
self._table_to_features: Dict[str, List[str]] = module._table_to_features
230+
self._table_name_to_config: Dict[str, BaseEmbeddingConfig] = (
231+
module._table_name_to_config
232+
)
218233
self._has_uninitialized_input_dists: bool = True
219234
self._input_dists: List[nn.Module] = []
220235
self._managed_collision_modules = nn.ModuleDict()
@@ -223,6 +238,9 @@ def __init__(
223238
self._create_output_dists()
224239
self._use_index_dedup = use_index_dedup
225240
self._initialize_torch_state()
241+
self.post_lookup_tracker_fn: Optional[
242+
Callable[[KeyedJaggedTensor, torch.Tensor], None]
243+
] = None
226244

227245
def _initialize_torch_state(self) -> None:
228246
self._model_parallel_mc_buffer_name_to_sharded_tensor = OrderedDict()
@@ -732,6 +750,17 @@ def compute(
732750
mc_input = mcm.remap(mc_input)
733751
mc_input = self.global_to_local_index(mc_input)
734752
output.update(mc_input)
753+
if isinstance(
754+
mcm,
755+
HashZchManagedCollisionModule,
756+
):
757+
if self.post_lookup_tracker_fn is not None:
758+
self.post_lookup_tracker_fn(
759+
KeyedJaggedTensor.from_jt_dict(output),
760+
mcm._hash_zch_identities.index_select(
761+
dim=0, index=mc_input[table].values()
762+
),
763+
)
735764
values = torch.cat([jt.values() for jt in output.values()])
736765
else:
737766
table: str = tables[0]
@@ -750,6 +779,12 @@ def compute(
750779
mc_input = mcm.remap(mc_input)
751780
mc_input = self.global_to_local_index(mc_input)
752781
values = mc_input[table].values()
782+
if isinstance(mcm, HashZchManagedCollisionModule):
783+
if self.post_lookup_tracker_fn is not None:
784+
self.post_lookup_tracker_fn(
785+
KeyedJaggedTensor.from_jt_dict(mc_input),
786+
mcm._hash_zch_identities.index_select(dim=0, index=values),
787+
)
753788

754789
remapped_kjts.append(
755790
KeyedJaggedTensor(
@@ -840,6 +875,24 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
840875
def unsharded_module_type(self) -> Type[ManagedCollisionCollection]:
841876
return ManagedCollisionCollection
842877

878+
def register_post_lookup_tracker_fn(
879+
self,
880+
record_fn: Callable[[KeyedJaggedTensor, torch.Tensor], None],
881+
) -> None:
882+
"""
883+
Register a function to be called after lookup is done. This is used for
884+
tracking the lookup results and optimizer states.
885+
886+
Args:
887+
record_fn (Callable[[KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done.
888+
889+
"""
890+
if self.post_lookup_tracker_fn is not None:
891+
logger.warning(
892+
"[ModelDeltaTracker] Custom record function already defined, overriding with new callable"
893+
)
894+
self.post_lookup_tracker_fn = record_fn
895+
843896

844897
class ManagedCollisionCollectionSharder(
845898
BaseEmbeddingSharder[ManagedCollisionCollection]

torchrec/modules/mc_modules.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,9 +357,11 @@ def __init__(
357357
len(features) for features in self._table_to_features.values()
358358
]
359359

360-
table_to_config = {config.name: config for config in embedding_configs}
360+
self._table_name_to_config: Dict[str, BaseEmbeddingConfig] = {
361+
config.name: config for config in embedding_configs
362+
}
361363

362-
for name, config in table_to_config.items():
364+
for name, config in self._table_name_to_config.items():
363365
if name not in managed_collision_modules:
364366
raise ValueError(
365367
f"Table {name} is not present in managed_collision_modules"

0 commit comments

Comments
 (0)