Skip to content

Commit 7528728

Browse files
author
Vincent Moens
committed
[Feature] TD+NJT to(device) support
ghstack-source-id: 7181249 Pull Request resolved: #1022
1 parent 5a50f89 commit 7528728

File tree

3 files changed

+63
-3
lines changed

3 files changed

+63
-3
lines changed

tensordict/base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10346,6 +10346,20 @@ def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking):
1034610346
untyped_storage = storage_cast.untyped_storage()
1034710347

1034810348
def set_(x):
10349+
if x.is_nested:
10350+
if x.layout != torch.jagged:
10351+
raise RuntimeError(
10352+
"to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
10353+
"Please raise an issue on GitHub."
10354+
)
10355+
values = x._values
10356+
lengths = x._lengths
10357+
offsets = x._offsets
10358+
return torch.nested.nested_tensor_from_jagged(
10359+
set_(values),
10360+
offsets=set_(offsets),
10361+
lengths=set_(lengths) if lengths is not None else None,
10362+
)
1034910363
storage_offset = x.storage_offset()
1035010364
stride = x.stride()
1035110365
return torch.empty_like(x, device=device).set_(

tensordict/utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1542,9 +1542,14 @@ def assert_close(
15421542
elif not isinstance(input1, torch.Tensor):
15431543
continue
15441544
if input1.is_nested:
1545-
input1 = input1._base
1546-
input2 = input2._base
1547-
mse = (input1.to(torch.float) - input2.to(torch.float)).pow(2).sum()
1545+
input1v = input1.values()
1546+
input2v = input2.values()
1547+
mse = (input1v.to(torch.float) - input2v.to(torch.float)).pow(2).sum()
1548+
input1o = input1.offsets()
1549+
input2o = input2.offsets()
1550+
mse = mse + (input1o.to(torch.float) - input2o.to(torch.float)).pow(2).sum()
1551+
else:
1552+
mse = (input1.to(torch.float) - input2.to(torch.float)).pow(2).sum()
15481553
mse = mse.div(input1.numel()).sqrt().item()
15491554

15501555
local_msg = f"key {key} does not match, got mse = {mse:4.4f}"

test/test_tensordict.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7963,6 +7963,47 @@ def test_consolidate_to_device(self):
79637963
assert td_c_device["d"] == [["a string!"] * 3]
79647964
assert len(dataptrs) == 1
79657965

7966+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device detected")
7967+
def test_consolidate_to_device_njt(self):
7968+
td = TensorDict(
7969+
{
7970+
"a": torch.arange(3).expand(4, 3).clone(),
7971+
"d": "a string!",
7972+
"njt": torch.nested.nested_tensor_from_jagged(
7973+
torch.arange(10), offsets=torch.tensor([0, 2, 5, 8, 10])
7974+
),
7975+
"njt_lengths": torch.nested.nested_tensor_from_jagged(
7976+
torch.arange(10),
7977+
offsets=torch.tensor([0, 2, 5, 8, 10]),
7978+
lengths=torch.tensor([2, 3, 3, 2]),
7979+
),
7980+
},
7981+
device="cpu",
7982+
batch_size=[4],
7983+
)
7984+
device = torch.device("cuda:0")
7985+
td_c = td.consolidate()
7986+
assert td_c.device == torch.device("cpu")
7987+
td_c_device = td_c.to(device)
7988+
assert td_c_device.device == device
7989+
assert td_c_device.is_consolidated()
7990+
dataptrs = set()
7991+
for tensor in td_c_device.values(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS):
7992+
assert tensor.device == device
7993+
if tensor.is_nested:
7994+
vals = tensor._values
7995+
dataptrs.add(vals.untyped_storage().data_ptr())
7996+
offsets = tensor._offsets
7997+
dataptrs.add(offsets.untyped_storage().data_ptr())
7998+
lengths = tensor._lengths
7999+
if lengths is not None:
8000+
dataptrs.add(lengths.untyped_storage().data_ptr())
8001+
else:
8002+
dataptrs.add(tensor.untyped_storage().data_ptr())
8003+
assert len(dataptrs) == 1
8004+
assert assert_allclose_td(td_c_device.cpu(), td)
8005+
assert td_c_device["njt_lengths"]._lengths is not None
8006+
79668007
def test_create_empty(self):
79678008
td = LazyStackedTensorDict(stack_dim=0)
79688009
assert td.device is None

0 commit comments

Comments
 (0)