Skip to content
This repository was archived by the owner on Mar 20, 2023. It is now read-only.

Update to async/await syntax #80

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ junit.xml
test_elasticsearch_async/htmlcov
docs/_build
.cache
.eggs
23 changes: 3 additions & 20 deletions README
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ response.

Sniffing (when requested) is also done via a scheduled coroutine.

Example for python 3.5+
Example for python 3.6+

.. code-block:: python

Expand All @@ -25,22 +25,6 @@ Example for python 3.5+
loop.run_until_complete(client.transport.close())
loop.close()

Example for python 3.4

.. code-block:: python

import asyncio
from elasticsearch_async import AsyncElasticsearch
hosts = ['localhost', 'other-host']

async def print_info():
async with AsyncElasticsearch(hosts=hosts) as client:
print(await client.info())

loop = asyncio.get_event_loop()
loop.run_until_complete(print_info())
loop.close()


Example with SSL Context

Expand All @@ -58,9 +42,8 @@ Example with SSL Context
http_auth=('elastic', 'changeme')
)

@asyncio.coroutine
def print_info():
info = yield from client.info()
async def print_info():
info = await client.info()
print(info)

loop = asyncio.get_event_loop()
Expand Down
8 changes: 3 additions & 5 deletions elasticsearch_async/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@ class AsyncElasticsearch(Elasticsearch):
def __init__(self, hosts=None, transport_class=AsyncTransport, **kwargs):
super().__init__(hosts, transport_class=transport_class, **kwargs)

@asyncio.coroutine
def __aenter__(self):
async def __aenter__(self):
return self

@asyncio.coroutine
def __aexit__(self, _exc_type, _exc_val, _exc_tb):
yield from self.transport.close()
async def __aexit__(self, _exc_type, _exc_val, _exc_tb):
await self.transport.close()
24 changes: 10 additions & 14 deletions elasticsearch_async/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, hosts, connection_class=AIOHttpConnection, loop=None,
self.raise_on_sniff_error = raise_on_sniff_error
self.loop = asyncio.get_event_loop() if loop is None else loop
kwargs['loop'] = self.loop
super().__init__(hosts, connection_class=connection_class, sniff_on_start=False,
super().__init__(hosts, connection_class=connection_class, sniff_on_start=False,
connection_pool_class=connection_pool_class, **kwargs)

self.sniffing_task = None
Expand All @@ -46,11 +46,10 @@ def initiate_sniff(self, initial=False):
if self.sniffing_task is None:
self.sniffing_task = ensure_future(self.sniff_hosts(initial), loop=self.loop)

@asyncio.coroutine
def close(self):
async def close(self):
if self.sniffing_task:
self.sniffing_task.cancel()
yield from self.connection_pool.close()
await self.connection_pool.close()

def set_connections(self, hosts):
super().set_connections(hosts)
Expand All @@ -68,8 +67,7 @@ def mark_dead(self, connection):
if self.sniff_on_connection_fail:
self.initiate_sniff()

@asyncio.coroutine
def _get_sniff_data(self, initial=False):
async def _get_sniff_data(self, initial=False):
previous_sniff = self.last_sniff

# reset last_sniff timestamp
Expand All @@ -89,7 +87,7 @@ def _get_sniff_data(self, initial=False):
try:
while tasks:
# execute sniff requests in parallel, wait for first to return
done, tasks = yield from asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED, loop=self.loop)
done, tasks = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED, loop=self.loop)
# go through all the finished tasks
for t in done:
try:
Expand All @@ -112,8 +110,7 @@ def _get_sniff_data(self, initial=False):
for t in chain(done, tasks):
t.cancel()

@asyncio.coroutine
def sniff_hosts(self, initial=False):
async def sniff_hosts(self, initial=False):
"""
Obtain a list of nodes from the cluster and create a new connection
pool using the information retrieved.
Expand All @@ -123,7 +120,7 @@ def sniff_hosts(self, initial=False):
:arg initial: flag indicating if this is during startup
(``sniff_on_start``), ignore the ``sniff_timeout`` if ``True``
"""
node_info = yield from self._get_sniff_data(initial)
node_info = await self._get_sniff_data(initial)

hosts = list(filter(None, (self._get_host_info(n) for n in node_info)))

Expand All @@ -138,15 +135,14 @@ def sniff_hosts(self, initial=False):
# close those connections that are not in use any more
for c in orig_connections:
if c not in self.connection_pool.connections:
yield from c.close()
await c.close()

@asyncio.coroutine
def main_loop(self, method, url, params, body, headers=None, ignore=(), timeout=None):
async def main_loop(self, method, url, params, body, headers=None, ignore=(), timeout=None):
for attempt in range(self.max_retries + 1):
connection = self.get_connection()

try:
status, headers, data = yield from connection.perform_request(
status, headers, data = await connection.perform_request(
method, url, params, body, headers=headers, ignore=ignore, timeout=timeout)
except TransportError as e:
if method == 'HEAD' and e.status_code == 404:
Expand Down
28 changes: 14 additions & 14 deletions test_elasticsearch_async/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,46 +3,46 @@
from elasticsearch import NotFoundError

@mark.asyncio
def test_custom_body(server, client):
async def test_custom_body(server, client):
server.register_response('/', {'custom': 'body'})
data = yield from client.info()
data = await client.info()

assert [('GET', '/', '', {})] == server.calls
assert {'custom': 'body'} == data

@mark.asyncio
def test_info_works(server, client):
data = yield from client.info()
async def test_info_works(server, client):
data = await client.info()

assert [('GET', '/', '', {})] == server.calls
assert {'body': '', 'method': 'GET', 'params': {}, 'path': '/'} == data

@mark.asyncio
def test_ping_works(server, client):
data = yield from client.ping()
async def test_ping_works(server, client):
data = await client.ping()

assert [('HEAD', '/', '', {})] == server.calls
assert data is True

@mark.asyncio
def test_exists_with_404_returns_false(server, client):
async def test_exists_with_404_returns_false(server, client):
server.register_response('/not-there', status=404)
data = yield from client.indices.exists(index='not-there')
data = await client.indices.exists(index='not-there')

assert data is False

@mark.asyncio
def test_404_properly_raised(server, client):
async def test_404_properly_raised(server, client):
server.register_response('/i/t/42', status=404)
with raises(NotFoundError):
yield from client.get(index='i', doc_type='t', id=42)
await client.get(index='i', doc_type='t', id=42)

@mark.asyncio
def test_body_gets_passed_properly(client):
data = yield from client.index(index='i', doc_type='t', id='42', body={'some': 'data'})
async def test_body_gets_passed_properly(client):
data = await client.index(index='i', doc_type='t', id='42', body={'some': 'data'})
assert {'body': {'some': 'data'}, 'method': 'PUT', 'params': {}, 'path': '/i/t/42'} == data

@mark.asyncio
def test_params_get_passed_properly(client):
data = yield from client.info(params={'some': 'data'})
async def test_params_get_passed_properly(client):
data = await client.info(params={'some': 'data'})
assert {'body': '', 'method': 'GET', 'params': {'some': 'data'}, 'path': '/'} == data
21 changes: 10 additions & 11 deletions test_elasticsearch_async/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from elasticsearch_async.connection import AIOHttpConnection

@mark.asyncio
def test_info(connection):
status, headers, data = yield from connection.perform_request('GET', '/')
async def test_info(connection):
status, headers, data = await connection.perform_request('GET', '/')

data = json.loads(data)

Expand Down Expand Up @@ -41,9 +41,9 @@ def test_ssl_context_is_correctly(event_loop):


@mark.asyncio
def test_request_is_properly_logged(connection, caplog, port, server):
async def test_request_is_properly_logged(connection, caplog, port, server):
server.register_response('/_cat/indices', {'cat': 'indices'})
yield from connection.perform_request('GET', '/_cat/indices', body=b'{}', params={"format": "json"})
await connection.perform_request('GET', '/_cat/indices', body=b'{}', params={"format": "json"})

for logger, level, message in caplog.record_tuples:
if logger == 'elasticsearch' and level == logging.INFO:
Expand All @@ -56,10 +56,10 @@ def test_request_is_properly_logged(connection, caplog, port, server):
assert ('elasticsearch', logging.DEBUG, '< {"cat": "indices"}') in caplog.record_tuples

@mark.asyncio
def test_error_is_properly_logged(connection, caplog, port, server):
async def test_error_is_properly_logged(connection, caplog, port, server):
server.register_response('/i', status=404)
with raises(NotFoundError):
yield from connection.perform_request('GET', '/i', params={'some': 'data'})
await connection.perform_request('GET', '/i', params={'some': 'data'})

for logger, level, message in caplog.record_tuples:
if logger == 'elasticsearch' and level == logging.WARNING:
Expand All @@ -69,15 +69,14 @@ def test_error_is_properly_logged(connection, caplog, port, server):
assert False, "Log not received"

@mark.asyncio
def test_timeout_is_properly_raised(connection, server):
@asyncio.coroutine
def slow_request():
yield from asyncio.sleep(0.01)
async def test_timeout_is_properly_raised(connection, server):
async def slow_request():
await asyncio.sleep(0.01)
return {}
server.register_response('/_search', slow_request())

with raises(ConnectionTimeout):
yield from connection.perform_request('GET', '/_search', timeout=0.0001)
await connection.perform_request('GET', '/_search', timeout=0.0001)


def test_dns_cache_is_enabled_by_default(event_loop):
Expand Down
16 changes: 8 additions & 8 deletions test_elasticsearch_async/test_connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,31 @@


@mark.asyncio
def test_single_host_makes_async_dummy_pool(server, client, event_loop, port):
async def test_single_host_makes_async_dummy_pool(server, client, event_loop, port):
client = AsyncElasticsearch(port=port, loop=event_loop)
assert isinstance(client.transport.connection_pool, AsyncDummyConnectionPool)
yield from client.transport.close()
await client.transport.close()

@mark.asyncio
def test_multiple_hosts_make_async_pool(server, event_loop, port):
async def test_multiple_hosts_make_async_pool(server, event_loop, port):
client = AsyncElasticsearch(
hosts=['localhost', 'localhost'], port=port, loop=event_loop)
assert isinstance(client.transport.connection_pool, AsyncConnectionPool)
assert len(client.transport.connection_pool.orig_connections) == 2
yield from client.transport.close()
await client.transport.close()

@mark.asyncio
def test_async_dummy_pool_is_closed_properly(server, event_loop, port):
async def test_async_dummy_pool_is_closed_properly(server, event_loop, port):
client = AsyncElasticsearch(port=port, loop=event_loop)
assert isinstance(client.transport.connection_pool, AsyncDummyConnectionPool)
yield from client.transport.close()
await client.transport.close()
assert client.transport.connection_pool.connection.session.closed

@mark.asyncio
def test_async_pool_is_closed_properly(server, event_loop, port):
async def test_async_pool_is_closed_properly(server, event_loop, port):
client = AsyncElasticsearch(
hosts=['localhost', 'localhost'], port=port, loop=event_loop)
assert isinstance(client.transport.connection_pool, AsyncConnectionPool)
yield from client.transport.close()
await client.transport.close()
for conn in client.transport.connection_pool.orig_connections:
assert conn.session.closed
12 changes: 6 additions & 6 deletions test_elasticsearch_async/test_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,32 @@


@mark.asyncio
def test_sniff_on_start_sniffs(server, event_loop, port, sniff_data):
async def test_sniff_on_start_sniffs(server, event_loop, port, sniff_data):
server.register_response('/_nodes/_all/http', sniff_data)

client = AsyncElasticsearch(
port=port, sniff_on_start=True, loop=event_loop)

# sniff has been called in the background
assert client.transport.sniffing_task is not None
yield from client.transport.sniffing_task
await client.transport.sniffing_task

assert [('GET', '/_nodes/_all/http', '', {})] == server.calls
connections = client.transport.connection_pool.connections

assert 1 == len(connections)
assert 'http://node1:9200' == connections[0].host
yield from client.transport.close()
await client.transport.close()


@mark.asyncio
def test_retry_will_work(port, server, event_loop):
async def test_retry_will_work(port, server, event_loop):
client = AsyncElasticsearch(
hosts=['not-an-es-host', 'localhost'],
port=port,
loop=event_loop,
randomize_hosts=False)

data = yield from client.info()
data = await client.info()
assert {'body': '', 'method': 'GET', 'params': {}, 'path': '/'} == data
yield from client.transport.close()
await client.transport.close()