|
27 | 27 |
|
28 | 28 | import torch |
29 | 29 | from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings |
30 | | -from tensordict import TensorDict |
31 | 30 | from torch import distributed as dist, nn, Tensor |
32 | 31 | from torch.autograd.profiler import record_function |
33 | 32 | from torch.distributed._shard.sharded_tensor import TensorProperties |
|
95 | 94 | from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule |
96 | 95 | from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer |
97 | 96 | from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor |
98 | | -from torchrec.sparse.tensor_dict import maybe_td_to_kjt |
99 | 97 |
|
100 | 98 | try: |
101 | 99 | torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") |
@@ -658,7 +656,9 @@ def __init__( |
658 | 656 | self._inverse_indices_permute_indices: Optional[torch.Tensor] = None |
659 | 657 | # to support mean pooling callback hook |
660 | 658 | self._has_mean_pooling_callback: bool = ( |
661 | | - PoolingType.MEAN.value in self._pooling_type_to_rs_features |
| 659 | + True |
| 660 | + if PoolingType.MEAN.value in self._pooling_type_to_rs_features |
| 661 | + else False |
662 | 662 | ) |
663 | 663 | self._dim_per_key: Optional[torch.Tensor] = None |
664 | 664 | self._kjt_key_indices: Dict[str, int] = {} |
@@ -1189,16 +1189,8 @@ def _create_inverse_indices_permute_indices( |
1189 | 1189 |
|
1190 | 1190 | # pyre-ignore [14] |
1191 | 1191 | def input_dist( |
1192 | | - self, |
1193 | | - ctx: EmbeddingBagCollectionContext, |
1194 | | - features: Union[KeyedJaggedTensor, TensorDict], |
| 1192 | + self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor |
1195 | 1193 | ) -> Awaitable[Awaitable[KJTList]]: |
1196 | | - if isinstance(features, TensorDict): |
1197 | | - feature_keys = list(features.keys()) # pyre-ignore[6] |
1198 | | - if len(self._features_order) > 0: |
1199 | | - feature_keys = [feature_keys[i] for i in self._features_order] |
1200 | | - self._has_features_permute = False # feature_keys are in order |
1201 | | - features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6] |
1202 | 1194 | ctx.variable_batch_per_feature = features.variable_stride_per_key() |
1203 | 1195 | ctx.inverse_indices = features.inverse_indices_or_none() |
1204 | 1196 | if self._has_uninitialized_input_dist: |
|
0 commit comments