Skip to content

Commit ec4fc8b

Browse files
committed
Add AsyncTransport
1 parent bb78767 commit ec4fc8b

File tree

5 files changed

+830
-38
lines changed

5 files changed

+830
-38
lines changed

elasticsearch/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@
7272
raise ImportError
7373

7474
from ._async.http_aiohttp import AIOHttpConnection
75+
from ._async.transport import AsyncTransport
7576

76-
__all__ += ["AIOHttpConnection"]
77+
__all__ += ["AIOHttpConnection", "AsyncTransport"]
7778
except (ImportError, SyntaxError):
7879
pass

elasticsearch/_async/transport.py

Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
# Licensed to Elasticsearch B.V under one or more agreements.
2+
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
3+
# See the LICENSE file in the project root for more information
4+
5+
import asyncio
6+
import logging
7+
from itertools import chain
8+
9+
from .compat import get_running_loop
10+
from .http_aiohttp import AIOHttpConnection
11+
from ..transport import Transport
12+
from ..exceptions import (
13+
TransportError,
14+
ConnectionTimeout,
15+
ConnectionError,
16+
SerializationError,
17+
)
18+
19+
20+
logger = logging.getLogger("elasticsearch")
21+
22+
23+
async def connection_pool_aclose(self):
24+
"""Shim for closing ConnectionPools that contain async connections"""
25+
for connection in self.connections:
26+
await connection.close()
27+
28+
29+
class AsyncTransport(Transport):
30+
"""
31+
Encapsulation of transport-related to logic. Handles instantiation of the
32+
individual connections as well as creating a connection pool to hold them.
33+
34+
Main interface is the `perform_request` method.
35+
"""
36+
37+
DEFAULT_CONNECTION_CLASS = AIOHttpConnection
38+
39+
def __init__(self, hosts, *args, sniff_on_start=False, **kwargs):
40+
"""
41+
:arg hosts: list of dictionaries, each containing keyword arguments to
42+
create a `connection_class` instance
43+
:arg connection_class: subclass of :class:`~elasticsearch.Connection` to use
44+
:arg connection_pool_class: subclass of :class:`~elasticsearch.ConnectionPool` to use
45+
:arg host_info_callback: callback responsible for taking the node information from
46+
`/_cluster/nodes`, along with already extracted information, and
47+
producing a list of arguments (same as `hosts` parameter)
48+
:arg sniff_on_start: flag indicating whether to obtain a list of nodes
49+
from the cluster at startup time
50+
:arg sniffer_timeout: number of seconds between automatic sniffs
51+
:arg sniff_on_connection_fail: flag controlling if connection failure triggers a sniff
52+
:arg sniff_timeout: timeout used for the sniff request - it should be a
53+
fast api call and we are talking potentially to more nodes so we want
54+
to fail quickly. Not used during initial sniffing (if
55+
``sniff_on_start`` is on) when the connection still isn't
56+
initialized.
57+
:arg serializer: serializer instance
58+
:arg serializers: optional dict of serializer instances that will be
59+
used for deserializing data coming from the server. (key is the mimetype)
60+
:arg default_mimetype: when no mimetype is specified by the server
61+
response assume this mimetype, defaults to `'application/json'`
62+
:arg max_retries: maximum number of retries before an exception is propagated
63+
:arg retry_on_status: set of HTTP status codes on which we should retry
64+
on a different node. defaults to ``(502, 503, 504)``
65+
:arg retry_on_timeout: should timeout trigger a retry on different
66+
node? (default `False`)
67+
:arg send_get_body_as: for GET requests with body this option allows
68+
you to specify an alternate way of execution for environments that
69+
don't support passing bodies with GET requests. If you set this to
70+
'POST' a POST method will be used instead, if to 'source' then the body
71+
will be serialized and passed as a query parameter `source`.
72+
73+
Any extra keyword arguments will be passed to the `connection_class`
74+
when creating and instance unless overridden by that connection's
75+
options provided as part of the hosts parameter.
76+
"""
77+
self.sniffing_task = None
78+
self.loop = None
79+
self._async_init_called = False
80+
81+
super(AsyncTransport, self).__init__(
82+
*args, hosts=[], sniff_on_start=False, **kwargs
83+
)
84+
85+
# Don't enable sniffing on Cloud instances.
86+
if kwargs.get("cloud_id", False):
87+
sniff_on_start = False
88+
89+
# Since we defer connections / sniffing to not occur
90+
# within the constructor we never want to signal to
91+
# our parent to 'sniff_on_start' or non-empty 'hosts'.
92+
self.hosts = hosts
93+
self.sniff_on_start = sniff_on_start
94+
95+
async def _async_init(self):
96+
"""This is our stand-in for an async constructor. Everything
97+
that was deferred within __init__() should be done here now.
98+
99+
This method will only be called once per AsyncTransport instance
100+
and is called from one of AsyncElasticsearch.__aenter__(),
101+
AsyncTransport.perform_request() or AsyncTransport.get_connection()
102+
"""
103+
# Detect the async loop we're running in and set it
104+
# on all already created HTTP connections.
105+
self.loop = get_running_loop()
106+
self.kwargs["loop"] = self.loop
107+
108+
# Now that we have a loop we can create all our HTTP connections
109+
self.set_connections(self.hosts)
110+
self.seed_connections = list(self.connection_pool.connections[:])
111+
112+
# ... and we can start sniffing in the background.
113+
if self.sniffing_task is None and self.sniff_on_start:
114+
self.last_sniff = self.loop.time()
115+
self.create_sniff_task(initial=True)
116+
117+
async def _async_call(self):
118+
"""This method is called within any async method of AsyncTransport
119+
where the transport is not closing. This will check to see if we should
120+
call our _async_init() or create a new sniffing task
121+
"""
122+
if not self._async_init_called:
123+
self._async_init_called = True
124+
await self._async_init()
125+
126+
if self.sniffer_timeout:
127+
if self.loop.time() >= self.last_sniff + self.sniff_timeout:
128+
self.create_sniff_task()
129+
130+
async def _get_node_info(self, conn, initial):
131+
try:
132+
# use small timeout for the sniffing request, should be a fast api call
133+
_, headers, node_info = await conn.perform_request(
134+
"GET",
135+
"/_nodes/_all/http",
136+
timeout=self.sniff_timeout if not initial else None,
137+
)
138+
return self.deserializer.loads(node_info, headers.get("content-type"))
139+
except Exception:
140+
pass
141+
return None
142+
143+
async def _get_sniff_data(self, initial=False):
144+
previous_sniff = self.last_sniff
145+
146+
# reset last_sniff timestamp
147+
self.last_sniff = self.loop.time()
148+
149+
# use small timeout for the sniffing request, should be a fast api call
150+
timeout = self.sniff_timeout if not initial else None
151+
152+
def _sniff_request(conn):
153+
return self.loop.create_task(
154+
conn.perform_request("GET", "/_nodes/_all/http", timeout=timeout)
155+
)
156+
157+
# Go through all current connections as well as the
158+
# seed_connections for good measure
159+
tasks = []
160+
for conn in self.connection_pool.connections:
161+
tasks.append(_sniff_request(conn))
162+
for conn in self.seed_connections:
163+
# Ensure that we don't have any duplication within seed_connections.
164+
if conn in self.connection_pool.connections:
165+
continue
166+
tasks.append(_sniff_request(conn))
167+
168+
done = ()
169+
try:
170+
while tasks:
171+
# execute sniff requests in parallel, wait for first to return
172+
done, tasks = await asyncio.wait(
173+
tasks, return_when=asyncio.FIRST_COMPLETED, loop=self.loop
174+
)
175+
# go through all the finished tasks
176+
for t in done:
177+
try:
178+
_, headers, node_info = t.result()
179+
node_info = self.deserializer.loads(
180+
node_info, headers.get("content-type")
181+
)
182+
except (ConnectionError, SerializationError):
183+
continue
184+
node_info = list(node_info["nodes"].values())
185+
return node_info
186+
else:
187+
# no task has finished completely
188+
raise TransportError("N/A", "Unable to sniff hosts.")
189+
except Exception:
190+
# keep the previous value on error
191+
self.last_sniff = previous_sniff
192+
raise
193+
finally:
194+
# Cancel all the pending tasks
195+
for task in chain(done, tasks):
196+
task.cancel()
197+
198+
async def sniff_hosts(self, initial=False):
199+
"""Either spawns a sniffing_task which does regular sniffing
200+
over time or does a single sniffing session and awaits the results.
201+
"""
202+
# Without a loop we can't do anything.
203+
if not self.loop:
204+
return
205+
206+
node_info = await self._get_sniff_data(initial)
207+
hosts = list(filter(None, (self._get_host_info(n) for n in node_info)))
208+
209+
# we weren't able to get any nodes, maybe using an incompatible
210+
# transport_schema or host_info_callback blocked all - raise error.
211+
if not hosts:
212+
raise TransportError(
213+
"N/A", "Unable to sniff hosts - no viable hosts found."
214+
)
215+
216+
# remember current live connections
217+
orig_connections = self.connection_pool.connections[:]
218+
self.set_connections(hosts)
219+
# close those connections that are not in use any more
220+
for c in orig_connections:
221+
if c not in self.connection_pool.connections:
222+
await c.close()
223+
224+
def create_sniff_task(self, initial=False):
225+
"""
226+
Initiate a sniffing task. Make sure we only have one sniff request
227+
running at any given time. If a finished sniffing request is around,
228+
collect its result (which can raise its exception).
229+
"""
230+
if self.sniffing_task and self.sniffing_task.done():
231+
try:
232+
if self.sniffing_task is not None:
233+
self.sniffing_task.result()
234+
finally:
235+
self.sniffing_task = None
236+
237+
if self.sniffing_task is None:
238+
self.sniffing_task = self.loop.create_task(self.sniff_hosts(initial))
239+
240+
def mark_dead(self, connection):
241+
"""
242+
Mark a connection as dead (failed) in the connection pool. If sniffing
243+
on failure is enabled this will initiate the sniffing process.
244+
245+
:arg connection: instance of :class:`~elasticsearch.Connection` that failed
246+
"""
247+
self.connection_pool.mark_dead(connection)
248+
if self.sniff_on_connection_fail:
249+
self.create_sniff_task()
250+
251+
def get_connection(self):
252+
return self.connection_pool.get_connection()
253+
254+
async def perform_request(self, method, url, headers=None, params=None, body=None):
255+
"""
256+
Perform the actual request. Retrieve a connection from the connection
257+
pool, pass all the information to it's perform_request method and
258+
return the data.
259+
260+
If an exception was raised, mark the connection as failed and retry (up
261+
to `max_retries` times).
262+
263+
If the operation was successful and the connection used was previously
264+
marked as dead, mark it as live, resetting it's failure count.
265+
266+
:arg method: HTTP method to use
267+
:arg url: absolute url (without host) to target
268+
:arg headers: dictionary of headers, will be handed over to the
269+
underlying :class:`~elasticsearch.Connection` class
270+
:arg params: dictionary of query parameters, will be handed over to the
271+
underlying :class:`~elasticsearch.Connection` class for serialization
272+
:arg body: body of the request, will be serialized using serializer and
273+
passed to the connection
274+
"""
275+
await self._async_call()
276+
277+
method, params, body, ignore, timeout = self._resolve_request_args(
278+
method, params, body
279+
)
280+
281+
for attempt in range(self.max_retries + 1):
282+
connection = self.get_connection()
283+
284+
try:
285+
status, headers, data = await connection.perform_request(
286+
method,
287+
url,
288+
params,
289+
body,
290+
headers=headers,
291+
ignore=ignore,
292+
timeout=timeout,
293+
)
294+
except TransportError as e:
295+
if method == "HEAD" and e.status_code == 404:
296+
return False
297+
298+
retry = False
299+
if isinstance(e, ConnectionTimeout):
300+
retry = self.retry_on_timeout
301+
elif isinstance(e, ConnectionError):
302+
retry = True
303+
elif e.status_code in self.retry_on_status:
304+
retry = True
305+
306+
if retry:
307+
# only mark as dead if we are retrying
308+
self.mark_dead(connection)
309+
# raise exception on last retry
310+
if attempt == self.max_retries:
311+
raise
312+
else:
313+
raise
314+
315+
else:
316+
if method == "HEAD":
317+
return 200 <= status < 300
318+
319+
# connection didn't fail, confirm it's live status
320+
self.connection_pool.mark_live(connection)
321+
if data:
322+
data = self.deserializer.loads(data, headers.get("content-type"))
323+
return data
324+
325+
async def close(self):
326+
"""
327+
Explicitly closes connections
328+
"""
329+
if self.sniffing_task:
330+
self.sniffing_task.cancel()
331+
self.sniffing_task = None
332+
await self.connection_pool.aclose()

elasticsearch/connection_pool.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
33
# See the LICENSE file in the project root for more information
44

5+
import sys
56
import time
67
import random
78
import logging
@@ -256,9 +257,23 @@ def close(self):
256257
"""
257258
Explicitly closes connections
258259
"""
259-
for conn in self.orig_connections:
260+
for conn in self.connections:
260261
conn.close()
261262

263+
def __repr__(self):
264+
return "<%s: %r>" % (type(self).__name__, self.connections)
265+
266+
# We need to add the 'aclose()' method conditionally as it's only
267+
# used by AsyncTransport but since it's an async function we
268+
# can't have it be a part of 'connection_pool.py' :-(
269+
if sys.version_info >= (3, 6):
270+
try:
271+
from ._async.transport import connection_pool_aclose
272+
273+
aclose = connection_pool_aclose
274+
except (ImportError, SyntaxError):
275+
pass
276+
262277

263278
class DummyConnectionPool(ConnectionPool):
264279
def __init__(self, connections, **kwargs):
@@ -284,3 +299,19 @@ def _noop(self, *args, **kwargs):
284299
pass
285300

286301
mark_dead = mark_live = resurrect = _noop
302+
303+
304+
class EmptyConnectionPool(ConnectionPool):
305+
"""A connection pool that is empty. Errors out if used."""
306+
307+
def __init__(self, *_, **__):
308+
self.connections = []
309+
self.connection_opts = []
310+
311+
def get_connection(self):
312+
raise ImproperlyConfigured("No connections were configured")
313+
314+
def _noop(self, *args, **kwargs):
315+
pass
316+
317+
close = mark_dead = mark_live = resurrect = _noop

0 commit comments

Comments
 (0)