Skip to content

Commit 63353cf

Browse files
authored
Filter models by inference status (#2517)
* Filter models by inference status * fiox
1 parent 1b9e5b0 commit 63353cf

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

src/huggingface_hub/hf_api.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -718,13 +718,17 @@ class ModelInfo:
718718
Is the repo private.
719719
disabled (`bool`, *optional*):
720720
Is the repo disabled.
721-
gated (`Literal["auto", "manual", False]`, *optional*):
722-
Is the repo gated.
723-
If so, whether there is manual or automatic approval.
724721
downloads (`int`):
725722
Number of downloads of the model over the last 30 days.
726723
downloads_all_time (`int`):
727724
Cumulated number of downloads of the model since its creation.
725+
gated (`Literal["auto", "manual", False]`, *optional*):
726+
Is the repo gated.
727+
If so, whether there is manual or automatic approval.
728+
inference (`Literal["cold", "frozen", "warm"]`, *optional*):
729+
Status of the model on the inference API.
730+
Warm models are available for immediate use. Cold models will be loaded on first inference call.
731+
Frozen models are not available in Inference API.
728732
likes (`int`):
729733
Number of likes of the model.
730734
library_name (`str`, *optional*):
@@ -760,10 +764,11 @@ class ModelInfo:
760764
created_at: Optional[datetime]
761765
last_modified: Optional[datetime]
762766
private: Optional[bool]
763-
gated: Optional[Literal["auto", "manual", False]]
764767
disabled: Optional[bool]
765768
downloads: Optional[int]
766769
downloads_all_time: Optional[int]
770+
gated: Optional[Literal["auto", "manual", False]]
771+
inference: Optional[Literal["warm", "cold", "frozen"]]
767772
likes: Optional[int]
768773
library_name: Optional[str]
769774
tags: Optional[List[str]]
@@ -793,6 +798,7 @@ def __init__(self, **kwargs):
793798
self.downloads_all_time = kwargs.pop("downloadsAllTime", None)
794799
self.likes = kwargs.pop("likes", None)
795800
self.library_name = kwargs.pop("library_name", None)
801+
self.inference = kwargs.pop("inference", None)
796802
self.tags = kwargs.pop("tags", None)
797803
self.pipeline_tag = kwargs.pop("pipeline_tag", None)
798804
self.mask_token = kwargs.pop("mask_token", None)
@@ -1611,6 +1617,7 @@ def list_models(
16111617
filter: Union[str, Iterable[str], None] = None,
16121618
author: Optional[str] = None,
16131619
gated: Optional[bool] = None,
1620+
inference: Optional[Literal["cold", "frozen", "warm"]] = None,
16141621
library: Optional[Union[str, List[str]]] = None,
16151622
language: Optional[Union[str, List[str]]] = None,
16161623
model_name: Optional[str] = None,
@@ -1639,11 +1646,15 @@ def list_models(
16391646
A string or list of string to filter models on the Hub.
16401647
author (`str`, *optional*):
16411648
A string which identify the author (user or organization) of the
1642-
returned models
1649+
returned models.
16431650
gated (`bool`, *optional*):
16441651
A boolean to filter models on the Hub that are gated or not. By default, all models are returned.
16451652
If `gated=True` is passed, only gated models are returned.
16461653
If `gated=False` is passed, only non-gated models are returned.
1654+
inference (`Literal["cold", "frozen", "warm"]`, *optional*):
1655+
A string to filter models on the Hub by their state on the Inference API.
1656+
Warm models are available for immediate use. Cold models will be loaded on first inference call.
1657+
Frozen models are not available in Inference API.
16471658
library (`str` or `List`, *optional*):
16481659
A string or list of strings of foundational libraries models were
16491660
originally trained from, such as pytorch, tensorflow, or allennlp.
@@ -1771,6 +1782,8 @@ def list_models(
17711782
params["author"] = author
17721783
if gated is not None:
17731784
params["gated"] = gated
1785+
if inference is not None:
1786+
params["inference"] = inference
17741787
if pipeline_tag:
17751788
params["pipeline_tag"] = pipeline_tag
17761789
search_list = []

tests/test_hf_api.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1792,6 +1792,14 @@ def test_list_models_non_gated_only(self):
17921792
for model in self._api.list_models(expand=["gated"], gated=False, limit=5):
17931793
assert model.gated is False
17941794

1795+
def test_list_models_inference_warm(self):
1796+
for model in self._api.list_models(inference=["warm"], expand="inference", limit=5):
1797+
assert model.inference == "warm"
1798+
1799+
def test_list_models_inference_cold(self):
1800+
for model in self._api.list_models(inference=["cold"], expand="inference", limit=5):
1801+
assert model.inference == "cold"
1802+
17951803
def test_model_info(self):
17961804
model = self._api.model_info(repo_id=DUMMY_MODEL_ID)
17971805
self.assertIsInstance(model, ModelInfo)

0 commit comments

Comments
 (0)