@@ -11851,6 +11851,26 @@ def test_new_empty_nontensorstack(self):
1185111851 assert isinstance (td .new_empty ((4 ,), empty_lazy = True ).get ("a" ), NonTensorStack )
1185211852 assert isinstance (td .new_empty ((1 ,), empty_lazy = True ).get ("a" ), NonTensorStack )
1185311853
11854+ def test_new_empty_setitem (self ):
11855+ td = TensorDict (
11856+ a = TensorDict (
11857+ b = NonTensorStack ("a" , "b" , "c" ).unsqueeze (- 1 ), batch_size = (3 ,)
11858+ ),
11859+ batch_size = (3 ,),
11860+ ).to_lazystack ()
11861+ tdz = td .new_zeros ((4 ,), empty_lazy = True )
11862+ tdz [torch .tensor ([True , True , False , True ])] = td
11863+ assert tdz .get (("a" , "b" )).tolist () == [["a" ], ["b" ], ["a" ], ["c" ]]
11864+
11865+ def test_new_empty_setitem_2 (self ):
11866+ td = TensorDict (
11867+ a = TensorDict (b = NonTensorStack ("a" ), batch_size = (1 ,)), batch_size = (1 ,)
11868+ ).to_lazystack ()
11869+ tdz = td .new_zeros ((4 ,), empty_lazy = True )
11870+ td ["a" , "b" ] = "new"
11871+ tdz [torch .tensor ([False , False , False , True ])] = td
11872+ assert tdz ["a" , "b" ][- 1 ] == "new"
11873+
1185411874 def test_non_tensor_call (self ):
1185511875 td0 = TensorDict ({"a" : 0 , "b" : 0 })
1185611876 td1 = TensorDict ({"a" : 1 , "b" : 1 })
0 commit comments