Skip to content

Fix inference search #3022

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/package_reference/hf_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ models = hf_api.list_models()

[[autodoc]] huggingface_hub.hf_api.GitRefs

### InferenceProviderMapping

[[autodoc]] huggingface_hub.hf_api.InferenceProviderMapping

### LFSFileInfo

[[autodoc]] huggingface_hub.hf_api.LFSFileInfo
Expand Down
45 changes: 26 additions & 19 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
RevisionNotFoundError,
)
from .file_download import HfFileMetadata, get_hf_file_metadata, hf_hub_url
from .inference._providers import PROVIDER_T
from .repocard_data import DatasetCardData, ModelCardData, SpaceCardData
from .utils import (
DEFAULT_IGNORE_PATTERNS,
Expand Down Expand Up @@ -708,13 +709,15 @@ def __init__(self, **kwargs):

@dataclass
class InferenceProviderMapping:
status: Literal["live", "staging"]
provider: str
provider_id: str
status: Literal["live", "staging"]
task: str

def __init__(self, **kwargs):
self.status = kwargs.pop("status")
self.provider = kwargs.pop("provider")
self.provider_id = kwargs.pop("providerId")
self.status = kwargs.pop("status")
self.task = kwargs.pop("task")
self.__dict__.update(**kwargs)

Expand Down Expand Up @@ -757,12 +760,10 @@ class ModelInfo:
If so, whether there is manual or automatic approval.
gguf (`Dict`, *optional*):
GGUF information of the model.
inference (`Literal["cold", "frozen", "warm"]`, *optional*):
Status of the model on the inference API.
Warm models are available for immediate use. Cold models will be loaded on first inference call.
Frozen models are not available in Inference API.
inference_provider_mapping (`Dict`, *optional*):
Model's inference provider mapping.
inference (`Literal["warm"]`, *optional*):
Status of the model on Inference Providers. Warm if the model is served by at least one provider.
inference_provider_mapping (`List[InferenceProviderMapping]`, *optional*):
A list of [`InferenceProviderMapping`] ordered after the user's provider order.
likes (`int`):
Number of likes of the model.
library_name (`str`, *optional*):
Expand Down Expand Up @@ -807,8 +808,8 @@ class ModelInfo:
downloads_all_time: Optional[int]
gated: Optional[Literal["auto", "manual", False]]
gguf: Optional[Dict]
inference: Optional[Literal["warm", "cold", "frozen"]]
inference_provider_mapping: Optional[Dict[str, InferenceProviderMapping]]
inference: Optional[Literal["warm"]]
inference_provider_mapping: Optional[List[InferenceProviderMapping]]
likes: Optional[int]
library_name: Optional[str]
tags: Optional[List[str]]
Expand Down Expand Up @@ -846,10 +847,9 @@ def __init__(self, **kwargs):
self.inference = kwargs.pop("inference", None)
self.inference_provider_mapping = kwargs.pop("inferenceProviderMapping", None)
if self.inference_provider_mapping:
self.inference_provider_mapping = {
provider: InferenceProviderMapping(**value)
for provider, value in self.inference_provider_mapping.items()
}
self.inference_provider_mapping = [
InferenceProviderMapping(**value) for value in self.inference_provider_mapping
]

self.tags = kwargs.pop("tags", None)
self.pipeline_tag = kwargs.pop("pipeline_tag", None)
Expand Down Expand Up @@ -1816,7 +1816,8 @@ def list_models(
filter: Union[str, Iterable[str], None] = None,
author: Optional[str] = None,
gated: Optional[bool] = None,
inference: Optional[Literal["cold", "frozen", "warm"]] = None,
inference: Optional[Literal["warm"]] = None,
inference_provider: Optional[Union[Literal["all"], PROVIDER_T, List[PROVIDER_T]]] = None,
library: Optional[Union[str, List[str]]] = None,
language: Optional[Union[str, List[str]]] = None,
model_name: Optional[str] = None,
Expand Down Expand Up @@ -1850,10 +1851,11 @@ def list_models(
A boolean to filter models on the Hub that are gated or not. By default, all models are returned.
If `gated=True` is passed, only gated models are returned.
If `gated=False` is passed, only non-gated models are returned.
inference (`Literal["cold", "frozen", "warm"]`, *optional*):
A string to filter models on the Hub by their state on the Inference API.
Warm models are available for immediate use. Cold models will be loaded on first inference call.
Frozen models are not available in Inference API.
inference (`Literal["warm"]`, *optional*):
If "warm", filter models on the Hub currently served by at least one provider.
inference_provider (`Literal["all"]` or `str`, *optional*):
A string to filter models on the Hub that are served by a specific provider.
Pass `"all"` to get all models served by at least one provider.
library (`str` or `List`, *optional*):
A string or list of strings of foundational libraries models were
originally trained from, such as pytorch, tensorflow, or allennlp.
Expand Down Expand Up @@ -1943,6 +1945,9 @@ def list_models(
# List all models with "bert" in their name made by google
>>> api.list_models(search="bert", author="google")
```

# List all models served by Cohere
>>> api.list_models(inference_provider="cohere")
"""
if expand and (full or cardData or fetch_config):
raise ValueError("`expand` cannot be used if `full`, `cardData` or `fetch_config` are passed.")
Expand Down Expand Up @@ -1983,6 +1988,8 @@ def list_models(
params["gated"] = gated
if inference is not None:
params["inference"] = inference
if inference_provider is not None:
params["inference_provider"] = inference_provider
if pipeline_tag:
params["pipeline_tag"] = pipeline_tag
search_list = []
Expand Down
16 changes: 13 additions & 3 deletions src/huggingface_hub/inference/_providers/_common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import lru_cache
from typing import Any, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from huggingface_hub import constants
from huggingface_hub.inference._common import RequestParameters
Expand All @@ -8,6 +8,9 @@

logger = logging.get_logger(__name__)

if TYPE_CHECKING:
from huggingface_hub.hf_api import InferenceProviderMapping


# Dev purposes only.
# If you want to try to run inference for a new model locally before it's registered on huggingface.co
Expand Down Expand Up @@ -118,7 +121,14 @@ def _prepare_mapped_model(self, model: Optional[str]) -> str:
if HARDCODED_MODEL_ID_MAPPING.get(self.provider, {}).get(model):
return HARDCODED_MODEL_ID_MAPPING[self.provider][model]

provider_mapping = _fetch_inference_provider_mapping(model).get(self.provider)
provider_mapping = next(
(
provider_mapping
for provider_mapping in _fetch_inference_provider_mapping(model)
if provider_mapping.provider == self.provider
),
None,
)
if provider_mapping is None:
raise ValueError(f"Model {model} is not supported by provider {self.provider}.")

Expand Down Expand Up @@ -220,7 +230,7 @@ def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model:


@lru_cache(maxsize=None)
def _fetch_inference_provider_mapping(model: str) -> Dict:
def _fetch_inference_provider_mapping(model: str) -> List["InferenceProviderMapping"]:
"""
Fetch provider mappings for a model from the Hub.
"""
Expand Down
12 changes: 12 additions & 0 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2314,6 +2314,18 @@ def test_filter_models_with_card_data(self):
models = self._api.list_models(filter="co2_eq_emissions")
assert all(model.card_data is None for model in models)

def test_filter_models_by_inference_provider(self):
models = list(
self._api.list_models(inference_provider="hf-inference", expand=["inferenceProviderMapping"], limit=10)
)
assert len(models) > 0
for model in models:
assert model.inference_provider_mapping is not None
assert any(mapping.provider == "hf-inference" for mapping in model.inference_provider_mapping)

models = self._api.list_models(filter="co2_eq_emissions")
assert all(model.card_data is None for model in models)

def test_is_emission_within_threshold(self):
# tests that dictionary is handled correctly as "emissions" and that
# 17g is accepted and parsed correctly as a value
Expand Down
16 changes: 12 additions & 4 deletions tests/test_inference_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,23 +76,29 @@ def test_prepare_mapped_model(self, mocker, caplog: LogCaptureFixture):
# Test unsupported model
mocker.patch(
"huggingface_hub.inference._providers._common._fetch_inference_provider_mapping",
return_value={"other-provider": "mapping"},
return_value=[
mocker.Mock(provider="other-provider", task="task-name", provider_id="mapped-id", status="active")
],
)
with pytest.raises(ValueError, match="Model test-model is not supported.*"):
helper._prepare_mapped_model("test-model")

# Test task mismatch
mocker.patch(
"huggingface_hub.inference._providers._common._fetch_inference_provider_mapping",
return_value={"provider-name": mocker.Mock(task="other-task", provider_id="mapped-id", status="active")},
return_value=[
mocker.Mock(provider="provider-name", task="other-task", provider_id="mapped-id", status="active")
],
)
with pytest.raises(ValueError, match="Model test-model is not supported for task.*"):
helper._prepare_mapped_model("test-model")

# Test staging model
mocker.patch(
"huggingface_hub.inference._providers._common._fetch_inference_provider_mapping",
return_value={"provider-name": mocker.Mock(task="task-name", provider_id="mapped-id", status="staging")},
return_value=[
mocker.Mock(provider="provider-name", task="task-name", provider_id="mapped-id", status="staging")
],
)
assert helper._prepare_mapped_model("test-model") == "mapped-id"
assert_in_logs(
Expand All @@ -103,7 +109,9 @@ def test_prepare_mapped_model(self, mocker, caplog: LogCaptureFixture):
caplog.clear()
mocker.patch(
"huggingface_hub.inference._providers._common._fetch_inference_provider_mapping",
return_value={"provider-name": mocker.Mock(task="task-name", provider_id="mapped-id", status="active")},
return_value=[
mocker.Mock(provider="provider-name", task="task-name", provider_id="mapped-id", status="live")
],
)
assert helper._prepare_mapped_model("test-model") == "mapped-id"
assert len(caplog.records) == 0
Expand Down
Loading