|
41 | 41 | import torch # type: ignore
|
42 | 42 |
|
43 | 43 | if is_safetensors_available():
|
| 44 | + import packaging.version |
| 45 | + import safetensors |
44 | 46 | from safetensors.torch import load_model as load_model_as_safetensor
|
45 | 47 | from safetensors.torch import save_model as save_model_as_safetensor
|
46 | 48 |
|
@@ -827,17 +829,18 @@ def _load_as_pickle(cls, model: T, model_file: str, map_location: str, strict: b
|
827 | 829 |
|
828 | 830 | @classmethod
|
829 | 831 | def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
830 |
| - load_model_as_safetensor(model, model_file, strict=strict) # type: ignore [arg-type] |
831 |
| - if map_location != "cpu": |
832 |
| - # TODO: remove this once https://github.com/huggingface/safetensors/pull/449 is merged. |
833 |
| - logger.warning( |
834 |
| - "Loading model weights on other devices than 'cpu' is not supported natively." |
835 |
| - " This means that the model is loaded on 'cpu' first and then copied to the device." |
836 |
| - " This leads to a slower loading time." |
837 |
| - " Support for loading directly on other devices is planned to be added in future releases." |
838 |
| - " See https://github.com/huggingface/huggingface_hub/pull/2086 for more details." |
839 |
| - ) |
840 |
| - model.to(map_location) # type: ignore [attr-defined] |
| 832 | + if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"): |
| 833 | + load_model_as_safetensor(model, model_file, strict=strict) # type: ignore [arg-type] |
| 834 | + if map_location != "cpu": |
| 835 | + logger.warning( |
| 836 | + "Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors." |
| 837 | + " This means that the model is loaded on 'cpu' first and then copied to the device." |
| 838 | + " This leads to a slower loading time." |
| 839 | + " Please update safetensors to version 0.4.3 or above for improved performance." |
| 840 | + ) |
| 841 | + model.to(map_location) # type: ignore [attr-defined] |
| 842 | + else: |
| 843 | + safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) |
841 | 844 | return model
|
842 | 845 |
|
843 | 846 |
|
|
0 commit comments