Skip to content

Commit c88acce

Browse files
author
Vincent Moens
committed
[BugFix] Pass type directly during reduction
ghstack-source-id: 737f037 Pull Request resolved: #1225 (cherry picked from commit ad0a8dd)
1 parent 3e0c2d8 commit c88acce

File tree

4 files changed

+67
-40
lines changed

4 files changed

+67
-40
lines changed

tensordict/_lazy.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -381,17 +381,6 @@ def from_dict(
381381
stack_dim_name=None,
382382
stack_dim=0,
383383
):
384-
# if batch_size is not None:
385-
# batch_size = list(batch_size)
386-
# if stack_dim is None:
387-
# stack_dim = 0
388-
# n = batch_size.pop(stack_dim)
389-
# if n != len(input_dict):
390-
# raise ValueError(
391-
# "The number of dicts and the corresponding batch-size must match, "
392-
# f"got len(input_dict)={len(input_dict)} and batch_size[{stack_dim}]={n}."
393-
# )
394-
# batch_size = torch.Size(batch_size)
395384
return LazyStackedTensorDict(
396385
*(
397386
TensorDict.from_dict(
@@ -1992,11 +1981,22 @@ def _apply_nest(
19921981
if all(r is None for r in results) and filter_empty in (None, True):
19931982
return
19941983
if not inplace:
1995-
out = type(self)(
1996-
*results,
1997-
stack_dim=self.stack_dim,
1998-
stack_dim_name=self._td_dim_name,
1999-
)
1984+
if not results or any(r is not None for r in results):
1985+
try:
1986+
out = type(self)(
1987+
*results,
1988+
stack_dim=self.stack_dim,
1989+
stack_dim_name=self._td_dim_name,
1990+
)
1991+
except Exception as e:
1992+
raise RuntimeError(
1993+
f"Failed to reconstruct the lazy stack of tensordicts with class: {type(self)}. "
1994+
f"One common issue is that the outputs of apply are a mix of None and non-None "
1995+
f"values. Check that the outputs of apply() are all None or all non-None. "
1996+
f"Otherwise, please report this bug on tensordict github."
1997+
) from e
1998+
else:
1999+
out = None
20002000
else:
20012001
out = self
20022002
if names is not NO_DEFAULT:

tensordict/_reductions.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@
1111
from tensordict._lazy import LazyStackedTensorDict
1212
from tensordict._td import TensorDict
1313

14-
from tensordict.tensorclass import NonTensorData
15-
from tensordict.utils import _STRDTYPE2DTYPE
14+
from tensordict.tensorclass import NonTensorData, NonTensorStack
15+
from tensordict.utils import _is_tensorclass, _STRDTYPE2DTYPE
1616

1717
CLS_MAP = {
1818
"TensorDict": TensorDict,
1919
"LazyStackedTensorDict": LazyStackedTensorDict,
20+
"NonTensorData": NonTensorData,
21+
"NonTensorStack": NonTensorStack,
2022
}
2123

2224

@@ -57,7 +59,9 @@ def from_metadata(metadata=metadata_dict, prefix=None):
5759
d[k] = from_metadata(
5860
v, prefix=prefix + (k,) if prefix is not None else (k,)
5961
)
60-
result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata)
62+
if isinstance(cls, str):
63+
cls = CLS_MAP[cls]
64+
result = cls._from_dict_validated(d, **cls_metadata)
6165
if is_locked:
6266
result.lock_()
6367
# if is_shared:
@@ -121,10 +125,15 @@ def from_metadata(metadata=metadata, prefix=None):
121125
d[k] = from_metadata(
122126
v, prefix=prefix + (k,) if prefix is not None else (k,)
123127
)
124-
result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata)
128+
if isinstance(cls, str):
129+
cls = CLS_MAP[cls]
130+
result = cls._from_dict_validated(d, **cls_metadata)
125131
if is_locked:
126132
result = result.lock_()
127-
result._consolidated = consolidated
133+
if _is_tensorclass(cls):
134+
result._tensordict._consolidated = consolidated
135+
else:
136+
result._consolidated = consolidated
128137
return result
129138

130139
return from_metadata()

tensordict/_td.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2064,6 +2064,8 @@ def from_dict(
20642064
batch_dims=None,
20652065
names=None,
20662066
):
2067+
if _is_tensor_collection(type(input_dict)):
2068+
return input_dict
20672069
if others:
20682070
if batch_size is not None:
20692071
raise TypeError(
@@ -2120,14 +2122,7 @@ def from_dict(
21202122
)
21212123
if batch_size is None:
21222124
if auto_batch_size is None and batch_dims is None:
2123-
warn(
2124-
"The batch-size was not provided and auto_batch_size isn't set either. "
2125-
"Currently, from_dict will call set auto_batch_size=True but this behaviour "
2126-
"will be changed in v0.8 and auto_batch_size will be False onward. "
2127-
"To silence this warning, pass auto_batch_size directly.",
2128-
category=DeprecationWarning,
2129-
)
2130-
auto_batch_size = True
2125+
auto_batch_size = False
21312126
elif auto_batch_size is None:
21322127
auto_batch_size = True
21332128
if auto_batch_size:

tensordict/base.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from concurrent.futures import Future, ThreadPoolExecutor, wait
2525
from copy import copy
26-
from functools import partial, wraps
26+
from functools import wraps
2727
from pathlib import Path
2828
from textwrap import indent
2929
from threading import Thread
@@ -2187,7 +2187,6 @@ def _from_dict_validated(cls, *args, **kwargs):
21872187

21882188
By default, falls back on :meth:`~.from_dict`.
21892189
"""
2190-
kwargs.setdefault("auto_batch_size", True)
21912190
return cls.from_dict(*args, **kwargs)
21922191

21932192
@abc.abstractmethod
@@ -4993,8 +4992,15 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata):
49934992

49944993
if requires_metadata:
49954994
# metadata is nested
4995+
cls = type(self)
4996+
from tensordict._reductions import CLS_MAP
4997+
4998+
if cls.__name__ in CLS_MAP:
4999+
cls = cls.__name__
5000+
else:
5001+
pass
49965002
metadata_dict = {
4997-
"cls": type(self).__name__,
5003+
"cls": cls,
49985004
"non_tensors": {},
49995005
"leaves": {},
50005006
"cls_metadata": self._reduce_get_metadata(),
@@ -5054,18 +5060,27 @@ def assign(
50545060
elif _is_tensor_collection(cls):
50555061
metadata_dict_key = None
50565062
if requires_metadata:
5063+
from tensordict._reductions import CLS_MAP
5064+
5065+
if cls.__name__ in CLS_MAP:
5066+
cls = cls.__name__
5067+
else:
5068+
pass
50575069
metadata_dict_key = metadata_dict[key] = {
5058-
"cls": cls.__name__,
5070+
"cls": cls,
50595071
"non_tensors": {},
50605072
"leaves": {},
50615073
"cls_metadata": value._reduce_get_metadata(),
50625074
}
5063-
local_assign = partial(
5064-
assign,
5065-
track_key=total_key,
5066-
metadata_dict=metadata_dict_key,
5067-
flat_size=flat_size,
5068-
)
5075+
5076+
def local_assign(*t):
5077+
return assign(
5078+
*t,
5079+
track_key=total_key,
5080+
metadata_dict=metadata_dict_key,
5081+
flat_size=flat_size,
5082+
)
5083+
50695084
value._fast_apply(
50705085
local_assign,
50715086
named=True,
@@ -5253,7 +5268,15 @@ def consolidate(
52535268
storage.share_memory_()
52545269
else:
52555270
# Convert the dict to json
5256-
metadata_dict_json = json.dumps(metadata_dict)
5271+
try:
5272+
metadata_dict_json = json.dumps(metadata_dict)
5273+
except TypeError as e:
5274+
raise RuntimeError(
5275+
"Failed to convert the metatdata to json. "
5276+
"This is usually due to a nested class that is unaccounted for by the serializer, "
5277+
"such as custom TensorClass. "
5278+
"If you encounter this error, please file an issue on github."
5279+
) from e
52575280
# Represent as a tensor
52585281
metadata_dict_json = torch.as_tensor(
52595282
bytearray(metadata_dict_json), dtype=torch.uint8

0 commit comments

Comments
 (0)