Skip to content

Commit 7fde08d

Browse files
[Inference Endpoints] fix inference endpoint creation with custom image (#3076)
* fix inference endpoint creation with custom image * add test * Update src/huggingface_hub/constants.py Co-authored-by: Lucain <lucain@huggingface.co> --------- Co-authored-by: Lucain <lucain@huggingface.co>
1 parent 68475be commit 7fde08d

File tree

3 files changed

+155
-1
lines changed

3 files changed

+155
-1
lines changed

src/huggingface_hub/constants.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,17 @@ def _as_int(value: Optional[str]) -> Optional[int]:
8282
INFERENCE_ENDPOINTS_ENDPOINT = "https://api.endpoints.huggingface.cloud/v2"
8383
INFERENCE_CATALOG_ENDPOINT = "https://endpoints.huggingface.co/api/catalog"
8484

85+
# See https://api.endpoints.huggingface.cloud/#post-/v2/endpoint/-namespace-
86+
INFERENCE_ENDPOINT_IMAGE_KEYS = [
87+
"custom",
88+
"huggingface",
89+
"huggingfaceNeuron",
90+
"llamacpp",
91+
"tei",
92+
"tgi",
93+
"tgiNeuron",
94+
]
95+
8596
# Proxy for third-party providers
8697
INFERENCE_PROXY_TEMPLATE = "https://router.huggingface.co/{provider}"
8798

src/huggingface_hub/hf_api.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7698,7 +7698,15 @@ def create_inference_endpoint(
76987698
"""
76997699
namespace = namespace or self._get_namespace(token=token)
77007700

7701-
image = {"custom": custom_image} if custom_image is not None else {"huggingface": {}}
7701+
if custom_image is not None:
7702+
image = (
7703+
custom_image
7704+
if next(iter(custom_image)) in constants.INFERENCE_ENDPOINT_IMAGE_KEYS
7705+
else {"custom": custom_image}
7706+
)
7707+
else:
7708+
image = {"huggingface": {}}
7709+
77027710
payload: Dict = {
77037711
"accountId": account_id,
77047712
"compute": {

tests/test_hf_api.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4501,3 +4501,138 @@ def test_create_inference_endpoint_from_catalog(self, mock_get_session: Mock) ->
45014501
)
45024502
assert isinstance(endpoint, InferenceEndpoint)
45034503
assert endpoint.name == "llama-3-2-3b-instruct-eey"
4504+
4505+
4506+
@pytest.mark.parametrize(
4507+
"custom_image, expected_image_payload",
4508+
[
4509+
# Case 1: No custom_image provided
4510+
(
4511+
None,
4512+
{
4513+
"huggingface": {},
4514+
},
4515+
),
4516+
# Case 2: Flat dictionary custom_image provided
4517+
(
4518+
{
4519+
"url": "my.registry/my-image:latest",
4520+
"port": 8080,
4521+
},
4522+
{
4523+
"custom": {
4524+
"url": "my.registry/my-image:latest",
4525+
"port": 8080,
4526+
}
4527+
},
4528+
),
4529+
# Case 3: Explicitly keyed ('tgi') custom_image provided
4530+
(
4531+
{
4532+
"tgi": {
4533+
"url": "ghcr.io/huggingface/text-generation-inference:latest",
4534+
}
4535+
},
4536+
{
4537+
"tgi": {
4538+
"url": "ghcr.io/huggingface/text-generation-inference:latest",
4539+
}
4540+
},
4541+
),
4542+
# Case 4: Explicitly keyed ('custom') custom_image provided
4543+
(
4544+
{
4545+
"custom": {
4546+
"url": "another.registry/custom:v2",
4547+
}
4548+
},
4549+
{
4550+
"custom": {
4551+
"url": "another.registry/custom:v2",
4552+
}
4553+
},
4554+
),
4555+
],
4556+
ids=["no_custom_image", "flat_dict_custom_image", "keyed_tgi_custom_image", "keyed_custom_custom_image"],
4557+
)
4558+
@patch("huggingface_hub.hf_api.get_session")
4559+
def test_create_inference_endpoint_custom_image_payload(
4560+
mock_post: Mock,
4561+
custom_image: Optional[dict],
4562+
expected_image_payload: dict,
4563+
):
4564+
common_args = {
4565+
"name": "test-endpoint-custom-img",
4566+
"repository": "meta-llama/Llama-2-7b-chat-hf",
4567+
"framework": "pytorch",
4568+
"accelerator": "gpu",
4569+
"instance_size": "medium",
4570+
"instance_type": "nvidia-a10g",
4571+
"region": "us-east-1",
4572+
"vendor": "aws",
4573+
"type": "protected",
4574+
"task": "text-generation",
4575+
"namespace": "Wauplin",
4576+
}
4577+
mock_session = mock_post.return_value
4578+
mock_post_method = mock_session.post
4579+
mock_response = Mock()
4580+
mock_response.raise_for_status.return_value = None
4581+
mock_response.json.return_value = {
4582+
"compute": {
4583+
"accelerator": "gpu",
4584+
"id": "aws-us-east-1-nvidia-l4-x1",
4585+
"instanceSize": "x1",
4586+
"instanceType": "nvidia-l4",
4587+
"scaling": {
4588+
"maxReplica": 1,
4589+
"measure": {"hardwareUsage": None},
4590+
"metric": "hardwareUsage",
4591+
"minReplica": 0,
4592+
"scaleToZeroTimeout": 15,
4593+
},
4594+
},
4595+
"model": {
4596+
"env": {},
4597+
"framework": "pytorch",
4598+
"image": {
4599+
"tgi": {
4600+
"disableCustomKernels": False,
4601+
"healthRoute": "/health",
4602+
"port": 80,
4603+
"url": "ghcr.io/huggingface/text-generation-inference:3.1.1",
4604+
}
4605+
},
4606+
"repository": "meta-llama/Llama-3.2-3B-Instruct",
4607+
"revision": "0cb88a4f764b7a12671c53f0838cd831a0843b95",
4608+
"secrets": {},
4609+
"task": "text-generation",
4610+
},
4611+
"name": "llama-3-2-3b-instruct-eey",
4612+
"provider": {"region": "us-east-1", "vendor": "aws"},
4613+
"status": {
4614+
"createdAt": "2025-03-07T15:30:13.949Z",
4615+
"createdBy": {"id": "6273f303f6d63a28483fde12", "name": "Wauplin"},
4616+
"message": "Endpoint waiting to be scheduled",
4617+
"readyReplica": 0,
4618+
"state": "pending",
4619+
"targetReplica": 1,
4620+
"updatedAt": "2025-03-07T15:30:13.949Z",
4621+
"updatedBy": {"id": "6273f303f6d63a28483fde12", "name": "Wauplin"},
4622+
},
4623+
"type": "protected",
4624+
}
4625+
mock_post_method.return_value = mock_response
4626+
4627+
api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN)
4628+
if custom_image is not None:
4629+
api.create_inference_endpoint(custom_image=custom_image, **common_args)
4630+
else:
4631+
api.create_inference_endpoint(**common_args)
4632+
4633+
mock_post_method.assert_called_once()
4634+
_, call_kwargs = mock_post_method.call_args
4635+
payload = call_kwargs.get("json", {})
4636+
4637+
assert "model" in payload and "image" in payload["model"]
4638+
assert payload["model"]["image"] == expected_image_payload

0 commit comments

Comments
 (0)