Skip to content

Commit 3578390

Browse files
che-shfacebook-github-bot
authored andcommitted
Support postproc inputs to be list or dict with outputs from other postproc modules (#2733)
Summary: Pull Request resolved: #2733 Postproc modules with collection inputs (list or dict) with non-static (derived from input or other postproc) elements were not properly rewritten - input elements remained fx.Nodes even during the actual model forward (i.e. outside rewrite, during pipeline execution) To illustrate: ``` def forward(model_input: ...) -> ...: modified_input = model_input.float_features + 1 sharded_module_input = self.postproc(model_input, modified_input) # works sharded_module_input = self.postproc(model_input, [123]) # works sharded_module_input = self.postproc(model_input, [torch.ones_like(modified_input)]) # fails sharded_module_input = self.postproc(model_input, [modified_input]) # fails sharded_module_input = self.postproc(model_input, { 'a': 123 }) # works sharded_module_input = self.postproc(model_input, { 'a': torch.ones_like(modified_input) }) # fails sharded_module_input = self.postproc(model_input, { 'a': modified_input }) # fails return self.ebc(sharded_module_input) ``` Reviewed By: sarckk Differential Revision: D69292525 fbshipit-source-id: 4ba9672f3e31248b4850ba7ff35a07d7f292bd06
1 parent 62c0740 commit 3578390

File tree

3 files changed

+457
-191
lines changed

3 files changed

+457
-191
lines changed

torchrec/distributed/test_utils/test_model.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1832,6 +1832,103 @@ def forward(
18321832
return pred.sum(), pred
18331833

18341834

1835+
class TestModelWithPreprocCollectionArgs(nn.Module):
1836+
"""
1837+
Basic module with up to 3 postproc modules:
1838+
- postproc on idlist_features for non-weighted EBC
1839+
- postproc on idscore_features for weighted EBC
1840+
- postproc_inner on model input shared by both EBCs
1841+
- postproc_outer providing input to postproc_b (aka nested postproc)
1842+
1843+
Args:
1844+
tables,
1845+
weighted_tables,
1846+
device,
1847+
postproc_module_outer,
1848+
postproc_module_nested,
1849+
num_float_features,
1850+
1851+
Example:
1852+
>>> TestModelWithPreprocWithListArg(tables, weighted_tables, device)
1853+
1854+
Returns:
1855+
Tuple[torch.Tensor, torch.Tensor]
1856+
"""
1857+
1858+
CONST_DICT_KEY = "const"
1859+
INPUT_TENSOR_DICT_KEY = "tensor_from_input"
1860+
POSTPTOC_TENSOR_DICT_KEY = "tensor_from_postproc"
1861+
1862+
def __init__(
1863+
self,
1864+
tables: List[EmbeddingBagConfig],
1865+
weighted_tables: List[EmbeddingBagConfig],
1866+
device: torch.device,
1867+
postproc_module_outer: nn.Module,
1868+
postproc_module_nested: nn.Module,
1869+
num_float_features: int = 10,
1870+
) -> None:
1871+
super().__init__()
1872+
self.dense = TestDenseArch(num_float_features, device)
1873+
1874+
self.ebc: EmbeddingBagCollection = EmbeddingBagCollection(
1875+
tables=tables,
1876+
device=device,
1877+
)
1878+
self.weighted_ebc = EmbeddingBagCollection(
1879+
tables=weighted_tables,
1880+
is_weighted=True,
1881+
device=device,
1882+
)
1883+
self.postproc_nonweighted = TestPreprocNonWeighted()
1884+
self.postproc_weighted = TestPreprocWeighted()
1885+
self._postproc_module_outer = postproc_module_outer
1886+
self._postproc_module_nested = postproc_module_nested
1887+
1888+
def forward(
1889+
self,
1890+
input: ModelInput,
1891+
) -> Tuple[torch.Tensor, torch.Tensor]:
1892+
"""
1893+
Runs preproc for EBC and weighted EBC, optionally runs postproc for input
1894+
1895+
Args:
1896+
input
1897+
Returns:
1898+
Tuple[torch.Tensor, torch.Tensor]
1899+
"""
1900+
modified_input = input
1901+
1902+
outer_postproc_input = self._postproc_module_outer(modified_input)
1903+
1904+
preproc_input_list = [
1905+
1,
1906+
modified_input.float_features,
1907+
outer_postproc_input,
1908+
]
1909+
preproc_input_dict = {
1910+
self.CONST_DICT_KEY: 1,
1911+
self.INPUT_TENSOR_DICT_KEY: modified_input.float_features,
1912+
self.POSTPTOC_TENSOR_DICT_KEY: outer_postproc_input,
1913+
}
1914+
1915+
modified_input = self._postproc_module_nested(
1916+
modified_input, preproc_input_list, preproc_input_dict
1917+
)
1918+
1919+
modified_idlist_features = self.postproc_nonweighted(
1920+
modified_input.idlist_features
1921+
)
1922+
modified_idscore_features = self.postproc_weighted(
1923+
modified_input.idscore_features
1924+
)
1925+
ebc_out = self.ebc(modified_idlist_features[0])
1926+
weighted_ebc_out = self.weighted_ebc(modified_idscore_features[0])
1927+
1928+
pred = torch.cat([ebc_out.values(), weighted_ebc_out.values()], dim=1)
1929+
return pred.sum(), pred
1930+
1931+
18351932
class TestNegSamplingModule(torch.nn.Module):
18361933
"""
18371934
Basic module to simulate feature augmentation postproc (e.g. neg sampling) for testing

torchrec/distributed/train_pipeline/tests/test_train_pipelines.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313
from contextlib import ExitStack
1414
from dataclasses import dataclass
1515
from functools import partial
16-
from typing import cast, List, Optional, Tuple, Type, Union
16+
from typing import cast, Dict, List, Optional, Tuple, Type, Union
1717
from unittest.mock import MagicMock
1818

1919
import torch
2020
from hypothesis import given, settings, strategies as st, Verbosity
2121
from torch import nn, optim
2222
from torch._dynamo.testing import reduce_to_scalar_loss
2323
from torch._dynamo.utils import counters
24+
from torch.fx._symbolic_trace import is_fx_tracing
2425
from torchrec.distributed import DistributedModelParallel
2526
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
2627
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
@@ -36,6 +37,7 @@
3637
ModelInput,
3738
TestEBCSharder,
3839
TestModelWithPreproc,
40+
TestModelWithPreprocCollectionArgs,
3941
TestNegSamplingModule,
4042
TestPositionWeightedPreprocModule,
4143
TestSparseNN,
@@ -1448,6 +1450,81 @@ def forward(
14481450
self.assertEqual(len(pipeline._pipelined_modules), 2)
14491451
self.assertEqual(len(pipeline._pipelined_postprocs), 1)
14501452

1453+
# pyre-ignore
1454+
@unittest.skipIf(
1455+
not torch.cuda.is_available(),
1456+
"Not enough GPUs, this test requires at least one GPU",
1457+
)
1458+
def test_pipeline_postproc_with_collection_args(self) -> None:
1459+
"""
1460+
Exercises scenario when postproc module has an argument that is a list or dict
1461+
with some elements being:
1462+
* static scalars
1463+
* static tensors (e.g. torch.ones())
1464+
* tensors derived from input batch (e.g. input.idlist_features["feature_0"])
1465+
* tensors derived from input batch and other postproc module (e.g. other_postproc(input.idlist_features["feature_0"]))
1466+
"""
1467+
test_runner = self
1468+
1469+
class PostprocOuter(nn.Module):
1470+
def __init__(
1471+
self,
1472+
) -> None:
1473+
super().__init__()
1474+
1475+
def forward(
1476+
self,
1477+
model_input: ModelInput,
1478+
) -> torch.Tensor:
1479+
return model_input.float_features * 0.1
1480+
1481+
class PostprocInner(nn.Module):
1482+
def __init__(
1483+
self,
1484+
) -> None:
1485+
super().__init__()
1486+
1487+
def forward(
1488+
self,
1489+
model_input: ModelInput,
1490+
input_list: List[Union[torch.Tensor, int]],
1491+
input_dict: Dict[str, Union[torch.Tensor, int]],
1492+
) -> ModelInput:
1493+
if not is_fx_tracing():
1494+
for idx, value in enumerate(input_list):
1495+
if isinstance(value, torch.fx.Node):
1496+
test_runner.fail(
1497+
f"input_list[{idx}] was a fx.Node: {value}"
1498+
)
1499+
model_input.float_features += value
1500+
1501+
for key, value in input_dict.items():
1502+
if isinstance(value, torch.fx.Node):
1503+
test_runner.fail(
1504+
f"input_dict[{key}] was a fx.Node: {value}"
1505+
)
1506+
model_input.float_features += value
1507+
1508+
return model_input
1509+
1510+
model = TestModelWithPreprocCollectionArgs(
1511+
tables=self.tables[:-1], # ignore last table as postproc will remove
1512+
weighted_tables=self.weighted_tables[:-1], # ignore last table
1513+
device=self.device,
1514+
postproc_module_outer=PostprocOuter(),
1515+
postproc_module_nested=PostprocInner(),
1516+
)
1517+
1518+
pipelined_model, pipeline = self._check_output_equal(
1519+
model,
1520+
self.sharding_type,
1521+
)
1522+
1523+
# both EC end EBC are pipelined
1524+
self.assertEqual(len(pipeline._pipelined_modules), 2)
1525+
# both outer and nested postproces are pipelined
1526+
self.assertEqual(len(pipeline._pipelined_postprocs), 4)
1527+
14511528

14521529
class EmbeddingTrainPipelineTest(TrainPipelineSparseDistTestBase):
14531530
@unittest.skipIf(

0 commit comments

Comments
 (0)