Skip to content

Commit 233da76

Browse files
authored
Merge pull request #20 from mctigger/fix-mixed-shape-empty-container
refactor(container): centralize pytree mapping and fixes operations o…
2 parents 8a36497 + 4a6e86b commit 233da76

File tree

8 files changed

+109
-151
lines changed

8 files changed

+109
-151
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "tensorcontainer"
7-
version = "0.7.1"
7+
version = "0.8.0"
88
description = "TensorDict-like functionality for PyTorch with PyTree compatibility and torch.compile support"
99
authors = [{name="Tim Joseph", email="tim@mctigger.com"}]
1010
license = {text = "MIT"}

src/tensorcontainer/tensor_container.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,29 @@ def wrapped_func(keypath, x, *xs):
312312
message = f"Error at path {path}: {type(e).__name__}: {e}"
313313
raise type(e)(message) from e
314314

315-
return pytree.tree_map_with_path(wrapped_func, tree, *rests, is_leaf=is_leaf)
315+
return cls.tree_map_with_path(wrapped_func, tree, *rests, is_leaf=is_leaf)
316+
317+
@classmethod
318+
def tree_map_with_path(
319+
cls,
320+
func: Callable[..., Any],
321+
tree: PyTree,
322+
*rests: PyTree,
323+
is_leaf: Optional[Callable[[PyTree], bool]] = None,
324+
) -> PyTree:
325+
# This is copied from pytree.tree_map_with_path()
326+
# We add the check for no leaves as operations are currently no supported for
327+
# empty TensorContainers.
328+
keypath_leaves, treespec = pytree.tree_flatten_with_path(tree, is_leaf)
329+
330+
if len(keypath_leaves) == 0:
331+
raise RuntimeError(
332+
"TensorContainer does not allow operations on containers without leaves (i.e. not containing any tensors)."
333+
)
334+
335+
keypath_leaves = list(zip(*keypath_leaves))
336+
all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests]
337+
return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves))
316338

317339
@classmethod
318340
def _is_shape_compatible(cls, parent: TensorContainer, child: TCCompatible):
@@ -358,7 +380,7 @@ def ndim(self):
358380
# --- Overloaded methods leveraging PyTrees ---
359381

360382
def copy(self) -> Self:
361-
return pytree.tree_map(lambda x: x, self)
383+
return self._tree_map(lambda x: x, self)
362384

363385
def get_number_of_consuming_dims(self, item) -> int:
364386
if item is Ellipsis or item is None:

src/tensorcontainer/tensor_dict.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
class TensorDictPytreeContext(NamedTuple):
5050
keys: Tuple[str, ...]
5151
event_ndims: Tuple[int, ...]
52-
shape_context: torch.Size
5352
device_context: torch.device | None
5453
metadata: Dict[str, Any]
5554

@@ -160,9 +159,7 @@ def _get_pytree_context(
160159
"""
161160
batch_ndim = len(self.shape)
162161
event_ndims = tuple(leaf.ndim - batch_ndim for leaf in flat_leaves)
163-
return TensorDictPytreeContext(
164-
tuple(keys), event_ndims, self.shape, self.device, metadata
165-
)
162+
return TensorDictPytreeContext(tuple(keys), event_ndims, self.device, metadata)
166163

167164
def _pytree_flatten(
168165
self,
@@ -229,7 +226,7 @@ def _pytree_unflatten(
229226
from the context. The device is restored from the context.
230227
"""
231228
# Unpack context tuple
232-
keys, event_ndims, shape_context, device_context, metadata = context
229+
keys, event_ndims, device_context, metadata = context
233230

234231
obj = cls.__new__(cls)
235232
obj.device = device_context
@@ -241,22 +238,15 @@ def _pytree_unflatten(
241238
data.update(metadata)
242239
obj.data = data
243240

244-
if not leaves_list:
245-
# Empty case - use shape from context
246-
obj.shape = shape_context
247-
return obj
248-
249241
first_leaf = leaves_list[0]
250242

251243
# Infer batch shape from first leaf and event_ndims
252244
if (
253245
event_ndims and event_ndims[0] == 0
254246
): # Leaf was a scalar or had only batch dimensions originally
255247
reconstructed_shape = first_leaf.shape
256-
elif event_ndims: # Leaf had event dimensions originally
248+
else: # Leaf had event dimensions originally
257249
reconstructed_shape = first_leaf.shape[: -event_ndims[0]]
258-
else: # No leaves with event_ndims, use context
259-
reconstructed_shape = shape_context
260250

261251
obj.shape = reconstructed_shape
262252

tests/conftest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,16 @@ def pytest_configure(config):
8383
)
8484
def device(request):
8585
return torch.device(request.param)
86+
87+
88+
@pytest.fixture(autouse=True)
89+
def dynamo_reset():
90+
"""
91+
A pytest fixture that automatically resets torch._dynamo state
92+
before and after every test function.
93+
"""
94+
# Code before the test runs
95+
torch._dynamo.reset()
96+
yield
97+
# Code after the test runs (optional cleanup)
98+
torch._dynamo.reset()

tests/tensor_dict/test_cat.py

Lines changed: 26 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,25 @@
33

44
from tensorcontainer.tensor_dict import TensorDict # adjust import as needed
55
from tests.conftest import skipif_no_compile
6-
from tests.tensor_dict import common
7-
from tests.tensor_dict.common import compare_nested_dict, compute_cat_shape
86

9-
nested_dict = common.nested_dict
7+
8+
def create_nested_dict(shape):
9+
a = torch.rand(*shape)
10+
b = torch.rand(*shape)
11+
y = torch.rand(*shape)
12+
return {"x": {"a": a, "b": b}, "y": y}
13+
1014

1115
# Define parameter sets
1216
SHAPE_DIM_PARAMS_VALID = [
1317
# 1D
14-
((4,), 0),
15-
((4,), -1),
18+
((4,), (4,), 0, (8,)),
19+
((4,), (4,), -1, (8,)),
1620
# 2D
17-
((2, 2), 0),
18-
((2, 2), 1),
19-
((2, 2), -1),
20-
((1, 4), 0),
21-
((1, 4), 1),
22-
((1, 4), -2),
23-
# 3D
24-
((2, 1, 2), 0),
25-
((2, 1, 2), 1),
26-
((2, 1, 2), 2),
27-
((2, 1, 2), -1),
28-
((2, 1, 2), -3),
21+
((2, 2), (3, 2), 0, (5, 2)),
22+
((2, 2), (2, 3), 1, (2, 5)),
23+
((2, 2), (2, 3), -1, (2, 5)),
24+
((2, 2), (3, 2), -2, (5, 2)),
2925
]
3026

3127
SHAPE_DIM_PARAMS_INVALID = [
@@ -42,33 +38,24 @@
4238

4339

4440
# ——— Valid concatenation dims across several shapes ———
45-
@pytest.mark.parametrize("shape, dim", SHAPE_DIM_PARAMS_VALID)
46-
def test_cat_valid_eager(nested_dict, shape, dim):
47-
data = nested_dict(shape)
48-
td = TensorDict(data, shape)
41+
@pytest.mark.parametrize("shape1, shape2, dim, expected_shape", SHAPE_DIM_PARAMS_VALID)
42+
def test_cat_valid_eager(shape1, shape2, dim, expected_shape):
43+
data1 = create_nested_dict(shape1)
44+
data2 = create_nested_dict(shape2)
4945

50-
def cat_operation(tensor_dict_instance, cat_dimension):
51-
return torch.cat(
52-
[tensor_dict_instance, tensor_dict_instance], dim=cat_dimension
53-
)
46+
td1 = TensorDict(data1, shape1)
47+
td2 = TensorDict(data2, shape2)
5448

55-
cat_td = cat_operation(td, dim)
49+
cat_td = torch.cat([td1, td2], dim=dim)
5650

57-
# compute expected shape
58-
expected_shape = compute_cat_shape(shape, dim)
5951
assert cat_td.shape == expected_shape
6052

61-
# Compare nested structure and values
62-
# The lambda for comparison should always use eager torch.cat on original tensor data
63-
compare_nested_dict(
64-
data, cat_td, lambda orig_tensor: torch.cat([orig_tensor, orig_tensor], dim=dim)
65-
)
66-
6753

6854
# ——— Error on invalid dims ———
6955
@pytest.mark.parametrize("shape, dim", SHAPE_DIM_PARAMS_INVALID)
70-
def test_cat_invalid_dim_raises_eager(shape, dim, nested_dict):
71-
td = TensorDict(nested_dict(shape), shape)
56+
def test_cat_invalid_dim_raises_eager(shape, dim):
57+
data = create_nested_dict(shape)
58+
td = TensorDict(data, shape)
7259

7360
def cat_operation(tensor_dict_instance, cat_dimension):
7461
# This is the operation that is expected to raise an error
@@ -82,8 +69,9 @@ def cat_operation(tensor_dict_instance, cat_dimension):
8269

8370
@skipif_no_compile
8471
@pytest.mark.parametrize("shape, dim", SHAPE_DIM_PARAMS_INVALID)
85-
def test_cat_invalid_dim_raises_compile(shape, dim, nested_dict):
86-
td = TensorDict(nested_dict(shape), shape)
72+
def test_cat_invalid_dim_raises_compile(shape, dim):
73+
data = create_nested_dict(shape)
74+
td = TensorDict(data, shape)
8775

8876
def cat_operation(tensor_dict_instance, cat_dimension):
8977
# This is the operation that is expected to raise an error

tests/tensor_dict/test_copy.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -152,34 +152,6 @@ def copy_fn(td):
152152
assert "c" not in td["x"]
153153

154154

155-
def test_copy_of_empty_tensor_dict(nested_dict):
156-
# an empty dict should still copy correctly
157-
td = TensorDict({}, shape=())
158-
td_copy = td.copy()
159-
assert isinstance(td_copy, TensorDict)
160-
assert td_copy is not td
161-
assert td_copy.shape == torch.Size([])
162-
assert len(td_copy) == 0
163-
164-
165-
@skipif_no_compile
166-
def test_copy_of_empty_tensor_dict_compiled():
167-
"""Test that copying an empty TensorDict works with torch.compile."""
168-
169-
def copy_empty_td(td):
170-
return td.copy()
171-
172-
td = TensorDict({}, shape=())
173-
174-
eager_result, compiled_result = run_and_compare_compiled(copy_empty_td, td)
175-
176-
# Additional checks specific to empty TensorDict
177-
assert isinstance(eager_result, TensorDict)
178-
assert eager_result is not td
179-
assert eager_result.shape == torch.Size([])
180-
assert len(eager_result) == 0
181-
182-
183155
def test_copy_with_pytree(nested_dict):
184156
data = nested_dict((2, 2))
185157
td = TensorDict(data, shape=(2, 2))

tests/tensor_dict/test_metadata.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -84,24 +84,3 @@ def test_nested_tensordict_with_metadata(self):
8484
td_doubled["nested"]["nested_tensor"], torch.ones(4, 2) * 2
8585
)
8686
assert td_doubled["nested"]["nested_meta"] == "level2"
87-
88-
def test_metadata_only_tensordict(self):
89-
"""
90-
Tests the edge case where a TensorDict contains no tensors at all, only
91-
metadata. Pytree operations should not alter it.
92-
"""
93-
td = TensorDict({"meta1": "a", "meta2": 123}, shape=(4,))
94-
td_unchanged = tree_map(lambda x: x * 2, td)
95-
96-
assert td_unchanged.data == td.data
97-
98-
def test_empty_tensordict(self):
99-
"""
100-
Tests that an empty TensorDict remains empty and handles pytree
101-
operations gracefully without errors.
102-
"""
103-
td = TensorDict({}, shape=(4,))
104-
td_unchanged = tree_map(lambda x: x * 2, td)
105-
106-
assert len(td_unchanged) == 0
107-
assert td_unchanged.shape == (4,)

0 commit comments

Comments
 (0)