@@ -2665,6 +2665,38 @@ def test_split_keys(self):
26652665 assert td is td3
26662666 assert "d" not in td
26672667
2668+ @pytest .mark .parametrize ("sign" , ["plus" , "minus" ])
2669+ def test_split_uneven (self , sign ):
2670+ a = torch .arange (6 ).unsqueeze (- 1 ).expand (6 , 3 )
2671+ b = torch .arange (18 ).view (6 , 3 )
2672+ c = torch .arange (36 ).view (6 , 3 , 2 )
2673+ td = TensorDict ({"a" : a , "b" : b , "c" : c }, [6 , 3 ])
2674+
2675+ if sign == "plus" :
2676+ tds = td .split (5 , 0 )
2677+ else :
2678+ tds = td .split (5 , - 2 )
2679+ assert tds [0 ].shape == torch .Size ([5 , 3 ])
2680+ assert tds [1 ].shape == torch .Size ([1 , 3 ])
2681+ assert tds [0 ]["a" ].shape == torch .Size ([5 , 3 ])
2682+ assert tds [1 ]["a" ].shape == torch .Size ([1 , 3 ])
2683+ assert tds [0 ]["b" ].shape == torch .Size ([5 , 3 ])
2684+ assert tds [1 ]["b" ].shape == torch .Size ([1 , 3 ])
2685+ assert tds [0 ]["c" ].shape == torch .Size ([5 , 3 , 2 ])
2686+ assert tds [1 ]["c" ].shape == torch .Size ([1 , 3 , 2 ])
2687+ if sign == "plus" :
2688+ tds = td .split (2 , 1 )
2689+ else :
2690+ tds = td .split (2 , - 1 )
2691+ assert tds [0 ].shape == torch .Size ([6 , 2 ])
2692+ assert tds [1 ].shape == torch .Size ([6 , 1 ])
2693+ assert tds [0 ]["a" ].shape == torch .Size ([6 , 2 ])
2694+ assert tds [1 ]["a" ].shape == torch .Size ([6 , 1 ])
2695+ assert tds [0 ]["b" ].shape == torch .Size ([6 , 2 ])
2696+ assert tds [1 ]["b" ].shape == torch .Size ([6 , 1 ])
2697+ assert tds [0 ]["c" ].shape == torch .Size ([6 , 2 , 2 ])
2698+ assert tds [1 ]["c" ].shape == torch .Size ([6 , 1 , 2 ])
2699+
26682700 def test_setitem_nested (self ):
26692701 tensor = torch .randn (4 , 5 , 6 , 7 )
26702702 tensor2 = torch .ones (4 , 5 , 6 , 7 )
0 commit comments