Skip to content

Commit

Permalink
Filter models by inference status (#2517)
Browse files Browse the repository at this point in the history
* Filter models by inference status

* fiox
  • Loading branch information
Wauplin authored Sep 6, 2024
1 parent 1b9e5b0 commit 63353cf
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
23 changes: 18 additions & 5 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,13 +718,17 @@ class ModelInfo:
Is the repo private.
disabled (`bool`, *optional*):
Is the repo disabled.
gated (`Literal["auto", "manual", False]`, *optional*):
Is the repo gated.
If so, whether there is manual or automatic approval.
downloads (`int`):
Number of downloads of the model over the last 30 days.
downloads_all_time (`int`):
Cumulated number of downloads of the model since its creation.
gated (`Literal["auto", "manual", False]`, *optional*):
Is the repo gated.
If so, whether there is manual or automatic approval.
inference (`Literal["cold", "frozen", "warm"]`, *optional*):
Status of the model on the inference API.
Warm models are available for immediate use. Cold models will be loaded on first inference call.
Frozen models are not available in Inference API.
likes (`int`):
Number of likes of the model.
library_name (`str`, *optional*):
Expand Down Expand Up @@ -760,10 +764,11 @@ class ModelInfo:
created_at: Optional[datetime]
last_modified: Optional[datetime]
private: Optional[bool]
gated: Optional[Literal["auto", "manual", False]]
disabled: Optional[bool]
downloads: Optional[int]
downloads_all_time: Optional[int]
gated: Optional[Literal["auto", "manual", False]]
inference: Optional[Literal["warm", "cold", "frozen"]]
likes: Optional[int]
library_name: Optional[str]
tags: Optional[List[str]]
Expand Down Expand Up @@ -793,6 +798,7 @@ def __init__(self, **kwargs):
self.downloads_all_time = kwargs.pop("downloadsAllTime", None)
self.likes = kwargs.pop("likes", None)
self.library_name = kwargs.pop("library_name", None)
self.inference = kwargs.pop("inference", None)
self.tags = kwargs.pop("tags", None)
self.pipeline_tag = kwargs.pop("pipeline_tag", None)
self.mask_token = kwargs.pop("mask_token", None)
Expand Down Expand Up @@ -1611,6 +1617,7 @@ def list_models(
filter: Union[str, Iterable[str], None] = None,
author: Optional[str] = None,
gated: Optional[bool] = None,
inference: Optional[Literal["cold", "frozen", "warm"]] = None,
library: Optional[Union[str, List[str]]] = None,
language: Optional[Union[str, List[str]]] = None,
model_name: Optional[str] = None,
Expand Down Expand Up @@ -1639,11 +1646,15 @@ def list_models(
A string or list of string to filter models on the Hub.
author (`str`, *optional*):
A string which identify the author (user or organization) of the
returned models
returned models.
gated (`bool`, *optional*):
A boolean to filter models on the Hub that are gated or not. By default, all models are returned.
If `gated=True` is passed, only gated models are returned.
If `gated=False` is passed, only non-gated models are returned.
inference (`Literal["cold", "frozen", "warm"]`, *optional*):
A string to filter models on the Hub by their state on the Inference API.
Warm models are available for immediate use. Cold models will be loaded on first inference call.
Frozen models are not available in Inference API.
library (`str` or `List`, *optional*):
A string or list of strings of foundational libraries models were
originally trained from, such as pytorch, tensorflow, or allennlp.
Expand Down Expand Up @@ -1771,6 +1782,8 @@ def list_models(
params["author"] = author
if gated is not None:
params["gated"] = gated
if inference is not None:
params["inference"] = inference
if pipeline_tag:
params["pipeline_tag"] = pipeline_tag
search_list = []
Expand Down
8 changes: 8 additions & 0 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1792,6 +1792,14 @@ def test_list_models_non_gated_only(self):
for model in self._api.list_models(expand=["gated"], gated=False, limit=5):
assert model.gated is False

def test_list_models_inference_warm(self):
for model in self._api.list_models(inference=["warm"], expand="inference", limit=5):
assert model.inference == "warm"

def test_list_models_inference_cold(self):
for model in self._api.list_models(inference=["cold"], expand="inference", limit=5):
assert model.inference == "cold"

def test_model_info(self):
model = self._api.model_info(repo_id=DUMMY_MODEL_ID)
self.assertIsInstance(model, ModelInfo)
Expand Down

0 comments on commit 63353cf

Please sign in to comment.