33
44from tensorcontainer .tensor_dict import TensorDict # adjust import as needed
55from tests .conftest import skipif_no_compile
6- from tests .tensor_dict import common
7- from tests .tensor_dict .common import compare_nested_dict , compute_cat_shape
86
9- nested_dict = common .nested_dict
7+
8+ def create_nested_dict (shape ):
9+ a = torch .rand (* shape )
10+ b = torch .rand (* shape )
11+ y = torch .rand (* shape )
12+ return {"x" : {"a" : a , "b" : b }, "y" : y }
13+
1014
1115# Define parameter sets
1216SHAPE_DIM_PARAMS_VALID = [
1317 # 1D
14- ((4 ,), 0 ),
15- ((4 ,), - 1 ),
18+ ((4 ,), ( 4 ,), 0 , ( 8 ,) ),
19+ ((4 ,), ( 4 ,), - 1 , ( 8 ,) ),
1620 # 2D
17- ((2 , 2 ), 0 ),
18- ((2 , 2 ), 1 ),
19- ((2 , 2 ), - 1 ),
20- ((1 , 4 ), 0 ),
21- ((1 , 4 ), 1 ),
22- ((1 , 4 ), - 2 ),
23- # 3D
24- ((2 , 1 , 2 ), 0 ),
25- ((2 , 1 , 2 ), 1 ),
26- ((2 , 1 , 2 ), 2 ),
27- ((2 , 1 , 2 ), - 1 ),
28- ((2 , 1 , 2 ), - 3 ),
21+ ((2 , 2 ), (3 , 2 ), 0 , (5 , 2 )),
22+ ((2 , 2 ), (2 , 3 ), 1 , (2 , 5 )),
23+ ((2 , 2 ), (2 , 3 ), - 1 , (2 , 5 )),
24+ ((2 , 2 ), (3 , 2 ), - 2 , (5 , 2 )),
2925]
3026
3127SHAPE_DIM_PARAMS_INVALID = [
4238
4339
4440# ——— Valid concatenation dims across several shapes ———
45- @pytest .mark .parametrize ("shape, dim" , SHAPE_DIM_PARAMS_VALID )
46- def test_cat_valid_eager (nested_dict , shape , dim ):
47- data = nested_dict ( shape )
48- td = TensorDict ( data , shape )
41+ @pytest .mark .parametrize ("shape1, shape2, dim, expected_shape " , SHAPE_DIM_PARAMS_VALID )
42+ def test_cat_valid_eager (shape1 , shape2 , dim , expected_shape ):
43+ data1 = create_nested_dict ( shape1 )
44+ data2 = create_nested_dict ( shape2 )
4945
50- def cat_operation (tensor_dict_instance , cat_dimension ):
51- return torch .cat (
52- [tensor_dict_instance , tensor_dict_instance ], dim = cat_dimension
53- )
46+ td1 = TensorDict (data1 , shape1 )
47+ td2 = TensorDict (data2 , shape2 )
5448
55- cat_td = cat_operation ( td , dim )
49+ cat_td = torch . cat ([ td1 , td2 ], dim = dim )
5650
57- # compute expected shape
58- expected_shape = compute_cat_shape (shape , dim )
5951 assert cat_td .shape == expected_shape
6052
61- # Compare nested structure and values
62- # The lambda for comparison should always use eager torch.cat on original tensor data
63- compare_nested_dict (
64- data , cat_td , lambda orig_tensor : torch .cat ([orig_tensor , orig_tensor ], dim = dim )
65- )
66-
6753
6854# ——— Error on invalid dims ———
6955@pytest .mark .parametrize ("shape, dim" , SHAPE_DIM_PARAMS_INVALID )
70- def test_cat_invalid_dim_raises_eager (shape , dim , nested_dict ):
71- td = TensorDict (nested_dict (shape ), shape )
56+ def test_cat_invalid_dim_raises_eager (shape , dim ):
57+ data = create_nested_dict (shape )
58+ td = TensorDict (data , shape )
7259
7360 def cat_operation (tensor_dict_instance , cat_dimension ):
7461 # This is the operation that is expected to raise an error
@@ -82,8 +69,9 @@ def cat_operation(tensor_dict_instance, cat_dimension):
8269
8370@skipif_no_compile
8471@pytest .mark .parametrize ("shape, dim" , SHAPE_DIM_PARAMS_INVALID )
85- def test_cat_invalid_dim_raises_compile (shape , dim , nested_dict ):
86- td = TensorDict (nested_dict (shape ), shape )
72+ def test_cat_invalid_dim_raises_compile (shape , dim ):
73+ data = create_nested_dict (shape )
74+ td = TensorDict (data , shape )
8775
8876 def cat_operation (tensor_dict_instance , cat_dimension ):
8977 # This is the operation that is expected to raise an error
0 commit comments