Skip to content

Commit 30e0575

Browse files
author
Tim Joseph
committed
Fixes device handling: Recursively call .to() instead of using _tree_map as latter only applies to leaves (Tensor), not to TensorContainers
1 parent a8eca59 commit 30e0575

File tree

1 file changed

+5
-17
lines changed

1 file changed

+5
-17
lines changed

src/tensorcontainer/tensor_container.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -437,20 +437,6 @@ def _format_path(cls, path: pytree.KeyPath) -> str:
437437
return formatted_path
438438

439439
def __repr__(self) -> str:
440-
# Infer device for representation if not set (this part is unchanged)
441-
device_repr = self.device
442-
if device_repr is None:
443-
try:
444-
# Ensure there are leaves before trying to access device
445-
# pytree.tree_leaves can return an empty list
446-
leaves = pytree.tree_leaves(self)
447-
if leaves:
448-
device_repr = leaves[0].device
449-
except IndexError: # Should not happen if leaves is checked
450-
pass
451-
except Exception: # Catch any other pytree or attribute errors
452-
pass
453-
454440
# Use a consistent indent of 4 spaces, which is standard
455441
indent = " "
456442

@@ -480,7 +466,7 @@ def _format_item(key, value):
480466
return (
481467
f"{self.__class__.__name__}(\n"
482468
f"{indent}shape={str(self.shape)},\n"
483-
f"{indent}device={device_repr},\n"
469+
f"{indent}device={self.device},\n"
484470
f"{indent}items=\n{textwrap.indent(indented_items, indent)}\n{indent}\n"
485471
f")"
486472
)
@@ -622,8 +608,10 @@ def reshape(self: Self, *shape: int) -> Self:
622608

623609
def to(self: Self, *args, **kwargs) -> Self:
624610
with TensorContainer.unsafe_construction():
625-
tc = TensorContainer._tree_map(lambda x: x.to(*args, **kwargs), self)
626-
611+
leaves, context = self._pytree_flatten()
612+
leaves = [l.to(*args, **kwargs) for l in leaves]
613+
tc = self._pytree_unflatten(leaves, context)
614+
627615
device = self.device
628616

629617
is_device_in_args = len(args) > 0 and isinstance(args[0], (str, torch.device))

0 commit comments

Comments
 (0)