Commit dd5457c
add NJT/TD support for EC (meta-pytorch#2596)
Summary:
Pull Request resolved: meta-pytorch#2596
# Documents
* [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv)
* [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79)
{F1949248817}
# Context
* Continued from previous D66465376, which adds NJT/TD support for EBC, this diff is for EC
* As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict)
* Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EC ==> Output (KT)`
* In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT.
* In distributed mode, we do the conversion inside the `ShardedEmbeddingCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication.
* In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication.
While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication.
NOTE: This diff re-used a number of existing test framework/cases with minimal critical changes in the `EmbeddingCollection` and `shardedEmbeddingCollection`. Please see the follwoing verifications for NJT/TD correctness.
# Verification - input with TensorDict
* breakpoint at [sharding_single_rank_test](https://fburl.com/code/x74s13fd)
* sharded model
```
(Pdb) local_model
DistributedModelParallel(
(_dmp_wrapped_module): DistributedDataParallel(
(module): TestSequenceSparseNN(
(dense): TestDenseArch(
(linear): Linear(in_features=16, out_features=8, bias=True)
)
(sparse): TestSequenceSparseArch(
(ec): ShardedEmbeddingCollection(
(lookups):
GroupedEmbeddingsLookup(
(_emb_modules): ModuleList(
(0): BatchedDenseEmbedding(
(_emb_module): DenseTableBatchedEmbeddingBagsCodegen()
)
)
)
(_input_dists):
RwSparseFeaturesDist(
(_dist): KJTAllToAll()
)
(_output_dists):
RwSequenceEmbeddingDist(
(_dist): SequenceEmbeddingsAllToAll()
)
(embeddings): ModuleDict(
(table_0): Module()
(table_1): Module()
(table_2): Module()
(table_3): Module()
(table_4): Module()
(table_5): Module()
)
)
)
(over): TestSequenceOverArch(
(linear): Linear(in_features=1928, out_features=16, bias=True)
)
)
)
)
```
* TD input
```
(Pdb) local_input
ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433,
0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056],
[0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146,
0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671],
[0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315,
0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678],
[0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320,
0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617]],
device='cuda:0'), idlist_features=TensorDict(
fields={
feature_0: NestedTensor(shape=torch.Size([4, j5]), device=cuda:0, dtype=torch.int64, is_shared=True),
feature_1: NestedTensor(shape=torch.Size([4, j6]), device=cuda:0, dtype=torch.int64, is_shared=True),
feature_2: NestedTensor(shape=torch.Size([4, j7]), device=cuda:0, dtype=torch.int64, is_shared=True),
feature_3: NestedTensor(shape=torch.Size([4, j8]), device=cuda:0, dtype=torch.int64, is_shared=True)},
batch_size=torch.Size([]),
device=cuda:0,
is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895], device='cuda:0'))
```
* unsharded model
```
(Pdb) global_model
TestSequenceSparseNN(
(dense): TestDenseArch(
(linear): Linear(in_features=16, out_features=8, bias=True)
)
(sparse): TestSequenceSparseArch(
(ec): EmbeddingCollection(
(embeddings): ModuleDict(
(table_0): Embedding(11, 16)
(table_1): Embedding(22, 16)
(table_2): Embedding(33, 16)
(table_3): Embedding(44, 16)
(table_4): Embedding(11, 16)
(table_5): Embedding(22, 16)
)
)
)
(over): TestSequenceOverArch(
(linear): Linear(in_features=1928, out_features=16, bias=True)
)
)
```
* TD input
```
(Pdb) global_input
ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433,
0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056],
[0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146,
0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671],
[0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315,
0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678],
[0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320,
0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617],
[0.6807, 0.7970, 0.1164, 0.8487, 0.7730, 0.1654, 0.5599, 0.5923, 0.3909,
0.4720, 0.9423, 0.7868, 0.3710, 0.6075, 0.6849, 0.1366],
[0.0246, 0.5967, 0.2838, 0.8114, 0.3761, 0.3963, 0.7792, 0.9119, 0.4026,
0.4769, 0.1477, 0.0923, 0.0723, 0.4416, 0.4560, 0.9548],
[0.8666, 0.6254, 0.9162, 0.1954, 0.8466, 0.6498, 0.3412, 0.2098, 0.9786,
0.3349, 0.7625, 0.3615, 0.8880, 0.0751, 0.8417, 0.5380],
[0.2857, 0.6871, 0.6694, 0.8206, 0.5142, 0.5641, 0.3780, 0.9441, 0.0964,
0.2007, 0.1148, 0.8054, 0.1520, 0.3742, 0.6364, 0.9797]],
device='cuda:0'), idlist_features=TensorDict(
fields={
feature_0: NestedTensor(shape=torch.Size([8, j1]), device=cuda:0, dtype=torch.int64, is_shared=True),
feature_1: NestedTensor(shape=torch.Size([8, j2]), device=cuda:0, dtype=torch.int64, is_shared=True),
feature_2: NestedTensor(shape=torch.Size([8, j3]), device=cuda:0, dtype=torch.int64, is_shared=True),
feature_3: NestedTensor(shape=torch.Size([8, j4]), device=cuda:0, dtype=torch.int64, is_shared=True)},
batch_size=torch.Size([]),
device=cuda:0,
is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895, 0.3132, 0.2133, 0.4997, 0.0055],
device='cuda:0'))
```
Reviewed By: dstaay-fb
Differential Revision: D66521351
fbshipit-source-id: af433f18f27e26fafbb0ad61b2314da94f99d8901 parent f0ccbfb commit dd5457c
File tree
4 files changed
+84
-10
lines changed- torchrec
- distributed
- test_utils
- tests
- modules
4 files changed
+84
-10
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
26 | 26 | | |
27 | 27 | | |
28 | 28 | | |
| 29 | + | |
29 | 30 | | |
30 | 31 | | |
31 | 32 | | |
| |||
90 | 91 | | |
91 | 92 | | |
92 | 93 | | |
| 94 | + | |
93 | 95 | | |
94 | 96 | | |
95 | 97 | | |
| |||
1198 | 1200 | | |
1199 | 1201 | | |
1200 | 1202 | | |
1201 | | - | |
| 1203 | + | |
1202 | 1204 | | |
| 1205 | + | |
| 1206 | + | |
| 1207 | + | |
| 1208 | + | |
| 1209 | + | |
| 1210 | + | |
| 1211 | + | |
1203 | 1212 | | |
1204 | 1213 | | |
1205 | 1214 | | |
| |||
1209 | 1218 | | |
1210 | 1219 | | |
1211 | 1220 | | |
1212 | | - | |
| 1221 | + | |
1213 | 1222 | | |
1214 | 1223 | | |
1215 | 1224 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
147 | 147 | | |
148 | 148 | | |
149 | 149 | | |
| 150 | + | |
150 | 151 | | |
151 | 152 | | |
152 | 153 | | |
| |||
177 | 178 | | |
178 | 179 | | |
179 | 180 | | |
180 | | - | |
181 | | - | |
182 | | - | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
183 | 184 | | |
184 | 185 | | |
185 | 186 | | |
| |||
188 | 189 | | |
189 | 190 | | |
190 | 191 | | |
191 | | - | |
192 | | - | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
193 | 212 | | |
194 | 213 | | |
195 | 214 | | |
| |||
200 | 219 | | |
201 | 220 | | |
202 | 221 | | |
203 | | - | |
204 | 222 | | |
205 | 223 | | |
206 | 224 | | |
| |||
297 | 315 | | |
298 | 316 | | |
299 | 317 | | |
| 318 | + | |
300 | 319 | | |
301 | 320 | | |
302 | 321 | | |
| |||
319 | 338 | | |
320 | 339 | | |
321 | 340 | | |
| 341 | + | |
322 | 342 | | |
323 | 343 | | |
324 | 344 | | |
| |||
Lines changed: 41 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
376 | 376 | | |
377 | 377 | | |
378 | 378 | | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
| 393 | + | |
| 394 | + | |
| 395 | + | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
| 408 | + | |
| 409 | + | |
| 410 | + | |
| 411 | + | |
| 412 | + | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
| 416 | + | |
| 417 | + | |
| 418 | + | |
| 419 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
219 | 219 | | |
220 | 220 | | |
221 | 221 | | |
222 | | - | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
223 | 226 | | |
224 | 227 | | |
225 | 228 | | |
| |||
450 | 453 | | |
451 | 454 | | |
452 | 455 | | |
453 | | - | |
| 456 | + | |
454 | 457 | | |
455 | 458 | | |
456 | 459 | | |
| |||
463 | 466 | | |
464 | 467 | | |
465 | 468 | | |
| 469 | + | |
466 | 470 | | |
467 | 471 | | |
468 | 472 | | |
| |||
0 commit comments