Skip to content

Commit fc6b26d

Browse files
committed
[BugFix] JSON/orjson compatibility (#1373)
1 parent fcdb06d commit fc6b26d

File tree

7 files changed

+118
-249
lines changed

7 files changed

+118
-249
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ dependencies = [
4040
"importlib_metadata",
4141
# orjson fails to be installed in python 3.13t
4242
'orjson ; python_version < "3.13"',
43+
"pyvers (>=0.1.0,<0.2.0)",
4344
]
4445

4546
[project.urls]

tensordict/_lazy.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,6 @@
8181
from torch import Tensor
8282
from torch.nn.utils.rnn import pad_sequence
8383

84-
try:
85-
import orjson as json
86-
except ImportError:
87-
# Fallback
88-
import json
8984

9085
try:
9186
from functorch import dim as ftdim
@@ -2876,11 +2871,16 @@ def save_metadata(prefix=prefix, self=self):
28762871
if not prefix.exists():
28772872
os.makedirs(prefix, exist_ok=True)
28782873
with open(prefix / "meta.json", "wb") as f:
2879-
f.write(
2880-
json.dumps(
2881-
{"_type": str(type(self)), "stack_dim": self.stack_dim}
2882-
)
2874+
from tensordict.utils import json_dumps
2875+
2876+
json_str = json_dumps(
2877+
{"_type": str(type(self)), "stack_dim": self.stack_dim}
28832878
)
2879+
# Ensure we write bytes to the binary file
2880+
if isinstance(json_str, str):
2881+
f.write(json_str.encode("utf-8"))
2882+
else:
2883+
f.write(json_str)
28842884

28852885
if executor is None:
28862886
save_metadata()
@@ -4201,7 +4201,14 @@ def save_metadata(data: TensorDictBase, filepath, metadata=None):
42014201
}
42024202
)
42034203
with open(filepath, "wb") as json_metadata:
4204-
json_metadata.write(json.dumps(metadata))
4204+
from tensordict.utils import json_dumps
4205+
4206+
json_str = json_dumps(metadata)
4207+
# Ensure we write bytes to the binary file
4208+
if isinstance(json_str, str):
4209+
json_metadata.write(json_str.encode("utf-8"))
4210+
else:
4211+
json_metadata.write(json_str)
42054212

42064213
if prefix is not None:
42074214
prefix = Path(prefix)

tensordict/_td.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,6 @@
9393
from torch.nn.utils._named_member_accessor import swap_tensor
9494
from torch.utils._pytree import tree_map
9595

96-
try:
97-
import orjson as json
98-
except ImportError:
99-
# Fallback
100-
import json
10196
try:
10297
from functorch import dim as ftdim
10398

@@ -4175,14 +4170,17 @@ def save_metadata(prefix=prefix, self=self):
41754170
if not prefix.exists():
41764171
os.makedirs(prefix, exist_ok=True)
41774172
with open(prefix / "meta.json", "wb") as f:
4178-
f.write(
4179-
json.dumps(
4180-
{
4181-
"_type": str(type(self)),
4182-
"index": _index_to_str(self.idx),
4183-
}
4184-
)
4173+
from tensordict.utils import json_dumps
4174+
4175+
metadata_json = json_dumps(
4176+
{
4177+
"_type": str(type(self)),
4178+
"index": _index_to_str(self.idx),
4179+
}
41854180
)
4181+
if isinstance(metadata_json, str):
4182+
metadata_json = metadata_json.encode("utf-8")
4183+
f.write(metadata_json)
41864184

41874185
if executor is None:
41884186
save_metadata()
@@ -4682,7 +4680,14 @@ def _save_metadata(data: TensorDictBase, prefix: Path, metadata=None):
46824680
}
46834681
)
46844682
with open(filepath, "wb") as json_metadata:
4685-
json_metadata.write(json.dumps(metadata))
4683+
from tensordict.utils import json_dumps
4684+
4685+
json_str = json_dumps(metadata)
4686+
# Ensure we write bytes to the binary file
4687+
if isinstance(json_str, str):
4688+
json_metadata.write(json_str.encode("utf-8"))
4689+
else:
4690+
json_metadata.write(json_str)
46864691

46874692

46884693
# user did specify location and memmap is in wrong place, so we copy

tensordict/base.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
import gc
1414
import importlib
1515
import importlib.util
16+
17+
# JSON backend is now handled by utils.json_dumps
18+
import json
1619
import os.path
1720
import queue
1821
import uuid
@@ -111,12 +114,6 @@
111114
from torch.nn.parameter import Parameter, UninitializedTensorMixin
112115
from torch.utils._pytree import tree_map
113116

114-
try:
115-
import orjson as json
116-
except ImportError:
117-
# Fallback for 3.13
118-
import json
119-
120117
try:
121118
from torch.compiler import is_compiling
122119
except ImportError: # torch 2.0
@@ -5355,7 +5352,9 @@ def consolidate(
53555352
else:
53565353
# Convert the dict to json
53575354
try:
5358-
metadata_dict_json = json.dumps(metadata_dict)
5355+
from tensordict.utils import json_dumps
5356+
5357+
metadata_dict_json = json_dumps(metadata_dict)
53595358
except TypeError as e:
53605359
raise RuntimeError(
53615360
"Failed to convert the metatdata to json. "
@@ -5364,6 +5363,8 @@ def consolidate(
53645363
"If you encounter this error, please file an issue on github."
53655364
) from e
53665365
# Represent as a tensor
5366+
if isinstance(metadata_dict_json, str):
5367+
metadata_dict_json = metadata_dict_json.encode("utf-8")
53675368
metadata_dict_json = torch.as_tensor(
53685369
bytearray(metadata_dict_json), dtype=torch.uint8
53695370
)

tensordict/persistent.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,6 @@
5656
)
5757
from torch import multiprocessing as mp
5858

59-
try:
60-
import orjson as json
61-
except ImportError:
62-
# Fallback for 3.13
63-
import json
6459

6560
_has_h5 = importlib.util.find_spec("h5py", None) is not None
6661

@@ -768,7 +763,14 @@ def save_metadata(data: TensorDictBase, filepath, metadata=None):
768763
}
769764
)
770765
with open(filepath, "wb") as json_metadata:
771-
json_metadata.write(json.dumps(metadata))
766+
from tensordict.utils import json_dumps
767+
768+
json_str = json_dumps(metadata)
769+
# Ensure we write bytes to the binary file
770+
if isinstance(json_str, str):
771+
json_metadata.write(json_str.encode("utf-8"))
772+
else:
773+
json_metadata.write(json_str)
772774

773775
if prefix is not None:
774776
prefix = Path(prefix)

tensordict/tensorclass.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import dataclasses
1313
import functools
1414
import inspect
15+
1516
import multiprocessing.managers
1617
import multiprocessing.sharedctypes
1718
import numbers
@@ -80,11 +81,6 @@
8081
from torch.multiprocessing import Manager
8182
from torch.utils._pytree import tree_map
8283

83-
try:
84-
import orjson as json
85-
except ImportError:
86-
# Fallback for 3.13
87-
import json
8884
try:
8985
from torch.compiler import is_compiling
9086
except ImportError: # torch 2.0
@@ -1532,7 +1528,14 @@ def save_metadata(cls=cls, _non_tensordict=_non_tensordict, prefix=prefix):
15321528
metadata[key] = value
15331529
else:
15341530
to_pickle[key] = value
1535-
f.write(json.dumps(metadata))
1531+
from tensordict.utils import json_dumps
1532+
1533+
json_str = json_dumps(metadata)
1534+
# Ensure we write bytes to the binary file
1535+
if isinstance(json_str, str):
1536+
f.write(json_str.encode("utf-8"))
1537+
else:
1538+
f.write(json_str)
15361539
if to_pickle:
15371540
with open(prefix / "other.pickle", "wb") as pickle_file:
15381541
pickle.dump(to_pickle, pickle_file)
@@ -4094,7 +4097,14 @@ def save_metadata(prefix=prefix, self=self):
40944097
with open(prefix / "pickle.pkl", "wb") as f:
40954098
pickle.dump(data, f)
40964099
with open(prefix / "meta.json", "wb") as f:
4097-
f.write(json.dumps(jsondict))
4100+
from tensordict.utils import json_dumps
4101+
4102+
json_str = json_dumps(jsondict, separators=(",", ":"))
4103+
# Ensure we write bytes to the binary file
4104+
if isinstance(json_str, str):
4105+
f.write(json_str.encode("utf-8"))
4106+
else:
4107+
f.write(json_str)
40984108

40994109
if executor is None:
41004110
save_metadata()

0 commit comments

Comments
 (0)