diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index db1a26aba..9beee8203 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -1402,6 +1402,7 @@ def _maybe_compute_kjt_to_jt_dict( variable_stride_per_key: bool, weights: Optional[torch.Tensor], jt_dict: Optional[Dict[str, JaggedTensor]], + compute_offsets: bool = True, ) -> Dict[str, JaggedTensor]: if not length_per_key: return {} @@ -1418,50 +1419,49 @@ def _maybe_compute_kjt_to_jt_dict( torch._check(cat_size <= total_size) torch._check(cat_size == total_size) torch._check_is_size(stride) + values_list = torch.split(values, length_per_key) + split_offsets: list[torch.Tensor] = [] if variable_stride_per_key: split_lengths = torch.split(lengths, stride_per_key) - split_offsets = [ - torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) - for lengths in split_lengths - ] + if compute_offsets: + split_offsets = [ + torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) + for lengths in split_lengths + ] elif pt2_guard_size_oblivious(lengths.numel() > 0): strided_lengths = lengths.view(len(keys), stride) if not torch.jit.is_scripting() and is_torchdynamo_compiling(): torch._check(strided_lengths.size(0) > 0) torch._check(strided_lengths.size(1) > 0) - split_lengths = torch.unbind( - strided_lengths, - dim=0, - ) - split_offsets = torch.unbind( - _batched_lengths_to_offsets(strided_lengths), - dim=0, - ) + + split_lengths = torch.unbind(strided_lengths, dim=0) + if compute_offsets: + split_offsets = torch.unbind( # pyre-ignore + _batched_lengths_to_offsets(strided_lengths), dim=0 + ) else: split_lengths = torch.unbind(lengths, dim=0) - split_offsets = torch.unbind(lengths, dim=0) + if compute_offsets: + split_offsets = split_lengths # pyre-ignore if weights is not None: weights_list = torch.split(weights, length_per_key) for idx, key in enumerate(keys): - length = split_lengths[idx] - offset = split_offsets[idx] _jt_dict[key] = JaggedTensor( - lengths=length, - offsets=offset, + lengths=split_lengths[idx], + offsets=split_offsets[idx] if compute_offsets else None, values=values_list[idx], weights=weights_list[idx], ) else: for idx, key in enumerate(keys): - length = split_lengths[idx] - offset = split_offsets[idx] _jt_dict[key] = JaggedTensor( - lengths=length, - offsets=offset, + lengths=split_lengths[idx], + offsets=split_offsets[idx] if compute_offsets else None, values=values_list[idx], ) + return _jt_dict @@ -2698,11 +2698,14 @@ def __getitem__(self, key: str) -> JaggedTensor: offsets=None, ) - def to_dict(self) -> Dict[str, JaggedTensor]: + def to_dict(self, compute_offsets: bool = True) -> Dict[str, JaggedTensor]: """ Returns a dictionary of JaggedTensor for each key. Will cache result in self._jt_dict. + Args: + compute_offsets (str): compute offsets when true. + Returns: Dict[str, JaggedTensor]: dictionary of JaggedTensor for each key. """ @@ -2720,6 +2723,7 @@ def to_dict(self) -> Dict[str, JaggedTensor]: variable_stride_per_key=self.variable_stride_per_key(), weights=self.weights_or_none(), jt_dict=self._jt_dict, + compute_offsets=compute_offsets, ) self._jt_dict = _jt_dict return _jt_dict diff --git a/torchrec/sparse/tests/test_jagged_tensor.py b/torchrec/sparse/tests/test_jagged_tensor.py index 09a7e6b5f..7a52b4a4d 100644 --- a/torchrec/sparse/tests/test_jagged_tensor.py +++ b/torchrec/sparse/tests/test_jagged_tensor.py @@ -811,6 +811,54 @@ def test_from_jt_dict_vb(self) -> None: torch.equal(j1.values(), torch.Tensor([3.0, 4.0, 5.0, 6.0, 7.0, 8.0])) ) + def test_to_dict_compute_offsets_false(self) -> None: + # Setup: KJT with two keys and standard stride + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + keys = ["f1", "f2"] + lengths = torch.IntTensor([2, 0, 1, 1, 1, 3]) + + kjt = KeyedJaggedTensor(values=values, keys=keys, lengths=lengths) + + # Execute: call to_dict with compute_offsets=False + jt_dict = kjt.to_dict(compute_offsets=False) + + # Assert: offsets_or_none() should be None for each JaggedTensor + self.assertIsNone(jt_dict["f1"].offsets_or_none()) + self.assertIsNone(jt_dict["f2"].offsets_or_none()) + # Lengths should still be available + self.assertTrue( + torch.equal(jt_dict["f1"].lengths(), torch.IntTensor([2, 0, 1])) + ) + self.assertTrue( + torch.equal(jt_dict["f2"].lengths(), torch.IntTensor([1, 1, 3])) + ) + + def test_to_dict_compute_offsets_false_variable_stride(self) -> None: + # Setup: KJT with variable stride per key (reusing test_from_jt_dict_vb data) + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + keys = ["f1", "f2"] + lengths = torch.IntTensor([2, 0, 1, 1, 1, 3]) + stride_per_key_per_rank = [[2], [4]] + + kjt = KeyedJaggedTensor( + values=values, + keys=keys, + lengths=lengths, + stride_per_key_per_rank=stride_per_key_per_rank, + ) + + # Execute: call to_dict with compute_offsets=False + jt_dict = kjt.to_dict(compute_offsets=False) + + # Assert: offsets_or_none() should be None for each JaggedTensor + self.assertIsNone(jt_dict["f1"].offsets_or_none()) + self.assertIsNone(jt_dict["f2"].offsets_or_none()) + # Lengths should still be available + self.assertTrue(torch.equal(jt_dict["f1"].lengths(), torch.IntTensor([2, 0]))) + self.assertTrue( + torch.equal(jt_dict["f2"].lengths(), torch.IntTensor([1, 1, 1, 3])) + ) + class TestJaggedTensorTracing(unittest.TestCase): def test_jagged_tensor(self) -> None: