Skip to content

Commit d297df5

Browse files
Implemented #2516 (#2532)
* implemented #2516 * fix ruff error * Update setup.py removed torch dependency for ["dev"] Co-authored-by: Lucain <lucainp@gmail.com> --------- Co-authored-by: Lucain <lucainp@gmail.com>
1 parent 855755b commit d297df5

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

src/huggingface_hub/hub_mixin.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
import torch # type: ignore
4242

4343
if is_safetensors_available():
44+
import packaging.version
45+
import safetensors
4446
from safetensors.torch import load_model as load_model_as_safetensor
4547
from safetensors.torch import save_model as save_model_as_safetensor
4648

@@ -827,17 +829,18 @@ def _load_as_pickle(cls, model: T, model_file: str, map_location: str, strict: b
827829

828830
@classmethod
829831
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
830-
load_model_as_safetensor(model, model_file, strict=strict) # type: ignore [arg-type]
831-
if map_location != "cpu":
832-
# TODO: remove this once https://github.com/huggingface/safetensors/pull/449 is merged.
833-
logger.warning(
834-
"Loading model weights on other devices than 'cpu' is not supported natively."
835-
" This means that the model is loaded on 'cpu' first and then copied to the device."
836-
" This leads to a slower loading time."
837-
" Support for loading directly on other devices is planned to be added in future releases."
838-
" See https://github.com/huggingface/huggingface_hub/pull/2086 for more details."
839-
)
840-
model.to(map_location) # type: ignore [attr-defined]
832+
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"):
833+
load_model_as_safetensor(model, model_file, strict=strict) # type: ignore [arg-type]
834+
if map_location != "cpu":
835+
logger.warning(
836+
"Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
837+
" This means that the model is loaded on 'cpu' first and then copied to the device."
838+
" This leads to a slower loading time."
839+
" Please update safetensors to version 0.4.3 or above for improved performance."
840+
)
841+
model.to(map_location) # type: ignore [attr-defined]
842+
else:
843+
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
841844
return model
842845

843846

0 commit comments

Comments
 (0)