@@ -4501,3 +4501,138 @@ def test_create_inference_endpoint_from_catalog(self, mock_get_session: Mock) ->
4501
4501
)
4502
4502
assert isinstance (endpoint , InferenceEndpoint )
4503
4503
assert endpoint .name == "llama-3-2-3b-instruct-eey"
4504
+
4505
+
4506
+ @pytest .mark .parametrize (
4507
+ "custom_image, expected_image_payload" ,
4508
+ [
4509
+ # Case 1: No custom_image provided
4510
+ (
4511
+ None ,
4512
+ {
4513
+ "huggingface" : {},
4514
+ },
4515
+ ),
4516
+ # Case 2: Flat dictionary custom_image provided
4517
+ (
4518
+ {
4519
+ "url" : "my.registry/my-image:latest" ,
4520
+ "port" : 8080 ,
4521
+ },
4522
+ {
4523
+ "custom" : {
4524
+ "url" : "my.registry/my-image:latest" ,
4525
+ "port" : 8080 ,
4526
+ }
4527
+ },
4528
+ ),
4529
+ # Case 3: Explicitly keyed ('tgi') custom_image provided
4530
+ (
4531
+ {
4532
+ "tgi" : {
4533
+ "url" : "ghcr.io/huggingface/text-generation-inference:latest" ,
4534
+ }
4535
+ },
4536
+ {
4537
+ "tgi" : {
4538
+ "url" : "ghcr.io/huggingface/text-generation-inference:latest" ,
4539
+ }
4540
+ },
4541
+ ),
4542
+ # Case 4: Explicitly keyed ('custom') custom_image provided
4543
+ (
4544
+ {
4545
+ "custom" : {
4546
+ "url" : "another.registry/custom:v2" ,
4547
+ }
4548
+ },
4549
+ {
4550
+ "custom" : {
4551
+ "url" : "another.registry/custom:v2" ,
4552
+ }
4553
+ },
4554
+ ),
4555
+ ],
4556
+ ids = ["no_custom_image" , "flat_dict_custom_image" , "keyed_tgi_custom_image" , "keyed_custom_custom_image" ],
4557
+ )
4558
+ @patch ("huggingface_hub.hf_api.get_session" )
4559
+ def test_create_inference_endpoint_custom_image_payload (
4560
+ mock_post : Mock ,
4561
+ custom_image : Optional [dict ],
4562
+ expected_image_payload : dict ,
4563
+ ):
4564
+ common_args = {
4565
+ "name" : "test-endpoint-custom-img" ,
4566
+ "repository" : "meta-llama/Llama-2-7b-chat-hf" ,
4567
+ "framework" : "pytorch" ,
4568
+ "accelerator" : "gpu" ,
4569
+ "instance_size" : "medium" ,
4570
+ "instance_type" : "nvidia-a10g" ,
4571
+ "region" : "us-east-1" ,
4572
+ "vendor" : "aws" ,
4573
+ "type" : "protected" ,
4574
+ "task" : "text-generation" ,
4575
+ "namespace" : "Wauplin" ,
4576
+ }
4577
+ mock_session = mock_post .return_value
4578
+ mock_post_method = mock_session .post
4579
+ mock_response = Mock ()
4580
+ mock_response .raise_for_status .return_value = None
4581
+ mock_response .json .return_value = {
4582
+ "compute" : {
4583
+ "accelerator" : "gpu" ,
4584
+ "id" : "aws-us-east-1-nvidia-l4-x1" ,
4585
+ "instanceSize" : "x1" ,
4586
+ "instanceType" : "nvidia-l4" ,
4587
+ "scaling" : {
4588
+ "maxReplica" : 1 ,
4589
+ "measure" : {"hardwareUsage" : None },
4590
+ "metric" : "hardwareUsage" ,
4591
+ "minReplica" : 0 ,
4592
+ "scaleToZeroTimeout" : 15 ,
4593
+ },
4594
+ },
4595
+ "model" : {
4596
+ "env" : {},
4597
+ "framework" : "pytorch" ,
4598
+ "image" : {
4599
+ "tgi" : {
4600
+ "disableCustomKernels" : False ,
4601
+ "healthRoute" : "/health" ,
4602
+ "port" : 80 ,
4603
+ "url" : "ghcr.io/huggingface/text-generation-inference:3.1.1" ,
4604
+ }
4605
+ },
4606
+ "repository" : "meta-llama/Llama-3.2-3B-Instruct" ,
4607
+ "revision" : "0cb88a4f764b7a12671c53f0838cd831a0843b95" ,
4608
+ "secrets" : {},
4609
+ "task" : "text-generation" ,
4610
+ },
4611
+ "name" : "llama-3-2-3b-instruct-eey" ,
4612
+ "provider" : {"region" : "us-east-1" , "vendor" : "aws" },
4613
+ "status" : {
4614
+ "createdAt" : "2025-03-07T15:30:13.949Z" ,
4615
+ "createdBy" : {"id" : "6273f303f6d63a28483fde12" , "name" : "Wauplin" },
4616
+ "message" : "Endpoint waiting to be scheduled" ,
4617
+ "readyReplica" : 0 ,
4618
+ "state" : "pending" ,
4619
+ "targetReplica" : 1 ,
4620
+ "updatedAt" : "2025-03-07T15:30:13.949Z" ,
4621
+ "updatedBy" : {"id" : "6273f303f6d63a28483fde12" , "name" : "Wauplin" },
4622
+ },
4623
+ "type" : "protected" ,
4624
+ }
4625
+ mock_post_method .return_value = mock_response
4626
+
4627
+ api = HfApi (endpoint = ENDPOINT_STAGING , token = TOKEN )
4628
+ if custom_image is not None :
4629
+ api .create_inference_endpoint (custom_image = custom_image , ** common_args )
4630
+ else :
4631
+ api .create_inference_endpoint (** common_args )
4632
+
4633
+ mock_post_method .assert_called_once ()
4634
+ _ , call_kwargs = mock_post_method .call_args
4635
+ payload = call_kwargs .get ("json" , {})
4636
+
4637
+ assert "model" in payload and "image" in payload ["model" ]
4638
+ assert payload ["model" ]["image" ] == expected_image_payload
0 commit comments