@@ -437,20 +437,6 @@ def _format_path(cls, path: pytree.KeyPath) -> str:
437437 return formatted_path
438438
439439 def __repr__ (self ) -> str :
440- # Infer device for representation if not set (this part is unchanged)
441- device_repr = self .device
442- if device_repr is None :
443- try :
444- # Ensure there are leaves before trying to access device
445- # pytree.tree_leaves can return an empty list
446- leaves = pytree .tree_leaves (self )
447- if leaves :
448- device_repr = leaves [0 ].device
449- except IndexError : # Should not happen if leaves is checked
450- pass
451- except Exception : # Catch any other pytree or attribute errors
452- pass
453-
454440 # Use a consistent indent of 4 spaces, which is standard
455441 indent = " "
456442
@@ -480,7 +466,7 @@ def _format_item(key, value):
480466 return (
481467 f"{ self .__class__ .__name__ } (\n "
482468 f"{ indent } shape={ str (self .shape )} ,\n "
483- f"{ indent } device={ device_repr } ,\n "
469+ f"{ indent } device={ self . device } ,\n "
484470 f"{ indent } items=\n { textwrap .indent (indented_items , indent )} \n { indent } \n "
485471 f")"
486472 )
@@ -622,7 +608,9 @@ def reshape(self: Self, *shape: int) -> Self:
622608
623609 def to (self : Self , * args , ** kwargs ) -> Self :
624610 with TensorContainer .unsafe_construction ():
625- tc = TensorContainer ._tree_map (lambda x : x .to (* args , ** kwargs ), self )
611+ leaves , context = self ._pytree_flatten ()
612+ leaves = [leaf .to (* args , ** kwargs ) for leaf in leaves ]
613+ tc = self ._pytree_unflatten (leaves , context )
626614
627615 device = self .device
628616
0 commit comments