@@ -180,6 +180,7 @@ def test_multi_nested_tensor_basics(device):
180180 # Test multi_nested_tensor[List[int]] indexing
181181 for index in [[4 ], [2 , 2 ], [- 4 , 1 , 7 ], [3 , - 7 , 1 , 0 ], []]:
182182 multi_nested_tensor_indexed = multi_nested_tensor [index ]
183+ assert multi_nested_tensor_indexed .dtype == torch .long
183184 assert multi_nested_tensor_indexed .shape [0 ] == len (index )
184185 assert multi_nested_tensor_indexed .shape [1 ] == num_cols
185186 for i , idx in enumerate (index ):
@@ -208,8 +209,10 @@ def test_multi_nested_tensor_basics(device):
208209
209210 # Test column List[int] indexing
210211 for index in [[4 ], [2 , 2 ], [- 4 , 1 , 7 ], [3 , - 7 , 1 , 0 ], []]:
212+ multi_nested_tensor_indexed = multi_nested_tensor [:, index ]
211213 assert_equal (column_select (tensor_mat , index ),
212- multi_nested_tensor [:, index ])
214+ multi_nested_tensor_indexed )
215+ assert multi_nested_tensor_indexed .dtype == torch .long
213216
214217 # Test column-wise Boolean masking
215218 for index in [[4 ], [2 , 3 ], [0 , 1 , 7 ], []]:
@@ -245,6 +248,7 @@ def test_multi_nested_tensor_basics(device):
245248 empty_multi_nested_tensor = multi_nested_tensor [:, 5 :3 ]
246249 assert empty_multi_nested_tensor .shape [0 ] == num_rows
247250 assert empty_multi_nested_tensor .shape [1 ] == 0
251+ assert empty_multi_nested_tensor .dtype == torch .long
248252
249253 # Test column narrow
250254 assert_equal (column_select (tensor_mat , slice (3 , 3 + 2 )),
0 commit comments