Skip to content

Commit

Permalink
feat: ✨ allow passing secrets to the inference endpoint client (#2486)
Browse files Browse the repository at this point in the history
* feat: ✨ allow passing secrets through the inference endpoint client

added the secrets argument to create_inference_endpoint, update_inference_endpoint and InferenceEndpoint.update

* test: ✅ add secrets to test

* Apply suggestions from code review

---------

Co-authored-by: Lucain <[email protected]>
  • Loading branch information
LuisBlanche and Wauplin authored Aug 23, 2024
1 parent 73ee664 commit 6438044
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 5 deletions.
5 changes: 4 additions & 1 deletion src/huggingface_hub/_inference_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def update(
revision: Optional[str] = None,
task: Optional[str] = None,
custom_image: Optional[Dict] = None,
secrets: Optional[Dict[str, str]] = None,
) -> "InferenceEndpoint":
"""Update the Inference Endpoint.
Expand Down Expand Up @@ -279,7 +280,8 @@ def update(
custom_image (`Dict`, *optional*):
A custom Docker image to use for the Inference Endpoint. This is useful if you want to deploy an
Inference Endpoint running on the `text-generation-inference` (TGI) framework (see examples).
secrets (`Dict[str, str]`, *optional*):
Secret values to inject in the container environment.
Returns:
[`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
"""
Expand All @@ -298,6 +300,7 @@ def update(
revision=revision,
task=task,
custom_image=custom_image,
secrets=secrets,
token=self._token, # type: ignore [arg-type]
)

Expand Down
15 changes: 11 additions & 4 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,7 @@
)
from .utils import tqdm as hf_tqdm
from .utils._typing import CallableT
from .utils.endpoint_helpers import (
_is_emission_within_threshold,
)
from .utils.endpoint_helpers import _is_emission_within_threshold


R = TypeVar("R") # Return type
Expand Down Expand Up @@ -7418,6 +7416,7 @@ def create_inference_endpoint(
revision: Optional[str] = None,
task: Optional[str] = None,
custom_image: Optional[Dict] = None,
secrets: Optional[Dict[str, str]] = None,
type: InferenceEndpointType = InferenceEndpointType.PROTECTED,
namespace: Optional[str] = None,
token: Union[bool, str, None] = None,
Expand Down Expand Up @@ -7456,6 +7455,8 @@ def create_inference_endpoint(
custom_image (`Dict`, *optional*):
A custom Docker image to use for the Inference Endpoint. This is useful if you want to deploy an
Inference Endpoint running on the `text-generation-inference` (TGI) framework (see examples).
secrets (`Dict[str, str]`, *optional*):
Secret values to inject in the container environment.
type ([`InferenceEndpointType]`, *optional*):
The type of the Inference Endpoint, which can be `"protected"` (default), `"public"` or `"private"`.
namespace (`str`, *optional*):
Expand Down Expand Up @@ -7518,6 +7519,7 @@ def create_inference_endpoint(
... },
... "url": "ghcr.io/huggingface/text-generation-inference:1.1.0",
... },
... secrets={"MY_SECRET_KEY": "secret_value"},
... )
```
Expand All @@ -7543,6 +7545,7 @@ def create_inference_endpoint(
"revision": revision,
"task": task,
"image": image,
"secrets": secrets,
},
"name": name,
"provider": {
Expand Down Expand Up @@ -7625,6 +7628,7 @@ def update_inference_endpoint(
revision: Optional[str] = None,
task: Optional[str] = None,
custom_image: Optional[Dict] = None,
secrets: Optional[Dict[str, str]] = None,
# Other
namespace: Optional[str] = None,
token: Union[bool, str, None] = None,
Expand Down Expand Up @@ -7664,7 +7668,8 @@ def update_inference_endpoint(
custom_image (`Dict`, *optional*):
A custom Docker image to use for the Inference Endpoint. This is useful if you want to deploy an
Inference Endpoint running on the `text-generation-inference` (TGI) framework (see examples).
secrets (`Dict[str, str]`, *optional*):
Secret values to inject in the container environment.
namespace (`str`, *optional*):
The namespace where the Inference Endpoint will be updated. Defaults to the current user's namespace.
token (Union[bool, str, None], optional):
Expand Down Expand Up @@ -7702,6 +7707,8 @@ def update_inference_endpoint(
payload["model"]["task"] = task
if custom_image is not None:
payload["model"]["image"] = {"custom": custom_image}
if secrets is not None:
payload["model"]["secrets"] = secrets

response = get_session().put(
f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}",
Expand Down
3 changes: 3 additions & 0 deletions tests/test_inference_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"task": "text-generation",
"framework": "pytorch",
"image": {"huggingface": {}},
"secret": {"token": "my-token"},
},
"status": {
"createdAt": "2023-10-26T12:41:53.263078506Z",
Expand Down Expand Up @@ -61,6 +62,7 @@
"task": "text-generation",
"framework": "pytorch",
"image": {"huggingface": {}},
"secrets": {"token": "my-token"},
},
"status": {
"createdAt": "2023-10-26T12:41:53.263Z",
Expand Down Expand Up @@ -93,6 +95,7 @@
"task": "text-generation",
"framework": "pytorch",
"image": {"huggingface": {}},
"secrets": {"token": "my-token"},
},
"status": {
"createdAt": "2023-10-26T12:41:53.263Z",
Expand Down

0 comments on commit 6438044

Please sign in to comment.