Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 26 additions & 22 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,6 +1402,7 @@ def _maybe_compute_kjt_to_jt_dict(
variable_stride_per_key: bool,
weights: Optional[torch.Tensor],
jt_dict: Optional[Dict[str, JaggedTensor]],
compute_offsets: bool = True,
) -> Dict[str, JaggedTensor]:
if not length_per_key:
return {}
Expand All @@ -1418,50 +1419,49 @@ def _maybe_compute_kjt_to_jt_dict(
torch._check(cat_size <= total_size)
torch._check(cat_size == total_size)
torch._check_is_size(stride)

values_list = torch.split(values, length_per_key)
split_offsets: list[torch.Tensor] = []
if variable_stride_per_key:
split_lengths = torch.split(lengths, stride_per_key)
split_offsets = [
torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
for lengths in split_lengths
]
if compute_offsets:
split_offsets = [
torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
for lengths in split_lengths
]
elif pt2_guard_size_oblivious(lengths.numel() > 0):
strided_lengths = lengths.view(len(keys), stride)
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
torch._check(strided_lengths.size(0) > 0)
torch._check(strided_lengths.size(1) > 0)
split_lengths = torch.unbind(
strided_lengths,
dim=0,
)
split_offsets = torch.unbind(
_batched_lengths_to_offsets(strided_lengths),
dim=0,
)

split_lengths = torch.unbind(strided_lengths, dim=0)
if compute_offsets:
split_offsets = torch.unbind( # pyre-ignore
_batched_lengths_to_offsets(strided_lengths), dim=0
)
else:
split_lengths = torch.unbind(lengths, dim=0)
split_offsets = torch.unbind(lengths, dim=0)
if compute_offsets:
split_offsets = split_lengths # pyre-ignore

if weights is not None:
weights_list = torch.split(weights, length_per_key)
for idx, key in enumerate(keys):
length = split_lengths[idx]
offset = split_offsets[idx]
_jt_dict[key] = JaggedTensor(
lengths=length,
offsets=offset,
lengths=split_lengths[idx],
offsets=split_offsets[idx] if compute_offsets else None,
values=values_list[idx],
weights=weights_list[idx],
)
else:
for idx, key in enumerate(keys):
length = split_lengths[idx]
offset = split_offsets[idx]
_jt_dict[key] = JaggedTensor(
lengths=length,
offsets=offset,
lengths=split_lengths[idx],
offsets=split_offsets[idx] if compute_offsets else None,
values=values_list[idx],
)

return _jt_dict


Expand Down Expand Up @@ -2698,11 +2698,14 @@ def __getitem__(self, key: str) -> JaggedTensor:
offsets=None,
)

def to_dict(self) -> Dict[str, JaggedTensor]:
def to_dict(self, compute_offsets: bool = True) -> Dict[str, JaggedTensor]:
"""
Returns a dictionary of JaggedTensor for each key.
Will cache result in self._jt_dict.

Args:
compute_offsets (str): compute offsets when true.

Returns:
Dict[str, JaggedTensor]: dictionary of JaggedTensor for each key.
"""
Expand All @@ -2720,6 +2723,7 @@ def to_dict(self) -> Dict[str, JaggedTensor]:
variable_stride_per_key=self.variable_stride_per_key(),
weights=self.weights_or_none(),
jt_dict=self._jt_dict,
compute_offsets=compute_offsets,
)
self._jt_dict = _jt_dict
return _jt_dict
Expand Down
48 changes: 48 additions & 0 deletions torchrec/sparse/tests/test_jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,54 @@ def test_from_jt_dict_vb(self) -> None:
torch.equal(j1.values(), torch.Tensor([3.0, 4.0, 5.0, 6.0, 7.0, 8.0]))
)

def test_to_dict_compute_offsets_false(self) -> None:
# Setup: KJT with two keys and standard stride
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
keys = ["f1", "f2"]
lengths = torch.IntTensor([2, 0, 1, 1, 1, 3])

kjt = KeyedJaggedTensor(values=values, keys=keys, lengths=lengths)

# Execute: call to_dict with compute_offsets=False
jt_dict = kjt.to_dict(compute_offsets=False)

# Assert: offsets_or_none() should be None for each JaggedTensor
self.assertIsNone(jt_dict["f1"].offsets_or_none())
self.assertIsNone(jt_dict["f2"].offsets_or_none())
# Lengths should still be available
self.assertTrue(
torch.equal(jt_dict["f1"].lengths(), torch.IntTensor([2, 0, 1]))
)
self.assertTrue(
torch.equal(jt_dict["f2"].lengths(), torch.IntTensor([1, 1, 3]))
)

def test_to_dict_compute_offsets_false_variable_stride(self) -> None:
# Setup: KJT with variable stride per key (reusing test_from_jt_dict_vb data)
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
keys = ["f1", "f2"]
lengths = torch.IntTensor([2, 0, 1, 1, 1, 3])
stride_per_key_per_rank = [[2], [4]]

kjt = KeyedJaggedTensor(
values=values,
keys=keys,
lengths=lengths,
stride_per_key_per_rank=stride_per_key_per_rank,
)

# Execute: call to_dict with compute_offsets=False
jt_dict = kjt.to_dict(compute_offsets=False)

# Assert: offsets_or_none() should be None for each JaggedTensor
self.assertIsNone(jt_dict["f1"].offsets_or_none())
self.assertIsNone(jt_dict["f2"].offsets_or_none())
# Lengths should still be available
self.assertTrue(torch.equal(jt_dict["f1"].lengths(), torch.IntTensor([2, 0])))
self.assertTrue(
torch.equal(jt_dict["f2"].lengths(), torch.IntTensor([1, 1, 1, 3]))
)


class TestJaggedTensorTracing(unittest.TestCase):
def test_jagged_tensor(self) -> None:
Expand Down
Loading