|
1 | 1 | import os |
2 | 2 | import re |
3 | 3 | import tempfile |
| 4 | +import threading |
4 | 5 | from collections import deque |
5 | 6 | from contextlib import ExitStack |
| 7 | +from copy import deepcopy |
6 | 8 | from dataclasses import dataclass, field |
7 | 9 | from datetime import datetime |
8 | 10 | from itertools import chain |
|
21 | 23 | from .file_download import hf_hub_url, http_get |
22 | 24 | from .hf_api import HfApi, LastCommitInfo, RepoFile |
23 | 25 | from .utils import HFValidationError, hf_raise_for_status, http_backoff, http_stream_backoff |
| 26 | +from .utils.insecure_hashlib import md5 |
24 | 27 |
|
25 | 28 |
|
26 | 29 | # Regex used to match special revisions with "/" in them (see #1710) |
@@ -56,7 +59,61 @@ def unresolve(self) -> str: |
56 | 59 | return f"{repo_path}/{self.path_in_repo}".rstrip("/") |
57 | 60 |
|
58 | 61 |
|
59 | | -class HfFileSystem(fsspec.AbstractFileSystem): |
| 62 | +# We need to improve fsspec.spec._Cached which is AbstractFileSystem's metaclass |
| 63 | +_cached_base: Any = type(fsspec.AbstractFileSystem) |
| 64 | + |
| 65 | + |
| 66 | +class _Cached(_cached_base): |
| 67 | + """ |
| 68 | + Metaclass for caching HfFileSystem instances according to the args. |
| 69 | +
|
| 70 | + This creates an additional reference to the filesystem, which prevents the |
| 71 | + filesystem from being garbage collected when all *user* references go away. |
| 72 | + A call to the :meth:`AbstractFileSystem.clear_instance_cache` must *also* |
| 73 | + be made for a filesystem instance to be garbage collected. |
| 74 | +
|
| 75 | + This is a slightly modified version of `fsspec.spec._Cached` to improve it. |
| 76 | + In particular in `_tokenize` the pid isn't taken into account for the |
| 77 | + `fs_token` used to identify cached instances. The `fs_token` logic is also |
| 78 | + robust to defaults values and the order of the args. Finally new instances |
| 79 | + reuse the states from sister instances in the main thread. |
| 80 | + """ |
| 81 | + |
| 82 | + def __init__(cls, *args, **kwargs): |
| 83 | + # Hack: override https://github.com/fsspec/filesystem_spec/blob/dcb167e8f50e6273d4cfdfc4cab8fc5aa4c958bf/fsspec/spec.py#L53 |
| 84 | + super().__init__(*args, **kwargs) |
| 85 | + # Note: we intentionally create a reference here, to avoid garbage |
| 86 | + # collecting instances when all other references are gone. To really |
| 87 | + # delete a FileSystem, the cache must be cleared. |
| 88 | + cls._cache = {} |
| 89 | + |
| 90 | + def __call__(cls, *args, **kwargs): |
| 91 | + # Hack: override https://github.com/fsspec/filesystem_spec/blob/dcb167e8f50e6273d4cfdfc4cab8fc5aa4c958bf/fsspec/spec.py#L65 |
| 92 | + skip = kwargs.pop("skip_instance_cache", False) |
| 93 | + fs_token = cls._tokenize(cls, threading.get_ident(), *args, **kwargs) |
| 94 | + fs_token_main_thread = cls._tokenize(cls, threading.main_thread().ident, *args, **kwargs) |
| 95 | + if not skip and cls.cachable and fs_token in cls._cache: |
| 96 | + # reuse cached instance |
| 97 | + cls._latest = fs_token |
| 98 | + return cls._cache[fs_token] |
| 99 | + else: |
| 100 | + # create new instance |
| 101 | + obj = type.__call__(cls, *args, **kwargs) |
| 102 | + if not skip and cls.cachable and fs_token_main_thread in cls._cache: |
| 103 | + # reuse the cache from the main thread instance in the new instance |
| 104 | + instance_state = cls._cache[fs_token_main_thread]._get_instance_state() |
| 105 | + for attr, state_value in instance_state.items(): |
| 106 | + setattr(obj, attr, state_value) |
| 107 | + obj._fs_token_ = fs_token |
| 108 | + obj.storage_args = args |
| 109 | + obj.storage_options = kwargs |
| 110 | + if cls.cachable and not skip: |
| 111 | + cls._latest = fs_token |
| 112 | + cls._cache[fs_token] = obj |
| 113 | + return obj |
| 114 | + |
| 115 | + |
| 116 | +class HfFileSystem(fsspec.AbstractFileSystem, metaclass=_Cached): |
60 | 117 | """ |
61 | 118 | Access a remote Hugging Face Hub repository as if were a local file system. |
62 | 119 |
|
@@ -119,6 +176,18 @@ def __init__( |
119 | 176 | # Maps parent directory path to path infos |
120 | 177 | self.dircache: dict[str, list[dict[str, Any]]] = {} |
121 | 178 |
|
| 179 | + @classmethod |
| 180 | + def _tokenize(cls, threading_ident: int, *args, **kwargs) -> str: |
| 181 | + """Deterministic token for caching""" |
| 182 | + # make fs_token robust to default values and to kwargs order |
| 183 | + kwargs["endpoint"] = kwargs.get("endpoint") or constants.ENDPOINT |
| 184 | + kwargs["token"] = kwargs.get("token") |
| 185 | + kwargs = {key: kwargs[key] for key in sorted(kwargs)} |
| 186 | + # contrary to fsspec, we don't include pid here |
| 187 | + tokenize_args = (cls, threading_ident, args, kwargs) |
| 188 | + h = md5(str(tokenize_args).encode()) |
| 189 | + return h.hexdigest() |
| 190 | + |
122 | 191 | def _repo_and_revision_exist( |
123 | 192 | self, repo_type: str, repo_id: str, revision: Optional[str] |
124 | 193 | ) -> tuple[bool, Optional[Exception]]: |
@@ -931,17 +1000,20 @@ def start_transaction(self): |
931 | 1000 | raise NotImplementedError("Transactional commits are not supported.") |
932 | 1001 |
|
933 | 1002 | def __reduce__(self): |
934 | | - # re-populate the instance cache at HfFileSystem._cache and re-populate the cache attributes of every instance |
| 1003 | + # re-populate the instance cache at HfFileSystem._cache and re-populate the state of every instance |
935 | 1004 | return make_instance, ( |
936 | 1005 | type(self), |
937 | 1006 | self.storage_args, |
938 | 1007 | self.storage_options, |
939 | | - { |
940 | | - "dircache": self.dircache, |
941 | | - "_repo_and_revision_exists_cache": self._repo_and_revision_exists_cache, |
942 | | - }, |
| 1008 | + self._get_instance_state(), |
943 | 1009 | ) |
944 | 1010 |
|
| 1011 | + def _get_instance_state(self): |
| 1012 | + return { |
| 1013 | + "dircache": deepcopy(self.dircache), |
| 1014 | + "_repo_and_revision_exists_cache": deepcopy(self._repo_and_revision_exists_cache), |
| 1015 | + } |
| 1016 | + |
945 | 1017 |
|
946 | 1018 | class HfFileSystemFile(fsspec.spec.AbstractBufferedFile): |
947 | 1019 | def __init__(self, fs: HfFileSystem, path: str, revision: Optional[str] = None, **kwargs): |
@@ -1178,8 +1250,8 @@ def _partial_read(response: httpx.Response, length: int = -1) -> bytes: |
1178 | 1250 | return bytes(buf) # may be < length if response ended |
1179 | 1251 |
|
1180 | 1252 |
|
1181 | | -def make_instance(cls, args, kwargs, instance_cache_attributes_dict): |
| 1253 | +def make_instance(cls, args, kwargs, instance_state): |
1182 | 1254 | fs = cls(*args, **kwargs) |
1183 | | - for attr, cached_value in instance_cache_attributes_dict.items(): |
1184 | | - setattr(fs, attr, cached_value) |
| 1255 | + for attr, state_value in instance_state.items(): |
| 1256 | + setattr(fs, attr, state_value) |
1185 | 1257 | return fs |
0 commit comments