Skip to content

Commit d116de0

Browse files
lhoestqWauplin
andauthored
[HfFileSystem] improve cache for multiprocessing fork and multithreading (#3500)
* improve hffs cache for multiprocessing fork * minor * mypy * fix for token * add test * fix for threading too * comment * fix CI: make HfHubHTTPError picklable * fix tests * better naming * clear instance cache before testing to ignore remaning Mock objects * don't test "fork" on windows * Apply suggestions from code review Co-authored-by: Lucain <lucain@huggingface.co> * use insecure_hashlib * style --------- Co-authored-by: Lucain <lucain@huggingface.co>
1 parent fcac2c0 commit d116de0

File tree

3 files changed

+123
-9
lines changed

3 files changed

+123
-9
lines changed

src/huggingface_hub/errors.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ def append_to_message(self, additional_message: str) -> None:
8484
"""Append additional information to the `HfHubHTTPError` initial message."""
8585
self.args = (self.args[0] + additional_message,) + self.args[1:]
8686

87+
def __reduce_ex__(self, protocol):
88+
"""Fix pickling of Exception subclass with kwargs. We need to override __reduce_ex__ of the parent class"""
89+
return (self.__class__, (str(self),), {"response": self.response, "server_message": self.server_message})
90+
8791

8892
# INFERENCE CLIENT ERRORS
8993

src/huggingface_hub/hf_file_system.py

Lines changed: 81 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import os
22
import re
33
import tempfile
4+
import threading
45
from collections import deque
56
from contextlib import ExitStack
7+
from copy import deepcopy
68
from dataclasses import dataclass, field
79
from datetime import datetime
810
from itertools import chain
@@ -21,6 +23,7 @@
2123
from .file_download import hf_hub_url, http_get
2224
from .hf_api import HfApi, LastCommitInfo, RepoFile
2325
from .utils import HFValidationError, hf_raise_for_status, http_backoff, http_stream_backoff
26+
from .utils.insecure_hashlib import md5
2427

2528

2629
# Regex used to match special revisions with "/" in them (see #1710)
@@ -56,7 +59,61 @@ def unresolve(self) -> str:
5659
return f"{repo_path}/{self.path_in_repo}".rstrip("/")
5760

5861

59-
class HfFileSystem(fsspec.AbstractFileSystem):
62+
# We need to improve fsspec.spec._Cached which is AbstractFileSystem's metaclass
63+
_cached_base: Any = type(fsspec.AbstractFileSystem)
64+
65+
66+
class _Cached(_cached_base):
67+
"""
68+
Metaclass for caching HfFileSystem instances according to the args.
69+
70+
This creates an additional reference to the filesystem, which prevents the
71+
filesystem from being garbage collected when all *user* references go away.
72+
A call to the :meth:`AbstractFileSystem.clear_instance_cache` must *also*
73+
be made for a filesystem instance to be garbage collected.
74+
75+
This is a slightly modified version of `fsspec.spec._Cached` to improve it.
76+
In particular in `_tokenize` the pid isn't taken into account for the
77+
`fs_token` used to identify cached instances. The `fs_token` logic is also
78+
robust to defaults values and the order of the args. Finally new instances
79+
reuse the states from sister instances in the main thread.
80+
"""
81+
82+
def __init__(cls, *args, **kwargs):
83+
# Hack: override https://github.com/fsspec/filesystem_spec/blob/dcb167e8f50e6273d4cfdfc4cab8fc5aa4c958bf/fsspec/spec.py#L53
84+
super().__init__(*args, **kwargs)
85+
# Note: we intentionally create a reference here, to avoid garbage
86+
# collecting instances when all other references are gone. To really
87+
# delete a FileSystem, the cache must be cleared.
88+
cls._cache = {}
89+
90+
def __call__(cls, *args, **kwargs):
91+
# Hack: override https://github.com/fsspec/filesystem_spec/blob/dcb167e8f50e6273d4cfdfc4cab8fc5aa4c958bf/fsspec/spec.py#L65
92+
skip = kwargs.pop("skip_instance_cache", False)
93+
fs_token = cls._tokenize(cls, threading.get_ident(), *args, **kwargs)
94+
fs_token_main_thread = cls._tokenize(cls, threading.main_thread().ident, *args, **kwargs)
95+
if not skip and cls.cachable and fs_token in cls._cache:
96+
# reuse cached instance
97+
cls._latest = fs_token
98+
return cls._cache[fs_token]
99+
else:
100+
# create new instance
101+
obj = type.__call__(cls, *args, **kwargs)
102+
if not skip and cls.cachable and fs_token_main_thread in cls._cache:
103+
# reuse the cache from the main thread instance in the new instance
104+
instance_state = cls._cache[fs_token_main_thread]._get_instance_state()
105+
for attr, state_value in instance_state.items():
106+
setattr(obj, attr, state_value)
107+
obj._fs_token_ = fs_token
108+
obj.storage_args = args
109+
obj.storage_options = kwargs
110+
if cls.cachable and not skip:
111+
cls._latest = fs_token
112+
cls._cache[fs_token] = obj
113+
return obj
114+
115+
116+
class HfFileSystem(fsspec.AbstractFileSystem, metaclass=_Cached):
60117
"""
61118
Access a remote Hugging Face Hub repository as if were a local file system.
62119
@@ -119,6 +176,18 @@ def __init__(
119176
# Maps parent directory path to path infos
120177
self.dircache: dict[str, list[dict[str, Any]]] = {}
121178

179+
@classmethod
180+
def _tokenize(cls, threading_ident: int, *args, **kwargs) -> str:
181+
"""Deterministic token for caching"""
182+
# make fs_token robust to default values and to kwargs order
183+
kwargs["endpoint"] = kwargs.get("endpoint") or constants.ENDPOINT
184+
kwargs["token"] = kwargs.get("token")
185+
kwargs = {key: kwargs[key] for key in sorted(kwargs)}
186+
# contrary to fsspec, we don't include pid here
187+
tokenize_args = (cls, threading_ident, args, kwargs)
188+
h = md5(str(tokenize_args).encode())
189+
return h.hexdigest()
190+
122191
def _repo_and_revision_exist(
123192
self, repo_type: str, repo_id: str, revision: Optional[str]
124193
) -> tuple[bool, Optional[Exception]]:
@@ -931,17 +1000,20 @@ def start_transaction(self):
9311000
raise NotImplementedError("Transactional commits are not supported.")
9321001

9331002
def __reduce__(self):
934-
# re-populate the instance cache at HfFileSystem._cache and re-populate the cache attributes of every instance
1003+
# re-populate the instance cache at HfFileSystem._cache and re-populate the state of every instance
9351004
return make_instance, (
9361005
type(self),
9371006
self.storage_args,
9381007
self.storage_options,
939-
{
940-
"dircache": self.dircache,
941-
"_repo_and_revision_exists_cache": self._repo_and_revision_exists_cache,
942-
},
1008+
self._get_instance_state(),
9431009
)
9441010

1011+
def _get_instance_state(self):
1012+
return {
1013+
"dircache": deepcopy(self.dircache),
1014+
"_repo_and_revision_exists_cache": deepcopy(self._repo_and_revision_exists_cache),
1015+
}
1016+
9451017

9461018
class HfFileSystemFile(fsspec.spec.AbstractBufferedFile):
9471019
def __init__(self, fs: HfFileSystem, path: str, revision: Optional[str] = None, **kwargs):
@@ -1178,8 +1250,8 @@ def _partial_read(response: httpx.Response, length: int = -1) -> bytes:
11781250
return bytes(buf) # may be < length if response ended
11791251

11801252

1181-
def make_instance(cls, args, kwargs, instance_cache_attributes_dict):
1253+
def make_instance(cls, args, kwargs, instance_state):
11821254
fs = cls(*args, **kwargs)
1183-
for attr, cached_value in instance_cache_attributes_dict.items():
1184-
setattr(fs, attr, cached_value)
1255+
for attr, state_value in instance_state.items():
1256+
setattr(fs, attr, state_value)
11851257
return fs

tests/test_hf_file_system.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import copy
22
import datetime
33
import io
4+
import multiprocessing
5+
import multiprocessing.pool
46
import os
57
import pickle
68
import tempfile
@@ -644,6 +646,42 @@ def test_exists_after_repo_deletion():
644646
assert not hffs.exists(repo_id, refresh=True)
645647

646648

649+
def _get_fs_token_and_dircache(fs):
650+
fs = HfFileSystem(endpoint=fs.endpoint, token=fs.token)
651+
return fs._fs_token, fs.dircache
652+
653+
654+
def test_cache():
655+
HfFileSystem.clear_instance_cache()
656+
fs = HfFileSystem()
657+
fs.dircache = {"dummy": []}
658+
659+
assert HfFileSystem() is fs
660+
assert HfFileSystem(endpoint=constants.ENDPOINT) is fs
661+
assert HfFileSystem(token=None, endpoint=constants.ENDPOINT) is fs
662+
663+
another_fs = HfFileSystem(endpoint="something-else")
664+
assert another_fs is not fs
665+
assert another_fs.dircache != fs.dircache
666+
667+
with multiprocessing.get_context("spawn").Pool() as pool:
668+
(fs_token, dircache), (_, another_dircache) = pool.map(_get_fs_token_and_dircache, [fs, another_fs])
669+
assert dircache == fs.dircache
670+
assert another_dircache != fs.dircache
671+
672+
if os.name != "nt": # "fork" is unavailable on windows
673+
with multiprocessing.get_context("fork").Pool() as pool:
674+
(fs_token, dircache), (_, another_dircache) = pool.map(_get_fs_token_and_dircache, [fs, another_fs])
675+
assert dircache == fs.dircache
676+
assert another_dircache != fs.dircache
677+
678+
with multiprocessing.pool.ThreadPool() as pool:
679+
(fs_token, dircache), (_, another_dircache) = pool.map(_get_fs_token_and_dircache, [fs, another_fs])
680+
assert dircache == fs.dircache
681+
assert another_dircache != fs.dircache
682+
assert fs_token != fs._fs_token # use a different instance for thread safety
683+
684+
647685
@with_production_testing
648686
def test_hf_file_system_file_can_handle_gzipped_file():
649687
"""Test that HfFileSystemStreamFile.read() can handle gzipped files."""

0 commit comments

Comments
 (0)