3535 LazyStackedTensorDict ,
3636 MemoryMappedTensor ,
3737 set_capture_non_tensor_stack ,
38+ set_list_to_stack ,
3839 tensorclass ,
3940 TensorClass ,
4041 TensorDict ,
@@ -1032,19 +1033,28 @@ class MyDataParent:
10321033 assert data .y .v == "test_nested"
10331034 assert data .y .batch_size == torch .Size (batch_size )
10341035
1035- def test_indexing (self ):
1036- @tensorclass
1037- class MyDataNested :
1038- X : torch .Tensor
1039- z : list
1040- y : "MyDataNested" = None
1041-
1042- X = torch .ones (3 , 4 , 5 )
1043- z = ["a" , "b" , "c" ]
1044- batch_size = [3 , 4 ]
1045- data_nest = MyDataNested (X = X , z = z , batch_size = batch_size )
1046- data = MyDataNested (X = X , y = data_nest , z = z , batch_size = batch_size )
1036+ @pytest .mark .parametrize ("list_to_stack" , [True , False ])
1037+ def test_indexing (self , list_to_stack ):
1038+ with set_list_to_stack (list_to_stack ):
10471039
1040+ @tensorclass
1041+ class MyDataNested :
1042+ X : torch .Tensor
1043+ z : list
1044+ y : "MyDataNested" = None
1045+
1046+ X = torch .ones (3 , 4 , 5 )
1047+ z = ["a" , "b" , "c" ]
1048+ batch_size = [3 , 4 ]
1049+ with (
1050+ pytest .raises (RuntimeError , match = "batch dimension mismatch" )
1051+ if list_to_stack
1052+ else contextlib .nullcontext ()
1053+ ):
1054+ data_nest = MyDataNested (X = X , z = z , batch_size = batch_size )
1055+ data = MyDataNested (X = X , y = data_nest , z = z , batch_size = batch_size )
1056+ if list_to_stack :
1057+ return
10481058 assert data [:2 ].batch_size == torch .Size ([2 , 4 ])
10491059 assert data [:2 ].X .shape == torch .Size ([2 , 4 , 5 ])
10501060 assert (data [:2 ].X == X [:2 ]).all ()
@@ -1462,6 +1472,21 @@ class Data:
14621472 assert (data_select == 1 ).all ()
14631473 assert "a" in data_select ._tensordict
14641474
1475+ @set_list_to_stack (True )
1476+ def test_set_list_in_constructor (self ):
1477+ obj = MyTensorClass (
1478+ a = ["a string" , "another string" ],
1479+ b = [torch .randn (3 ), torch .zeros (3 )],
1480+ c = "smth completly different" ,
1481+ batch_size = 2 ,
1482+ )
1483+ assert obj .shape == (2 ,)
1484+ assert obj [0 ].a == "a string"
1485+ assert obj [1 ].a == "another string"
1486+ assert (obj [0 ].b != 0 ).all ()
1487+ assert (obj [1 ].b == 0 ).all ()
1488+ assert obj .c == obj [0 ].c
1489+
14651490 def test_set_dict (self ):
14661491 @tensorclass (autocast = True )
14671492 class MyClass :
@@ -1540,7 +1565,8 @@ class MyDataParent:
15401565 # ensure optional fields are writable
15411566 data .k = torch .zeros (3 , 4 , 5 )
15421567
1543- def test_setitem (self ):
1568+ @pytest .mark .parametrize ("list_to_stack" , [True , False ])
1569+ def test_setitem (self , list_to_stack ):
15441570 data = MyData (
15451571 X = torch .ones (3 , 4 , 5 ),
15461572 y = torch .zeros (3 , 4 , 5 ),
@@ -1599,26 +1625,34 @@ class MyDataNested:
15991625 X = torch .randn (3 , 4 , 5 )
16001626 z = ["a" , "b" , "c" ]
16011627 batch_size = [3 , 4 ]
1602- data_nest = MyDataNested (X = X , z = z , batch_size = batch_size )
1603- data = MyDataNested (X = X , y = data_nest , z = z , batch_size = batch_size )
1604- X2 = torch .ones (3 , 4 , 5 )
1605- data_nest2 = MyDataNested (X = X2 , z = z , batch_size = batch_size )
1606- data2 = MyDataNested (X = X2 , y = data_nest2 , z = z , batch_size = batch_size )
1607- data [:2 ] = data2 [:2 ].clone ()
1608- assert (data [:2 ].X == data2 [:2 ].X ).all ()
1609- assert (data [:2 ].y .X == data2 [:2 ].y .X ).all ()
1610- assert data [:2 ].z == z
1611-
1612- # Negative Scenario
1613- data3 = MyDataNested (X = X2 , y = data_nest2 , z = ["e" , "f" ], batch_size = batch_size )
1614- data [:2 ] = data3 [:2 ]
1615- assert data [:2 ].z == data3 [:2 ]._get_str ("z" , None ).tolist ()
1628+ with set_list_to_stack (list_to_stack ), (
1629+ pytest .raises (RuntimeError , match = "batch dimension mismatch" )
1630+ if list_to_stack
1631+ else contextlib .nullcontext ()
1632+ ):
1633+ data_nest = MyDataNested (X = X , z = z , batch_size = batch_size )
1634+ data = MyDataNested (X = X , y = data_nest , z = z , batch_size = batch_size )
1635+ X2 = torch .ones (3 , 4 , 5 )
1636+ data_nest2 = MyDataNested (X = X2 , z = z , batch_size = batch_size )
1637+ data2 = MyDataNested (X = X2 , y = data_nest2 , z = z , batch_size = batch_size )
1638+ data [:2 ] = data2 [:2 ].clone ()
1639+ assert (data [:2 ].X == data2 [:2 ].X ).all ()
1640+ assert (data [:2 ].y .X == data2 [:2 ].y .X ).all ()
1641+ assert data [:2 ].z == z
1642+
1643+ # Negative Scenario
1644+ data3 = MyDataNested (
1645+ X = X2 , y = data_nest2 , z = ["e" , "f" ], batch_size = batch_size
1646+ )
1647+ data [:2 ] = data3 [:2 ]
1648+ assert data [:2 ].z == data3 [:2 ]._get_str ("z" , None ).tolist ()
16161649
16171650 @pytest .mark .parametrize (
16181651 "broadcast_type" ,
16191652 ["scalar" , "tensor" , "tensordict" , "maptensor" ],
16201653 )
1621- def test_setitem_broadcast (self , broadcast_type ):
1654+ @pytest .mark .parametrize ("list_to_stack" , [True , False ])
1655+ def test_setitem_broadcast (self , broadcast_type , list_to_stack ):
16221656 @tensorclass
16231657 class MyDataNested :
16241658 X : torch .Tensor
@@ -1628,22 +1662,27 @@ class MyDataNested:
16281662 X = torch .ones (3 , 4 , 5 )
16291663 z = ["a" , "b" , "c" ]
16301664 batch_size = [3 , 4 ]
1631- data_nest = MyDataNested (X = X , z = z , batch_size = batch_size )
1632- data = MyDataNested (X = X , y = data_nest , z = z , batch_size = batch_size )
1633-
1634- if broadcast_type == "scalar" :
1635- val = 0
1636- elif broadcast_type == "tensor" :
1637- val = torch .zeros (4 , 5 )
1638- elif broadcast_type == "tensordict" :
1639- val = TensorDict ({"X" : torch .zeros (2 , 4 , 5 )}, batch_size = [2 , 4 ])
1640- elif broadcast_type == "maptensor" :
1641- val = MemoryMappedTensor .from_tensor (torch .zeros (4 , 5 ))
1642-
1643- data [:2 ] = val
1644- assert (data [:2 ] == 0 ).all ()
1645- assert (data .X [:2 ] == 0 ).all ()
1646- assert (data .y .X [:2 ] == 0 ).all ()
1665+ with set_list_to_stack (list_to_stack ), (
1666+ pytest .raises (RuntimeError , match = "batch dimension mismatch" )
1667+ if list_to_stack
1668+ else contextlib .nullcontext ()
1669+ ):
1670+ data_nest = MyDataNested (X = X , z = z , batch_size = batch_size )
1671+ data = MyDataNested (X = X , y = data_nest , z = z , batch_size = batch_size )
1672+
1673+ if broadcast_type == "scalar" :
1674+ val = 0
1675+ elif broadcast_type == "tensor" :
1676+ val = torch .zeros (4 , 5 )
1677+ elif broadcast_type == "tensordict" :
1678+ val = TensorDict ({"X" : torch .zeros (2 , 4 , 5 )}, batch_size = [2 , 4 ])
1679+ elif broadcast_type == "maptensor" :
1680+ val = MemoryMappedTensor .from_tensor (torch .zeros (4 , 5 ))
1681+
1682+ data [:2 ] = val
1683+ assert (data [:2 ] == 0 ).all ()
1684+ assert (data .X [:2 ] == 0 ).all ()
1685+ assert (data .y .X [:2 ] == 0 ).all ()
16471686
16481687 def test_setitem_memmap (self ):
16491688 # regression test PR #203
0 commit comments