From e99989b591c79af03cc57651b217ff357440c27c Mon Sep 17 00:00:00 2001 From: Malay Bag Date: Tue, 2 Dec 2025 00:17:04 -0800 Subject: [PATCH] Update ebc pruning logic when dynamo disabled nodes get partitioned into a new method. (#3566) Summary: As title Differential Revision: D86708767 --- torchrec/ir/tests/test_serializer.py | 104 +++++++++++++++++++++++++++ torchrec/ir/utils.py | 27 +++++-- 2 files changed, 125 insertions(+), 6 deletions(-) diff --git a/torchrec/ir/tests/test_serializer.py b/torchrec/ir/tests/test_serializer.py index 8acc36698..09bc2b387 100644 --- a/torchrec/ir/tests/test_serializer.py +++ b/torchrec/ir/tests/test_serializer.py @@ -15,6 +15,7 @@ import torch from torch import nn +from torch.fx.passes.utils.fuser_utils import fuse_by_partitions from torchrec.ir.serializer import JsonSerializer from torchrec.ir.utils import ( @@ -903,6 +904,109 @@ def forward( for key in eager_out.keys(): torch.testing.assert_close(deserialized_out[key], eager_out[key]) + def test_key_order_with_ebc_and_regroup_in_subgraph(self) -> None: + tb1_config = EmbeddingBagConfig( + name="t1", + embedding_dim=3, + num_embeddings=10, + feature_names=["f1"], + ) + tb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=4, + num_embeddings=10, + feature_names=["f2"], + ) + tb3_config = EmbeddingBagConfig( + name="t3", + embedding_dim=5, + num_embeddings=10, + feature_names=["f3"], + ) + id_list_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3", "f4", "f5"], + values=torch.tensor([0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1, 2]), + offsets=torch.tensor([0, 2, 2, 3, 4, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15]), + ) + ebc1 = EmbeddingBagCollection( + tables=[tb1_config, tb2_config, tb3_config], + is_weighted=False, + ) + ebc2 = EmbeddingBagCollection( + tables=[tb1_config, tb3_config, tb2_config], + is_weighted=False, + ) + ebc2.load_state_dict(ebc1.state_dict()) + regroup = KTRegroupAsDict([["f1", "f3"], ["f2"]], ["odd", "even"]) + + class mySparse(nn.Module): + def __init__(self, ebc, regroup): + super().__init__() + self.ebc = ebc + self.regroup = regroup + + def forward( + self, + features: KeyedJaggedTensor, + ) -> Dict[str, torch.Tensor]: + return self.regroup(keyed_tensors=[self.ebc(features)]) + + class myModel(nn.Module): + def __init__(self, ebc, regroup): + super().__init__() + self.sparse = mySparse(ebc, regroup) + + def forward( + self, + features: KeyedJaggedTensor, + ) -> Dict[str, torch.Tensor]: + return self.sparse(features) + + model = myModel(ebc1, regroup) + eager_out = model(id_list_features) + + model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer) + ep = torch.export.export( + model, + (id_list_features,), + {}, + strict=False, + # Allows KJT to not be unflattened and run a forward on unflattened EP + preserve_module_call_signature=(tuple(sparse_fqns)), + ) + unflatten_ep = torch.export.unflatten(ep) + # Create subgraph for regroup node + regroup_nodes = [] + mod = unflatten_ep.sparse + for node in mod.graph.nodes: + if node.op == "call_module" and "regroup" in str(node.target): + regroup_nodes.append(node) + partitions: List[Dict[torch.fx.Node, Optional[int]]] = [] + partitions.append(dict.fromkeys(regroup_nodes, None)) + fuse_by_partitions(unflatten_ep.sparse.graph_module, partitions) + sub_mod = torch.fx.graph_module._get_attr(mod.graph_module, "fused_0") + mod.add_module("fused_0", sub_mod) + mod.graph = mod.graph_module.graph + + # decapsulate method will short circuit the kt_regroup node + deserialized_model = decapsulate_ir_modules( + unflatten_ep, + JsonSerializer, + short_circuit_pytree_ebc_regroup=True, + finalize_interpreter_modules=True, + ) + + # we export the model with ebc1 and unflatten the model, + # and then swap with ebc2 (you can think this as the the sharding process + # resulting a shardedEBC), so that we can mimic the key-order change + # pyre-fixme[16]: `Module` has no attribute `ebc`. + # pyre-fixme[16]: `Tensor` has no attribute `ebc`. + deserialized_model.sparse.ebc = ebc2 + + deserialized_out = deserialized_model(id_list_features) + for key in eager_out.keys(): + torch.testing.assert_close(deserialized_out[key], eager_out[key]) + def test_cast_in_regroup(self) -> None: class Model(nn.Module): def __init__(self, ebc, fpebc, regroup): diff --git a/torchrec/ir/utils.py b/torchrec/ir/utils.py index 05bbfe9fb..ca1fa0fc8 100644 --- a/torchrec/ir/utils.py +++ b/torchrec/ir/utils.py @@ -312,7 +312,7 @@ def _short_circuit_pytree_ebc_regroup(module: nn.Module) -> nn.Module: """ ebc_fqns: List[str] = [] regroup_fqns: List[str] = [] - for fqn, m in module.named_modules(): + for fqn, m in module.named_modules(remove_duplicate=False): if isinstance(m, FeatureProcessedEmbeddingBagCollection): ebc_fqns.append(fqn) elif isinstance(m, EmbeddingBagCollection): @@ -357,11 +357,11 @@ def prune_pytree_flatten_unflatten( [tensors and specs] ==> (in-coming) pytree.unflatten ==> "preserved module" """ - def _get_graph_node(mod: nn.Module, fqn: str) -> Tuple[nn.Module, Node]: + def _get_graph_node(mod: nn.Module, fqn: str) -> Tuple[nn.Module, Node, str]: # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `nodes`. for node in mod.graph.nodes: if node.op == "call_module" and node.target == fqn: - return mod, node + return mod, node, fqn assert "." in fqn, f"can't find {fqn} in the graph of {mod}" curr, fqn = fqn.split(".", maxsplit=1) mod = getattr(mod, curr) @@ -369,10 +369,25 @@ def _get_graph_node(mod: nn.Module, fqn: str) -> Tuple[nn.Module, Node]: # remove tree_unflatten from the in_fqns (in-coming nodes) for fqn in in_fqns: - submodule, node = _get_graph_node(module, fqn) + submodule, node, submod_name = _get_graph_node(module, fqn) + # kt_regroup node will have either one arg or one kwarg - assert len(node.args) == 1 or len(node.kwargs) == 1 use_args = len(node.args) == 1 + use_kwargs = len(node.kwargs) == 1 + assert use_args or use_kwargs + + # Incase the kt_regroup module is partitioned to a submodule, we need + # to check the parent module for tree_unflatten node. + + if (use_args and cast(Node, node.args[0]).op == "placeholder") or ( + use_kwargs and cast(Node, list(node.kwargs.values())[0]).op == "placeholder" + ): + submodule, node, _ = _get_graph_node( + module, fqn.replace("." + submod_name, "") + ) + use_args = len(node.args) == 1 + use_kwargs = len(node.kwargs) == 1 + assert use_args or use_kwargs getitem_getitem = cast( Node, node.args[0] if use_args else list(node.kwargs.values())[0] @@ -403,7 +418,7 @@ def _get_graph_node(mod: nn.Module, fqn: str) -> Tuple[nn.Module, Node]: # remove tree_flatten_spec from the out_fqns (out-going nodes) for fqn in out_fqns: - submodule, node = _get_graph_node(module, fqn) + submodule, node, _ = _get_graph_node(module, fqn) users = list(node.users.keys()) assert ( len(users) == 1