Skip to content

Commit 80763ca

Browse files
authored
Merge pull request #4 from mctigger/device-fix
Fix TensorContainer.to() device propagation
2 parents a8eca59 + 3b24bb5 commit 80763ca

File tree

2 files changed

+5
-17
lines changed

2 files changed

+5
-17
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "tensorcontainer"
7-
version = "0.6.0"
7+
version = "0.6.1"
88
description = "TensorDict-like functionality for PyTorch with PyTree compatibility and torch.compile support"
99
authors = [{name="Tim Joseph", email="tim@mctigger.com"}]
1010
license = {text = "MIT"}

src/tensorcontainer/tensor_container.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -437,20 +437,6 @@ def _format_path(cls, path: pytree.KeyPath) -> str:
437437
return formatted_path
438438

439439
def __repr__(self) -> str:
440-
# Infer device for representation if not set (this part is unchanged)
441-
device_repr = self.device
442-
if device_repr is None:
443-
try:
444-
# Ensure there are leaves before trying to access device
445-
# pytree.tree_leaves can return an empty list
446-
leaves = pytree.tree_leaves(self)
447-
if leaves:
448-
device_repr = leaves[0].device
449-
except IndexError: # Should not happen if leaves is checked
450-
pass
451-
except Exception: # Catch any other pytree or attribute errors
452-
pass
453-
454440
# Use a consistent indent of 4 spaces, which is standard
455441
indent = " "
456442

@@ -480,7 +466,7 @@ def _format_item(key, value):
480466
return (
481467
f"{self.__class__.__name__}(\n"
482468
f"{indent}shape={str(self.shape)},\n"
483-
f"{indent}device={device_repr},\n"
469+
f"{indent}device={self.device},\n"
484470
f"{indent}items=\n{textwrap.indent(indented_items, indent)}\n{indent}\n"
485471
f")"
486472
)
@@ -622,7 +608,9 @@ def reshape(self: Self, *shape: int) -> Self:
622608

623609
def to(self: Self, *args, **kwargs) -> Self:
624610
with TensorContainer.unsafe_construction():
625-
tc = TensorContainer._tree_map(lambda x: x.to(*args, **kwargs), self)
611+
leaves, context = self._pytree_flatten()
612+
leaves = [leaf.to(*args, **kwargs) for leaf in leaves]
613+
tc = self._pytree_unflatten(leaves, context)
626614

627615
device = self.device
628616

0 commit comments

Comments
 (0)