Skip to content

Commit fea80fb

Browse files
author
Tim Joseph
committed
test(tensor-ops): refine error message matching and improve test cases
Update `pytest.raises` match strings for `IndexError` to use a more flexible regex (`Dimension .* out of range`), accounting for variations in error messages. Refactor `test_stack_inconsistent_shapes_raises` to correctly assert `RuntimeError` instead of `ValueError` and improve the test setup to explicitly create instances with incompatible batch shapes. Remove the redundant `test_cat_incompatible_batch_shapes` as its functionality is covered by `test_cat_inconsistent_shapes_raises`. Explicitly set `device=torch.device("cpu")` in `OptionalFieldsTestClass` test cases for clarity and consistency.
1 parent 03d8bc4 commit fea80fb

File tree

3 files changed

+39
-25
lines changed

3 files changed

+39
-25
lines changed

tests/tensor_dataclass/test_cat.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,8 @@ def test_cat_valid(self, nested_tensor_data_class, dim):
7272
def test_cat_invalid_dim_raises(self, nested_tensor_data_class, dim):
7373
"""Test cat operation raises with invalid event dimensions."""
7474
td1, td2 = self._create_test_pair(nested_tensor_data_class)
75-
with pytest.raises(IndexError, match="Dimension out of range"):
76-
r = self._cat_operation([td1, td2], dim)
77-
print(r.shape, r.tensor.shape)
75+
with pytest.raises(IndexError, match="Dimension .* out of range"):
76+
self._cat_operation([td1, td2], dim)
7877

7978
def test_cat_inconsistent_meta_data_raises(self, nested_tensor_data_class):
8079
"""Test cat operation raises with inconsistent metadata."""
@@ -124,17 +123,8 @@ def test_cat_dim_exceeds_batch_ndim(self, nested_tensor_data_class, dim_offset):
124123
"""Test cat operation raises IndexError when dim exceeds batch ndim."""
125124
td1, td2 = self._create_test_pair(nested_tensor_data_class)
126125
invalid_dim = td1.ndim + dim_offset
127-
with pytest.raises(IndexError, match="Dimension out of range"):
126+
with pytest.raises(IndexError, match="Dimension .* out of range"):
128127
self._cat_operation([td1, td2], invalid_dim)
129128

130-
def test_cat_incompatible_batch_shapes(self, nested_tensor_data_class):
131-
"""Test cat operation raises ValueError when batch shapes are incompatible."""
132-
td1, td2 = self._create_test_pair(nested_tensor_data_class)
133-
134-
# Manually set the shape attribute for the test
135-
td2.shape = (td2.shape[0], td2.shape[1] + 1)
136-
137-
with pytest.raises(
138-
ValueError, match="TensorContainer batch shapes must be identical"
139-
):
140-
self._cat_operation([td1, td2], 0)
129+
# Note: test_cat_inconsistent_shapes_raises already tests incompatible tensor shapes
130+
# This test is redundant and has been removed to avoid duplication

tests/tensor_dataclass/test_stack.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
OptionalFieldsTestClass,
2222
compute_stack_shape,
2323
)
24+
from tests.tensor_dataclass.conftest import FlatTensorDataClass
2425

2526

2627
def _stack_operation(tensor_dataclass_list, dim_arg):
@@ -89,16 +90,39 @@ def test_stack_valid_dims(self, nested_tensor_data_class, dim):
8990
def test_stack_invalid_dim_raises(self, nested_tensor_data_class, dim):
9091
"""Tests that stacking with an invalid dimension raises an IndexError."""
9192
td1, td2 = _create_test_pair(nested_tensor_data_class)
92-
with pytest.raises(IndexError, match="Dimension out of range"):
93+
with pytest.raises(IndexError, match="Dimension .* out of range"):
9394
_stack_operation([td1, td2], dim)
9495

9596
def test_stack_inconsistent_shapes_raises(self, nested_tensor_data_class):
96-
"""Tests that stacking instances with inconsistent shapes raises a ValueError."""
97-
td1, td2 = _create_test_pair(nested_tensor_data_class)
98-
# Create an inconsistent shape for one of the tensors.
99-
td2.shape = (td1.shape[0] + 1, td1.shape[1])
97+
"""Tests that stacking instances with inconsistent shapes raises a RuntimeError."""
98+
99+
# Get the first instance from the fixture
100+
td1 = nested_tensor_data_class
101+
102+
# Create a second instance with different batch shape (incompatible for stacking)
103+
# td1 has batch shape (2, 3), create td2 with batch shape (3, 3)
104+
different_batch_shape = (3, 3) # Different first dimension
105+
event_shape = (4, 5) # Same event shape as td1
106+
107+
# Create nested tensor with different batch shape
108+
flat_td2 = FlatTensorDataClass(
109+
tensor=torch.randn(*different_batch_shape, *event_shape, device=td1.device),
110+
meta_data=td1.tensor_data_class.meta_data, # Same metadata to avoid other errors
111+
shape=different_batch_shape,
112+
device=td1.device,
113+
)
114+
115+
# Create main tensor container with different batch shape
116+
td2 = NestedTensorDataClass(
117+
tensor=torch.randn(*different_batch_shape, *event_shape, device=td1.device),
118+
tensor_data_class=flat_td2,
119+
meta_data=td1.meta_data, # Same metadata to avoid other errors
120+
shape=different_batch_shape,
121+
device=td1.device,
122+
)
123+
100124
with pytest.raises(
101-
ValueError, match="stack expects each TensorContainer to be equal size"
125+
RuntimeError, match="stack expects each tensor to be equal size"
102126
):
103127
_stack_operation([td1, td2], 0)
104128

@@ -132,7 +156,7 @@ def test_stack_with_optional_tensor_as_none(self):
132156
def _test_stack_none():
133157
data1 = OptionalFieldsTestClass(
134158
shape=(4,),
135-
device=None,
159+
device=torch.device("cpu"),
136160
obs=torch.ones(4, 32, 32),
137161
reward=None,
138162
info=["step1"],
@@ -154,7 +178,7 @@ def test_stack_with_optional_tensor_as_tensor(self):
154178
def _test_stack_tensor():
155179
data1 = OptionalFieldsTestClass(
156180
shape=(4,),
157-
device=None,
181+
device=torch.device("cpu"),
158182
obs=torch.ones(4, 32, 32),
159183
reward=torch.ones(4),
160184
)
@@ -198,7 +222,7 @@ def test_default_factory_tensor_stack(self):
198222
def _test_default_factory_tensor():
199223
data1 = OptionalFieldsTestClass(
200224
shape=(4,),
201-
device=None,
225+
device=torch.device("cpu"),
202226
obs=torch.ones(4, 32, 32),
203227
reward=None,
204228
info=["step1"],

tests/tensor_dict/test_cat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,5 +83,5 @@ def cat_operation(tensor_dict_instance, cat_dimension):
8383
# are often wrapped in TorchRuntimeError.
8484
# We compile first, then expect the error upon execution of the compiled function.
8585
compiled_cat_op = torch.compile(cat_operation)
86-
with pytest.raises(IndexError, match="Dimension out of range"):
86+
with pytest.raises(IndexError, match="Dimension .* out of range"):
8787
compiled_cat_op(td, dim)

0 commit comments

Comments
 (0)