@@ -7963,6 +7963,47 @@ def test_consolidate_to_device(self):
79637963 assert td_c_device ["d" ] == [["a string!" ] * 3 ]
79647964 assert len (dataptrs ) == 1
79657965
7966+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "no cuda device detected" )
7967+ def test_consolidate_to_device_njt (self ):
7968+ td = TensorDict (
7969+ {
7970+ "a" : torch .arange (3 ).expand (4 , 3 ).clone (),
7971+ "d" : "a string!" ,
7972+ "njt" : torch .nested .nested_tensor_from_jagged (
7973+ torch .arange (10 ), offsets = torch .tensor ([0 , 2 , 5 , 8 , 10 ])
7974+ ),
7975+ "njt_lengths" : torch .nested .nested_tensor_from_jagged (
7976+ torch .arange (10 ),
7977+ offsets = torch .tensor ([0 , 2 , 5 , 8 , 10 ]),
7978+ lengths = torch .tensor ([2 , 3 , 3 , 2 ]),
7979+ ),
7980+ },
7981+ device = "cpu" ,
7982+ batch_size = [4 ],
7983+ )
7984+ device = torch .device ("cuda:0" )
7985+ td_c = td .consolidate ()
7986+ assert td_c .device == torch .device ("cpu" )
7987+ td_c_device = td_c .to (device )
7988+ assert td_c_device .device == device
7989+ assert td_c_device .is_consolidated ()
7990+ dataptrs = set ()
7991+ for tensor in td_c_device .values (True , True , is_leaf = _NESTED_TENSORS_AS_LISTS ):
7992+ assert tensor .device == device
7993+ if tensor .is_nested :
7994+ vals = tensor ._values
7995+ dataptrs .add (vals .untyped_storage ().data_ptr ())
7996+ offsets = tensor ._offsets
7997+ dataptrs .add (offsets .untyped_storage ().data_ptr ())
7998+ lengths = tensor ._lengths
7999+ if lengths is not None :
8000+ dataptrs .add (lengths .untyped_storage ().data_ptr ())
8001+ else :
8002+ dataptrs .add (tensor .untyped_storage ().data_ptr ())
8003+ assert len (dataptrs ) == 1
8004+ assert assert_allclose_td (td_c_device .cpu (), td )
8005+ assert td_c_device ["njt_lengths" ]._lengths is not None
8006+
79668007 def test_create_empty (self ):
79678008 td = LazyStackedTensorDict (stack_dim = 0 )
79688009 assert td .device is None
0 commit comments