|
93 | 93 | from torch.nn.utils._named_member_accessor import swap_tensor |
94 | 94 | from torch.utils._pytree import tree_map |
95 | 95 |
|
96 | | -try: |
97 | | - import orjson as json |
98 | | -except ImportError: |
99 | | - # Fallback |
100 | | - import json |
101 | 96 | try: |
102 | 97 | from functorch import dim as ftdim |
103 | 98 |
|
@@ -4175,14 +4170,17 @@ def save_metadata(prefix=prefix, self=self): |
4175 | 4170 | if not prefix.exists(): |
4176 | 4171 | os.makedirs(prefix, exist_ok=True) |
4177 | 4172 | with open(prefix / "meta.json", "wb") as f: |
4178 | | - f.write( |
4179 | | - json.dumps( |
4180 | | - { |
4181 | | - "_type": str(type(self)), |
4182 | | - "index": _index_to_str(self.idx), |
4183 | | - } |
4184 | | - ) |
| 4173 | + from tensordict.utils import json_dumps |
| 4174 | + |
| 4175 | + metadata_json = json_dumps( |
| 4176 | + { |
| 4177 | + "_type": str(type(self)), |
| 4178 | + "index": _index_to_str(self.idx), |
| 4179 | + } |
4185 | 4180 | ) |
| 4181 | + if isinstance(metadata_json, str): |
| 4182 | + metadata_json = metadata_json.encode("utf-8") |
| 4183 | + f.write(metadata_json) |
4186 | 4184 |
|
4187 | 4185 | if executor is None: |
4188 | 4186 | save_metadata() |
@@ -4682,7 +4680,14 @@ def _save_metadata(data: TensorDictBase, prefix: Path, metadata=None): |
4682 | 4680 | } |
4683 | 4681 | ) |
4684 | 4682 | with open(filepath, "wb") as json_metadata: |
4685 | | - json_metadata.write(json.dumps(metadata)) |
| 4683 | + from tensordict.utils import json_dumps |
| 4684 | + |
| 4685 | + json_str = json_dumps(metadata) |
| 4686 | + # Ensure we write bytes to the binary file |
| 4687 | + if isinstance(json_str, str): |
| 4688 | + json_metadata.write(json_str.encode("utf-8")) |
| 4689 | + else: |
| 4690 | + json_metadata.write(json_str) |
4686 | 4691 |
|
4687 | 4692 |
|
4688 | 4693 | # user did specify location and memmap is in wrong place, so we copy |
|
0 commit comments