|
14 | 14 |
|
15 | 15 | import torch |
16 | 16 | import torch.nn as nn |
17 | | -from tensordict import TensorDict |
18 | 17 | from torchrec.distributed.embedding_tower_sharding import ( |
19 | 18 | EmbeddingTowerCollectionSharder, |
20 | 19 | EmbeddingTowerSharder, |
|
47 | 46 | @dataclass |
48 | 47 | class ModelInput(Pipelineable): |
49 | 48 | float_features: torch.Tensor |
50 | | - idlist_features: Union[KeyedJaggedTensor, TensorDict] |
51 | | - idscore_features: Optional[Union[KeyedJaggedTensor, TensorDict]] |
| 49 | + idlist_features: KeyedJaggedTensor |
| 50 | + idscore_features: Optional[KeyedJaggedTensor] |
52 | 51 | label: torch.Tensor |
53 | 52 |
|
54 | 53 | @staticmethod |
@@ -77,13 +76,11 @@ def generate( |
77 | 76 | randomize_indices: bool = True, |
78 | 77 | device: Optional[torch.device] = None, |
79 | 78 | max_feature_lengths: Optional[List[int]] = None, |
80 | | - input_type: str = "kjt", |
81 | 79 | ) -> Tuple["ModelInput", List["ModelInput"]]: |
82 | 80 | """ |
83 | 81 | Returns a global (single-rank training) batch |
84 | 82 | and a list of local (multi-rank training) batches of world_size. |
85 | 83 | """ |
86 | | - |
87 | 84 | batch_size_by_rank = [batch_size] * world_size |
88 | 85 | if variable_batch_size: |
89 | 86 | batch_size_by_rank = [ |
@@ -202,26 +199,11 @@ def _validate_pooling_factor( |
202 | 199 | ) |
203 | 200 | global_idlist_lengths.append(lengths) |
204 | 201 | global_idlist_indices.append(indices) |
205 | | - |
206 | | - if input_type == "kjt": |
207 | | - global_idlist_input = KeyedJaggedTensor( |
208 | | - keys=idlist_features, |
209 | | - values=torch.cat(global_idlist_indices), |
210 | | - lengths=torch.cat(global_idlist_lengths), |
211 | | - ) |
212 | | - elif input_type == "td": |
213 | | - dict_of_nt = { |
214 | | - k: torch.nested.nested_tensor_from_jagged( |
215 | | - values=values, |
216 | | - lengths=lengths, |
217 | | - ) |
218 | | - for k, values, lengths in zip( |
219 | | - idlist_features, global_idlist_indices, global_idlist_lengths |
220 | | - ) |
221 | | - } |
222 | | - global_idlist_input = TensorDict(source=dict_of_nt) |
223 | | - else: |
224 | | - raise ValueError(f"For IdList features, unknown input type {input_type}") |
| 202 | + global_idlist_kjt = KeyedJaggedTensor( |
| 203 | + keys=idlist_features, |
| 204 | + values=torch.cat(global_idlist_indices), |
| 205 | + lengths=torch.cat(global_idlist_lengths), |
| 206 | + ) |
225 | 207 |
|
226 | 208 | for idx in range(len(idscore_ind_ranges)): |
227 | 209 | ind_range = idscore_ind_ranges[idx] |
@@ -263,25 +245,16 @@ def _validate_pooling_factor( |
263 | 245 | global_idscore_lengths.append(lengths) |
264 | 246 | global_idscore_indices.append(indices) |
265 | 247 | global_idscore_weights.append(weights) |
266 | | - |
267 | | - if input_type == "kjt": |
268 | | - global_idscore_input = ( |
269 | | - KeyedJaggedTensor( |
270 | | - keys=idscore_features, |
271 | | - values=torch.cat(global_idscore_indices), |
272 | | - lengths=torch.cat(global_idscore_lengths), |
273 | | - weights=torch.cat(global_idscore_weights), |
274 | | - ) |
275 | | - if global_idscore_indices |
276 | | - else None |
| 248 | + global_idscore_kjt = ( |
| 249 | + KeyedJaggedTensor( |
| 250 | + keys=idscore_features, |
| 251 | + values=torch.cat(global_idscore_indices), |
| 252 | + lengths=torch.cat(global_idscore_lengths), |
| 253 | + weights=torch.cat(global_idscore_weights), |
277 | 254 | ) |
278 | | - elif input_type == "td": |
279 | | - assert ( |
280 | | - len(idscore_features) == 0 |
281 | | - ), "TensorDict does not support weighted features" |
282 | | - global_idscore_input = None |
283 | | - else: |
284 | | - raise ValueError(f"For weighted features, unknown input type {input_type}") |
| 255 | + if global_idscore_indices |
| 256 | + else None |
| 257 | + ) |
285 | 258 |
|
286 | 259 | if randomize_indices: |
287 | 260 | global_float = torch.rand( |
@@ -330,57 +303,36 @@ def _validate_pooling_factor( |
330 | 303 | weights[lengths_cumsum[r] : lengths_cumsum[r + 1]] |
331 | 304 | ) |
332 | 305 |
|
333 | | - if input_type == "kjt": |
334 | | - local_idlist_input = KeyedJaggedTensor( |
335 | | - keys=idlist_features, |
336 | | - values=torch.cat(local_idlist_indices), |
337 | | - lengths=torch.cat(local_idlist_lengths), |
338 | | - ) |
339 | | - |
340 | | - local_idscore_input = ( |
341 | | - KeyedJaggedTensor( |
342 | | - keys=idscore_features, |
343 | | - values=torch.cat(local_idscore_indices), |
344 | | - lengths=torch.cat(local_idscore_lengths), |
345 | | - weights=torch.cat(local_idscore_weights), |
346 | | - ) |
347 | | - if local_idscore_indices |
348 | | - else None |
349 | | - ) |
350 | | - elif input_type == "td": |
351 | | - dict_of_nt = { |
352 | | - k: torch.nested.nested_tensor_from_jagged( |
353 | | - values=values, |
354 | | - lengths=lengths, |
355 | | - ) |
356 | | - for k, values, lengths in zip( |
357 | | - idlist_features, local_idlist_indices, local_idlist_lengths |
358 | | - ) |
359 | | - } |
360 | | - local_idlist_input = TensorDict(source=dict_of_nt) |
361 | | - assert ( |
362 | | - len(idscore_features) == 0 |
363 | | - ), "TensorDict does not support weighted features" |
364 | | - local_idscore_input = None |
| 306 | + local_idlist_kjt = KeyedJaggedTensor( |
| 307 | + keys=idlist_features, |
| 308 | + values=torch.cat(local_idlist_indices), |
| 309 | + lengths=torch.cat(local_idlist_lengths), |
| 310 | + ) |
365 | 311 |
|
366 | | - else: |
367 | | - raise ValueError( |
368 | | - f"For weighted features, unknown input type {input_type}" |
| 312 | + local_idscore_kjt = ( |
| 313 | + KeyedJaggedTensor( |
| 314 | + keys=idscore_features, |
| 315 | + values=torch.cat(local_idscore_indices), |
| 316 | + lengths=torch.cat(local_idscore_lengths), |
| 317 | + weights=torch.cat(local_idscore_weights), |
369 | 318 | ) |
| 319 | + if local_idscore_indices |
| 320 | + else None |
| 321 | + ) |
370 | 322 |
|
371 | 323 | local_input = ModelInput( |
372 | 324 | float_features=global_float[r * batch_size : (r + 1) * batch_size], |
373 | | - idlist_features=local_idlist_input, |
374 | | - idscore_features=local_idscore_input, |
| 325 | + idlist_features=local_idlist_kjt, |
| 326 | + idscore_features=local_idscore_kjt, |
375 | 327 | label=global_label[r * batch_size : (r + 1) * batch_size], |
376 | 328 | ) |
377 | 329 | local_inputs.append(local_input) |
378 | 330 |
|
379 | 331 | return ( |
380 | 332 | ModelInput( |
381 | 333 | float_features=global_float, |
382 | | - idlist_features=global_idlist_input, |
383 | | - idscore_features=global_idscore_input, |
| 334 | + idlist_features=global_idlist_kjt, |
| 335 | + idscore_features=global_idscore_kjt, |
384 | 336 | label=global_label, |
385 | 337 | ), |
386 | 338 | local_inputs, |
@@ -671,9 +623,8 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "ModelInput": |
671 | 623 |
|
672 | 624 | def record_stream(self, stream: torch.Stream) -> None: |
673 | 625 | self.float_features.record_stream(stream) |
674 | | - if isinstance(self.idlist_features, KeyedJaggedTensor): |
675 | | - self.idlist_features.record_stream(stream) |
676 | | - if isinstance(self.idscore_features, KeyedJaggedTensor): |
| 626 | + self.idlist_features.record_stream(stream) |
| 627 | + if self.idscore_features is not None: |
677 | 628 | self.idscore_features.record_stream(stream) |
678 | 629 | self.label.record_stream(stream) |
679 | 630 |
|
@@ -1880,8 +1831,6 @@ def forward(self, input: ModelInput) -> ModelInput: |
1880 | 1831 | ) |
1881 | 1832 |
|
1882 | 1833 | # stride will be same but features will be joined |
1883 | | - assert isinstance(modified_input.idlist_features, KeyedJaggedTensor) |
1884 | | - assert isinstance(self._extra_input.idlist_features, KeyedJaggedTensor) |
1885 | 1834 | modified_input.idlist_features = KeyedJaggedTensor.concat( |
1886 | 1835 | [modified_input.idlist_features, self._extra_input.idlist_features] |
1887 | 1836 | ) |
|
0 commit comments