|
23 | 23 |
|
24 | 24 | from concurrent.futures import Future, ThreadPoolExecutor, wait |
25 | 25 | from copy import copy |
26 | | -from functools import partial, wraps |
| 26 | +from functools import wraps |
27 | 27 | from pathlib import Path |
28 | 28 | from textwrap import indent |
29 | 29 | from threading import Thread |
@@ -2187,7 +2187,6 @@ def _from_dict_validated(cls, *args, **kwargs): |
2187 | 2187 |
|
2188 | 2188 | By default, falls back on :meth:`~.from_dict`. |
2189 | 2189 | """ |
2190 | | - kwargs.setdefault("auto_batch_size", True) |
2191 | 2190 | return cls.from_dict(*args, **kwargs) |
2192 | 2191 |
|
2193 | 2192 | @abc.abstractmethod |
@@ -4993,8 +4992,15 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata): |
4993 | 4992 |
|
4994 | 4993 | if requires_metadata: |
4995 | 4994 | # metadata is nested |
| 4995 | + cls = type(self) |
| 4996 | + from tensordict._reductions import CLS_MAP |
| 4997 | + |
| 4998 | + if cls.__name__ in CLS_MAP: |
| 4999 | + cls = cls.__name__ |
| 5000 | + else: |
| 5001 | + pass |
4996 | 5002 | metadata_dict = { |
4997 | | - "cls": type(self).__name__, |
| 5003 | + "cls": cls, |
4998 | 5004 | "non_tensors": {}, |
4999 | 5005 | "leaves": {}, |
5000 | 5006 | "cls_metadata": self._reduce_get_metadata(), |
@@ -5054,18 +5060,27 @@ def assign( |
5054 | 5060 | elif _is_tensor_collection(cls): |
5055 | 5061 | metadata_dict_key = None |
5056 | 5062 | if requires_metadata: |
| 5063 | + from tensordict._reductions import CLS_MAP |
| 5064 | + |
| 5065 | + if cls.__name__ in CLS_MAP: |
| 5066 | + cls = cls.__name__ |
| 5067 | + else: |
| 5068 | + pass |
5057 | 5069 | metadata_dict_key = metadata_dict[key] = { |
5058 | | - "cls": cls.__name__, |
| 5070 | + "cls": cls, |
5059 | 5071 | "non_tensors": {}, |
5060 | 5072 | "leaves": {}, |
5061 | 5073 | "cls_metadata": value._reduce_get_metadata(), |
5062 | 5074 | } |
5063 | | - local_assign = partial( |
5064 | | - assign, |
5065 | | - track_key=total_key, |
5066 | | - metadata_dict=metadata_dict_key, |
5067 | | - flat_size=flat_size, |
5068 | | - ) |
| 5075 | + |
| 5076 | + def local_assign(*t): |
| 5077 | + return assign( |
| 5078 | + *t, |
| 5079 | + track_key=total_key, |
| 5080 | + metadata_dict=metadata_dict_key, |
| 5081 | + flat_size=flat_size, |
| 5082 | + ) |
| 5083 | + |
5069 | 5084 | value._fast_apply( |
5070 | 5085 | local_assign, |
5071 | 5086 | named=True, |
@@ -5253,7 +5268,15 @@ def consolidate( |
5253 | 5268 | storage.share_memory_() |
5254 | 5269 | else: |
5255 | 5270 | # Convert the dict to json |
5256 | | - metadata_dict_json = json.dumps(metadata_dict) |
| 5271 | + try: |
| 5272 | + metadata_dict_json = json.dumps(metadata_dict) |
| 5273 | + except TypeError as e: |
| 5274 | + raise RuntimeError( |
| 5275 | + "Failed to convert the metatdata to json. " |
| 5276 | + "This is usually due to a nested class that is unaccounted for by the serializer, " |
| 5277 | + "such as custom TensorClass. " |
| 5278 | + "If you encounter this error, please file an issue on github." |
| 5279 | + ) from e |
5257 | 5280 | # Represent as a tensor |
5258 | 5281 | metadata_dict_json = torch.as_tensor( |
5259 | 5282 | bytearray(metadata_dict_json), dtype=torch.uint8 |
|
0 commit comments