diff --git a/docs/source/en/package_reference/hf_api.md b/docs/source/en/package_reference/hf_api.md index c82e17f7c6..7ed523e22b 100644 --- a/docs/source/en/package_reference/hf_api.md +++ b/docs/source/en/package_reference/hf_api.md @@ -57,6 +57,10 @@ models = hf_api.list_models() [[autodoc]] huggingface_hub.hf_api.GitRefs +### InferenceProviderMapping + +[[autodoc]] huggingface_hub.hf_api.InferenceProviderMapping + ### LFSFileInfo [[autodoc]] huggingface_hub.hf_api.LFSFileInfo diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index a994915244..77581adee6 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -103,6 +103,7 @@ RevisionNotFoundError, ) from .file_download import HfFileMetadata, get_hf_file_metadata, hf_hub_url +from .inference._providers import PROVIDER_T from .repocard_data import DatasetCardData, ModelCardData, SpaceCardData from .utils import ( DEFAULT_IGNORE_PATTERNS, @@ -708,13 +709,15 @@ def __init__(self, **kwargs): @dataclass class InferenceProviderMapping: - status: Literal["live", "staging"] + provider: str provider_id: str + status: Literal["live", "staging"] task: str def __init__(self, **kwargs): - self.status = kwargs.pop("status") + self.provider = kwargs.pop("provider") self.provider_id = kwargs.pop("providerId") + self.status = kwargs.pop("status") self.task = kwargs.pop("task") self.__dict__.update(**kwargs) @@ -757,12 +760,10 @@ class ModelInfo: If so, whether there is manual or automatic approval. gguf (`Dict`, *optional*): GGUF information of the model. - inference (`Literal["cold", "frozen", "warm"]`, *optional*): - Status of the model on the inference API. - Warm models are available for immediate use. Cold models will be loaded on first inference call. - Frozen models are not available in Inference API. - inference_provider_mapping (`Dict`, *optional*): - Model's inference provider mapping. + inference (`Literal["warm"]`, *optional*): + Status of the model on Inference Providers. Warm if the model is served by at least one provider. + inference_provider_mapping (`List[InferenceProviderMapping]`, *optional*): + A list of [`InferenceProviderMapping`] ordered after the user's provider order. likes (`int`): Number of likes of the model. library_name (`str`, *optional*): @@ -807,8 +808,8 @@ class ModelInfo: downloads_all_time: Optional[int] gated: Optional[Literal["auto", "manual", False]] gguf: Optional[Dict] - inference: Optional[Literal["warm", "cold", "frozen"]] - inference_provider_mapping: Optional[Dict[str, InferenceProviderMapping]] + inference: Optional[Literal["warm"]] + inference_provider_mapping: Optional[List[InferenceProviderMapping]] likes: Optional[int] library_name: Optional[str] tags: Optional[List[str]] @@ -846,10 +847,9 @@ def __init__(self, **kwargs): self.inference = kwargs.pop("inference", None) self.inference_provider_mapping = kwargs.pop("inferenceProviderMapping", None) if self.inference_provider_mapping: - self.inference_provider_mapping = { - provider: InferenceProviderMapping(**value) - for provider, value in self.inference_provider_mapping.items() - } + self.inference_provider_mapping = [ + InferenceProviderMapping(**value) for value in self.inference_provider_mapping + ] self.tags = kwargs.pop("tags", None) self.pipeline_tag = kwargs.pop("pipeline_tag", None) @@ -1816,7 +1816,8 @@ def list_models( filter: Union[str, Iterable[str], None] = None, author: Optional[str] = None, gated: Optional[bool] = None, - inference: Optional[Literal["cold", "frozen", "warm"]] = None, + inference: Optional[Literal["warm"]] = None, + inference_provider: Optional[Union[Literal["all"], PROVIDER_T, List[PROVIDER_T]]] = None, library: Optional[Union[str, List[str]]] = None, language: Optional[Union[str, List[str]]] = None, model_name: Optional[str] = None, @@ -1850,10 +1851,11 @@ def list_models( A boolean to filter models on the Hub that are gated or not. By default, all models are returned. If `gated=True` is passed, only gated models are returned. If `gated=False` is passed, only non-gated models are returned. - inference (`Literal["cold", "frozen", "warm"]`, *optional*): - A string to filter models on the Hub by their state on the Inference API. - Warm models are available for immediate use. Cold models will be loaded on first inference call. - Frozen models are not available in Inference API. + inference (`Literal["warm"]`, *optional*): + If "warm", filter models on the Hub currently served by at least one provider. + inference_provider (`Literal["all"]` or `str`, *optional*): + A string to filter models on the Hub that are served by a specific provider. + Pass `"all"` to get all models served by at least one provider. library (`str` or `List`, *optional*): A string or list of strings of foundational libraries models were originally trained from, such as pytorch, tensorflow, or allennlp. @@ -1943,6 +1945,9 @@ def list_models( # List all models with "bert" in their name made by google >>> api.list_models(search="bert", author="google") ``` + + # List all models served by Cohere + >>> api.list_models(inference_provider="cohere") """ if expand and (full or cardData or fetch_config): raise ValueError("`expand` cannot be used if `full`, `cardData` or `fetch_config` are passed.") @@ -1983,6 +1988,8 @@ def list_models( params["gated"] = gated if inference is not None: params["inference"] = inference + if inference_provider is not None: + params["inference_provider"] = inference_provider if pipeline_tag: params["pipeline_tag"] = pipeline_tag search_list = [] diff --git a/src/huggingface_hub/inference/_providers/_common.py b/src/huggingface_hub/inference/_providers/_common.py index d59f3f859c..d012464c76 100644 --- a/src/huggingface_hub/inference/_providers/_common.py +++ b/src/huggingface_hub/inference/_providers/_common.py @@ -1,5 +1,5 @@ from functools import lru_cache -from typing import Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from huggingface_hub import constants from huggingface_hub.inference._common import RequestParameters @@ -8,6 +8,9 @@ logger = logging.get_logger(__name__) +if TYPE_CHECKING: + from huggingface_hub.hf_api import InferenceProviderMapping + # Dev purposes only. # If you want to try to run inference for a new model locally before it's registered on huggingface.co @@ -118,7 +121,14 @@ def _prepare_mapped_model(self, model: Optional[str]) -> str: if HARDCODED_MODEL_ID_MAPPING.get(self.provider, {}).get(model): return HARDCODED_MODEL_ID_MAPPING[self.provider][model] - provider_mapping = _fetch_inference_provider_mapping(model).get(self.provider) + provider_mapping = next( + ( + provider_mapping + for provider_mapping in _fetch_inference_provider_mapping(model) + if provider_mapping.provider == self.provider + ), + None, + ) if provider_mapping is None: raise ValueError(f"Model {model} is not supported by provider {self.provider}.") @@ -220,7 +230,7 @@ def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: @lru_cache(maxsize=None) -def _fetch_inference_provider_mapping(model: str) -> Dict: +def _fetch_inference_provider_mapping(model: str) -> List["InferenceProviderMapping"]: """ Fetch provider mappings for a model from the Hub. """ diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index 4a2f5f3a1c..8e0b89623e 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -2314,6 +2314,18 @@ def test_filter_models_with_card_data(self): models = self._api.list_models(filter="co2_eq_emissions") assert all(model.card_data is None for model in models) + def test_filter_models_by_inference_provider(self): + models = list( + self._api.list_models(inference_provider="hf-inference", expand=["inferenceProviderMapping"], limit=10) + ) + assert len(models) > 0 + for model in models: + assert model.inference_provider_mapping is not None + assert any(mapping.provider == "hf-inference" for mapping in model.inference_provider_mapping) + + models = self._api.list_models(filter="co2_eq_emissions") + assert all(model.card_data is None for model in models) + def test_is_emission_within_threshold(self): # tests that dictionary is handled correctly as "emissions" and that # 17g is accepted and parsed correctly as a value diff --git a/tests/test_inference_providers.py b/tests/test_inference_providers.py index 67a3033840..3765c60a05 100644 --- a/tests/test_inference_providers.py +++ b/tests/test_inference_providers.py @@ -76,7 +76,9 @@ def test_prepare_mapped_model(self, mocker, caplog: LogCaptureFixture): # Test unsupported model mocker.patch( "huggingface_hub.inference._providers._common._fetch_inference_provider_mapping", - return_value={"other-provider": "mapping"}, + return_value=[ + mocker.Mock(provider="other-provider", task="task-name", provider_id="mapped-id", status="active") + ], ) with pytest.raises(ValueError, match="Model test-model is not supported.*"): helper._prepare_mapped_model("test-model") @@ -84,7 +86,9 @@ def test_prepare_mapped_model(self, mocker, caplog: LogCaptureFixture): # Test task mismatch mocker.patch( "huggingface_hub.inference._providers._common._fetch_inference_provider_mapping", - return_value={"provider-name": mocker.Mock(task="other-task", provider_id="mapped-id", status="active")}, + return_value=[ + mocker.Mock(provider="provider-name", task="other-task", provider_id="mapped-id", status="active") + ], ) with pytest.raises(ValueError, match="Model test-model is not supported for task.*"): helper._prepare_mapped_model("test-model") @@ -92,7 +96,9 @@ def test_prepare_mapped_model(self, mocker, caplog: LogCaptureFixture): # Test staging model mocker.patch( "huggingface_hub.inference._providers._common._fetch_inference_provider_mapping", - return_value={"provider-name": mocker.Mock(task="task-name", provider_id="mapped-id", status="staging")}, + return_value=[ + mocker.Mock(provider="provider-name", task="task-name", provider_id="mapped-id", status="staging") + ], ) assert helper._prepare_mapped_model("test-model") == "mapped-id" assert_in_logs( @@ -103,7 +109,9 @@ def test_prepare_mapped_model(self, mocker, caplog: LogCaptureFixture): caplog.clear() mocker.patch( "huggingface_hub.inference._providers._common._fetch_inference_provider_mapping", - return_value={"provider-name": mocker.Mock(task="task-name", provider_id="mapped-id", status="active")}, + return_value=[ + mocker.Mock(provider="provider-name", task="task-name", provider_id="mapped-id", status="live") + ], ) assert helper._prepare_mapped_model("test-model") == "mapped-id" assert len(caplog.records) == 0