Skip to content

Commit aadb655

Browse files
authored
Add Inference API (#2393)
1 parent 80a2792 commit aadb655

File tree

7 files changed

+440
-1
lines changed

7 files changed

+440
-1
lines changed

docs/sphinx/api.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ arguments are required for all calls.
1010
.. note::
1111

1212
Some API parameters in Elasticsearch are reserved keywords in Python.
13-
For example the ``from`` query parameter for pagination would be aliased as
13+
For example the ``from`` query parameter for pagination would be aliased as
1414
``from_``.
1515

1616
.. toctree::
@@ -28,6 +28,7 @@ arguments are required for all calls.
2828
api/graph-explore
2929
api/index-lifecycle-management
3030
api/indices
31+
api/inference
3132
api/ingest-pipelines
3233
api/license
3334
api/logstash

docs/sphinx/api/inference.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
.. _inference:
2+
3+
Inference
4+
---------
5+
.. py:module:: elasticsearch.client
6+
:noindex:
7+
8+
.. autoclass:: InferenceClient
9+
:members:

elasticsearch/_async/client/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from .graph import GraphClient
5555
from .ilm import IlmClient
5656
from .indices import IndicesClient
57+
from .inference import InferenceClient
5758
from .ingest import IngestClient
5859
from .license import LicenseClient
5960
from .logstash import LogstashClient
@@ -434,6 +435,7 @@ def __init__(
434435
self.fleet = FleetClient(self)
435436
self.features = FeaturesClient(self)
436437
self.indices = IndicesClient(self)
438+
self.inference = InferenceClient(self)
437439
self.ingest = IngestClient(self)
438440
self.nodes = NodesClient(self)
439441
self.snapshot = SnapshotClient(self)
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
# Licensed to Elasticsearch B.V. under one or more contributor
2+
# license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright
4+
# ownership. Elasticsearch B.V. licenses this file to you under
5+
# the Apache License, Version 2.0 (the "License"); you may
6+
# not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import typing as t
19+
20+
from elastic_transport import ObjectApiResponse
21+
22+
from ._base import NamespacedClient
23+
from .utils import SKIP_IN_PATH, _quote, _rewrite_parameters
24+
25+
26+
class InferenceClient(NamespacedClient):
27+
@_rewrite_parameters()
28+
async def delete_model(
29+
self,
30+
*,
31+
task_type: t.Union["t.Literal['sparse_embedding', 'text_embedding']", str],
32+
model_id: str,
33+
error_trace: t.Optional[bool] = None,
34+
filter_path: t.Optional[t.Union[str, t.Sequence[str]]] = None,
35+
human: t.Optional[bool] = None,
36+
pretty: t.Optional[bool] = None,
37+
) -> ObjectApiResponse[t.Any]:
38+
"""
39+
Delete model in the Inference API
40+
41+
`<https://www.elastic.co/guide/en/elasticsearch/reference/master/delete-inference-api.html>`_
42+
43+
:param task_type: The model task type
44+
:param model_id: The unique identifier of the inference model.
45+
"""
46+
if task_type in SKIP_IN_PATH:
47+
raise ValueError("Empty value passed for parameter 'task_type'")
48+
if model_id in SKIP_IN_PATH:
49+
raise ValueError("Empty value passed for parameter 'model_id'")
50+
__path = f"/_inference/{_quote(task_type)}/{_quote(model_id)}"
51+
__query: t.Dict[str, t.Any] = {}
52+
if error_trace is not None:
53+
__query["error_trace"] = error_trace
54+
if filter_path is not None:
55+
__query["filter_path"] = filter_path
56+
if human is not None:
57+
__query["human"] = human
58+
if pretty is not None:
59+
__query["pretty"] = pretty
60+
__headers = {"accept": "application/json"}
61+
return await self.perform_request( # type: ignore[return-value]
62+
"DELETE", __path, params=__query, headers=__headers
63+
)
64+
65+
@_rewrite_parameters()
66+
async def get_model(
67+
self,
68+
*,
69+
task_type: t.Union["t.Literal['sparse_embedding', 'text_embedding']", str],
70+
model_id: str,
71+
error_trace: t.Optional[bool] = None,
72+
filter_path: t.Optional[t.Union[str, t.Sequence[str]]] = None,
73+
human: t.Optional[bool] = None,
74+
pretty: t.Optional[bool] = None,
75+
) -> ObjectApiResponse[t.Any]:
76+
"""
77+
Get a model in the Inference API
78+
79+
`<https://www.elastic.co/guide/en/elasticsearch/reference/master/get-inference-api.html>`_
80+
81+
:param task_type: The model task type
82+
:param model_id: The unique identifier of the inference model.
83+
"""
84+
if task_type in SKIP_IN_PATH:
85+
raise ValueError("Empty value passed for parameter 'task_type'")
86+
if model_id in SKIP_IN_PATH:
87+
raise ValueError("Empty value passed for parameter 'model_id'")
88+
__path = f"/_inference/{_quote(task_type)}/{_quote(model_id)}"
89+
__query: t.Dict[str, t.Any] = {}
90+
if error_trace is not None:
91+
__query["error_trace"] = error_trace
92+
if filter_path is not None:
93+
__query["filter_path"] = filter_path
94+
if human is not None:
95+
__query["human"] = human
96+
if pretty is not None:
97+
__query["pretty"] = pretty
98+
__headers = {"accept": "application/json"}
99+
return await self.perform_request( # type: ignore[return-value]
100+
"GET", __path, params=__query, headers=__headers
101+
)
102+
103+
@_rewrite_parameters(
104+
body_fields=("input", "task_settings"),
105+
)
106+
async def inference(
107+
self,
108+
*,
109+
task_type: t.Union["t.Literal['sparse_embedding', 'text_embedding']", str],
110+
model_id: str,
111+
input: t.Optional[t.Union[str, t.Sequence[str]]] = None,
112+
error_trace: t.Optional[bool] = None,
113+
filter_path: t.Optional[t.Union[str, t.Sequence[str]]] = None,
114+
human: t.Optional[bool] = None,
115+
pretty: t.Optional[bool] = None,
116+
task_settings: t.Optional[t.Any] = None,
117+
body: t.Optional[t.Dict[str, t.Any]] = None,
118+
) -> ObjectApiResponse[t.Any]:
119+
"""
120+
Perform inference on a model
121+
122+
`<https://www.elastic.co/guide/en/elasticsearch/reference/master/post-inference-api.html>`_
123+
124+
:param task_type: The model task type
125+
:param model_id: The unique identifier of the inference model.
126+
:param input: Text input to the model. Either a string or an array of strings.
127+
:param task_settings: Optional task settings
128+
"""
129+
if task_type in SKIP_IN_PATH:
130+
raise ValueError("Empty value passed for parameter 'task_type'")
131+
if model_id in SKIP_IN_PATH:
132+
raise ValueError("Empty value passed for parameter 'model_id'")
133+
if input is None and body is None:
134+
raise ValueError("Empty value passed for parameter 'input'")
135+
__path = f"/_inference/{_quote(task_type)}/{_quote(model_id)}"
136+
__query: t.Dict[str, t.Any] = {}
137+
__body: t.Dict[str, t.Any] = body if body is not None else {}
138+
if error_trace is not None:
139+
__query["error_trace"] = error_trace
140+
if filter_path is not None:
141+
__query["filter_path"] = filter_path
142+
if human is not None:
143+
__query["human"] = human
144+
if pretty is not None:
145+
__query["pretty"] = pretty
146+
if not __body:
147+
if input is not None:
148+
__body["input"] = input
149+
if task_settings is not None:
150+
__body["task_settings"] = task_settings
151+
if not __body:
152+
__body = None # type: ignore[assignment]
153+
__headers = {"accept": "application/json"}
154+
if __body is not None:
155+
__headers["content-type"] = "application/json"
156+
return await self.perform_request( # type: ignore[return-value]
157+
"POST", __path, params=__query, headers=__headers, body=__body
158+
)
159+
160+
@_rewrite_parameters(
161+
body_name="model_config",
162+
)
163+
async def put_model(
164+
self,
165+
*,
166+
task_type: t.Union["t.Literal['sparse_embedding', 'text_embedding']", str],
167+
model_id: str,
168+
error_trace: t.Optional[bool] = None,
169+
filter_path: t.Optional[t.Union[str, t.Sequence[str]]] = None,
170+
human: t.Optional[bool] = None,
171+
model_config: t.Optional[t.Mapping[str, t.Any]] = None,
172+
body: t.Optional[t.Mapping[str, t.Any]] = None,
173+
pretty: t.Optional[bool] = None,
174+
) -> ObjectApiResponse[t.Any]:
175+
"""
176+
Configure a model for use in the Inference API
177+
178+
`<https://www.elastic.co/guide/en/elasticsearch/reference/master/put-inference-api.html>`_
179+
180+
:param task_type: The model task type
181+
:param model_id: The unique identifier of the inference model.
182+
:param model_config:
183+
"""
184+
if task_type in SKIP_IN_PATH:
185+
raise ValueError("Empty value passed for parameter 'task_type'")
186+
if model_id in SKIP_IN_PATH:
187+
raise ValueError("Empty value passed for parameter 'model_id'")
188+
if model_config is None and body is None:
189+
raise ValueError(
190+
"Empty value passed for parameters 'model_config' and 'body', one of them should be set."
191+
)
192+
elif model_config is not None and body is not None:
193+
raise ValueError("Cannot set both 'model_config' and 'body'")
194+
__path = f"/_inference/{_quote(task_type)}/{_quote(model_id)}"
195+
__query: t.Dict[str, t.Any] = {}
196+
if error_trace is not None:
197+
__query["error_trace"] = error_trace
198+
if filter_path is not None:
199+
__query["filter_path"] = filter_path
200+
if human is not None:
201+
__query["human"] = human
202+
if pretty is not None:
203+
__query["pretty"] = pretty
204+
__body = model_config if model_config is not None else body
205+
if not __body:
206+
__body = None
207+
__headers = {"accept": "application/json"}
208+
if __body is not None:
209+
__headers["content-type"] = "application/json"
210+
return await self.perform_request( # type: ignore[return-value]
211+
"PUT", __path, params=__query, headers=__headers, body=__body
212+
)

elasticsearch/_sync/client/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from .graph import GraphClient
5555
from .ilm import IlmClient
5656
from .indices import IndicesClient
57+
from .inference import InferenceClient
5758
from .ingest import IngestClient
5859
from .license import LicenseClient
5960
from .logstash import LogstashClient
@@ -434,6 +435,7 @@ def __init__(
434435
self.fleet = FleetClient(self)
435436
self.features = FeaturesClient(self)
436437
self.indices = IndicesClient(self)
438+
self.inference = InferenceClient(self)
437439
self.ingest = IngestClient(self)
438440
self.nodes = NodesClient(self)
439441
self.snapshot = SnapshotClient(self)

0 commit comments

Comments
 (0)