Skip to content

Commit 3543c37

Browse files
author
Vincent Moens
committed
[Feature] TD+NJT to(device) support
ghstack-source-id: 792ce21 Pull Request resolved: #1022
1 parent 038a707 commit 3543c37

File tree

3 files changed

+57
-2
lines changed

3 files changed

+57
-2
lines changed

tensordict/base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10346,6 +10346,20 @@ def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking):
1034610346
untyped_storage = storage_cast.untyped_storage()
1034710347

1034810348
def set_(x):
10349+
if x.is_nested:
10350+
if x.layout != torch.jagged:
10351+
raise RuntimeError(
10352+
"to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
10353+
"Please raise an issue on GitHub."
10354+
)
10355+
values = x._values
10356+
lengths = x._lengths
10357+
offsets = x._offsets
10358+
return torch.nested.nested_tensor_from_jagged(
10359+
set_(values),
10360+
offsets=set_(offsets),
10361+
lengths=set_(lengths) if lengths is not None else None,
10362+
)
1034910363
storage_offset = x.storage_offset()
1035010364
stride = x.stride()
1035110365
return torch.empty_like(x, device=device).set_(

tensordict/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1542,8 +1542,8 @@ def assert_close(
15421542
elif not isinstance(input1, torch.Tensor):
15431543
continue
15441544
if input1.is_nested:
1545-
input1 = input1._base
1546-
input2 = input2._base
1545+
input1 = input1.to_padded_tensor(0)
1546+
input2 = input2.to_padded_tensor(0)
15471547
mse = (input1.to(torch.float) - input2.to(torch.float)).pow(2).sum()
15481548
mse = mse.div(input1.numel()).sqrt().item()
15491549

test/test_tensordict.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7963,6 +7963,47 @@ def test_consolidate_to_device(self):
79637963
assert td_c_device["d"] == [["a string!"] * 3]
79647964
assert len(dataptrs) == 1
79657965

7966+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device detected")
7967+
def test_consolidate_to_device_njt(self):
7968+
td = TensorDict(
7969+
{
7970+
"a": torch.arange(3).expand(4, 3).clone(),
7971+
"d": "a string!",
7972+
"njt": torch.nested.nested_tensor_from_jagged(
7973+
torch.arange(10), offsets=torch.tensor([0, 2, 5, 8, 10])
7974+
),
7975+
"njt_lengths": torch.nested.nested_tensor_from_jagged(
7976+
torch.arange(10),
7977+
offsets=torch.tensor([0, 2, 5, 8, 10]),
7978+
lengths=torch.tensor([2, 3, 3, 2]),
7979+
),
7980+
},
7981+
device="cpu",
7982+
batch_size=[4],
7983+
)
7984+
device = torch.device("cuda:0")
7985+
td_c = td.consolidate()
7986+
assert td_c.device == torch.device("cpu")
7987+
td_c_device = td_c.to(device)
7988+
assert td_c_device.device == device
7989+
assert td_c_device.is_consolidated()
7990+
dataptrs = set()
7991+
for tensor in td_c_device.values(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS):
7992+
assert tensor.device == device
7993+
if tensor.is_nested:
7994+
vals = tensor._values
7995+
dataptrs.add(vals.untyped_storage().data_ptr())
7996+
offsets = tensor._offsets
7997+
dataptrs.add(offsets.untyped_storage().data_ptr())
7998+
lengths = tensor._lengths
7999+
if lengths is not None:
8000+
dataptrs.add(lengths.untyped_storage().data_ptr())
8001+
else:
8002+
dataptrs.add(tensor.untyped_storage().data_ptr())
8003+
assert len(dataptrs) == 1
8004+
assert assert_allclose_td(td_c_device.cpu(), td)
8005+
assert td_c_device["njt_lengths"]._lengths is not None
8006+
79668007
def test_create_empty(self):
79678008
td = LazyStackedTensorDict(stack_dim=0)
79688009
assert td.device is None

0 commit comments

Comments
 (0)