Skip to content

Upgrade @asyncio.coroutine to async #181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion rethinkdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
# limitations under the License.
# The builtins here defends against re-importing something obscuring `object`.
import builtins
import imp
try:
import imp
except ImportError:
import importlib as imp

import os

import pkg_resources
Expand Down
116 changes: 55 additions & 61 deletions rethinkdb/asyncio_net/net_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,19 @@
pQuery = ql2_pb2.Query.QueryType


@asyncio.coroutine
def _read_until(streamreader, delimiter):
async def _read_until(streamreader, delimiter):
"""Naive implementation of reading until a delimiter"""
buffer = bytearray()

while True:
c = yield from streamreader.read(1)
c = yield streamreader.read(1)
if c == b"":
break # EOF
buffer.append(c[0])
if c == delimiter:
break

return bytes(buffer)
yield bytes(buffer)


def reusable_waiter(loop, timeout):
Expand All @@ -62,22 +61,22 @@ def reusable_waiter(loop, timeout):

waiter = reusable_waiter(event_loop, 10.0)
while some_condition:
yield from waiter(some_future)
yield waiter(some_future)
"""
if timeout is not None:
deadline = loop.time() + timeout
else:
deadline = None

@asyncio.coroutine
def wait(future):
async def wait(future):
if deadline is not None:
new_timeout = max(deadline - loop.time(), 0)
else:
new_timeout = None
return (yield from asyncio.wait_for(future, new_timeout, loop=loop))
yield asyncio.wait_for(future, new_timeout, loop=loop)
return

return wait
yield wait


@contextlib.contextmanager
Expand All @@ -101,20 +100,18 @@ def __init__(self, *args, **kwargs):
def __aiter__(self):
return self

@asyncio.coroutine
def __anext__(self):
async def __anext__(self):
try:
return (yield from self._get_next(None))
yield self._get_next(None)
except ReqlCursorEmpty:
raise StopAsyncIteration

@asyncio.coroutine
def close(self):
async def close(self):
if self.error is None:
self.error = self._empty_error()
if self.conn.is_open():
self.outstanding_requests += 1
yield from self.conn._parent._stop(self)
yield self.conn._parent._stop(self)

def _extend(self, res_buf):
Cursor._extend(self, res_buf)
Expand All @@ -123,35 +120,35 @@ def _extend(self, res_buf):

# Convenience function so users know when they've hit the end of the cursor
# without having to catch an exception
@asyncio.coroutine
def fetch_next(self, wait=True):
async def fetch_next(self, wait=True):
timeout = Cursor._wait_to_timeout(wait)
waiter = reusable_waiter(self.conn._io_loop, timeout)
while len(self.items) == 0 and self.error is None:
self._maybe_fetch_batch()
if self.error is not None:
raise self.error
with translate_timeout_errors():
yield from waiter(asyncio.shield(self.new_response))
yield waiter(asyncio.shield(self.new_response))
# If there is a (non-empty) error to be received, we return True, so the
# user will receive it on the next `next` call.
return len(self.items) != 0 or not isinstance(self.error, RqlCursorEmpty)
yield len(self.items) != 0 or not isinstance(self.error, RqlCursorEmpty)
return

def _empty_error(self):
# We do not have RqlCursorEmpty inherit from StopIteration as that interferes
# with mechanisms to return from a coroutine.
return RqlCursorEmpty()

@asyncio.coroutine
def _get_next(self, timeout):
async def _get_next(self, timeout):
waiter = reusable_waiter(self.conn._io_loop, timeout)
while len(self.items) == 0:
self._maybe_fetch_batch()
if self.error is not None:
raise self.error
with translate_timeout_errors():
yield from waiter(asyncio.shield(self.new_response))
return self.items.popleft()
yield waiter(asyncio.shield(self.new_response))
yield self.items.popleft()
return

def _maybe_fetch_batch(self):
if (
Expand Down Expand Up @@ -186,8 +183,7 @@ def client_address(self):
if self.is_open():
return self._streamwriter.get_extra_info("sockname")[0]

@asyncio.coroutine
def connect(self, timeout):
async def connect(self, timeout):
try:
ssl_context = None
if len(self._parent.ssl) > 0:
Expand All @@ -199,7 +195,7 @@ def connect(self, timeout):
ssl_context.check_hostname = True # redundant with match_hostname
ssl_context.load_verify_locations(self._parent.ssl["ca_certs"])

self._streamreader, self._streamwriter = yield from asyncio.open_connection(
self._streamreader, self._streamwriter = yield asyncio.open_connection(
self._parent.host,
self._parent.port,
loop=self._io_loop,
Expand Down Expand Up @@ -230,23 +226,23 @@ def connect(self, timeout):
if request is not "":
self._streamwriter.write(request)

response = yield from asyncio.wait_for(
response = yield asyncio.wait_for(
_read_until(self._streamreader, b"\0"),
timeout,
loop=self._io_loop,
)
response = response[:-1]
except ReqlAuthError:
yield from self.close()
yield self.close()
raise
except ReqlTimeoutError as err:
yield from self.close()
yield self.close()
raise ReqlDriverError(
"Connection interrupted during handshake with %s:%s. Error: %s"
% (self._parent.host, self._parent.port, str(err))
)
except Exception as err:
yield from self.close()
yield self.close()
raise ReqlDriverError(
"Could not connect to %s:%s. Error: %s"
% (self._parent.host, self._parent.port, str(err))
Expand All @@ -255,13 +251,13 @@ def connect(self, timeout):
# Start a parallel function to perform reads
# store a reference to it so it doesn't get destroyed
self._reader_task = asyncio.ensure_future(self._reader(), loop=self._io_loop)
return self._parent
yield self._parent
return

def is_open(self):
return not (self._closing or self._streamreader.at_eof())

@asyncio.coroutine
def close(self, noreply_wait=False, token=None, exception=None):
async def close(self, noreply_wait=False, token=None, exception=None):
self._closing = True
if exception is not None:
err_message = "Connection is closed (%s)." % str(exception)
Expand All @@ -281,38 +277,37 @@ def close(self, noreply_wait=False, token=None, exception=None):

if noreply_wait:
noreply = Query(pQuery.NOREPLY_WAIT, token, None, None)
yield from self.run_query(noreply, False)
yield self.run_query(noreply, False)

self._streamwriter.close()
# We must not wait for the _reader_task if we got an exception, because that
# means that we were called from it. Waiting would lead to a deadlock.
if self._reader_task and exception is None:
yield from self._reader_task
yield self._reader_task

return None
return

@asyncio.coroutine
def run_query(self, query, noreply):
async def run_query(self, query, noreply):
self._streamwriter.write(query.serialize(self._parent._get_json_encoder(query)))
if noreply:
return None
return

response_future = asyncio.Future()
self._user_queries[query.token] = (query, response_future)
return (yield from response_future)
yield response_future
return

# The _reader coroutine runs in parallel, reading responses
# off of the socket and forwarding them to the appropriate Future or Cursor.
# This is shut down as a consequence of closing the stream, or an error in the
# socket/protocol from the server. Unexpected errors in this coroutine will
# close the ConnectionInstance and be passed to any open Futures or Cursors.
@asyncio.coroutine
def _reader(self):
async def _reader(self):
try:
while True:
buf = yield from self._streamreader.readexactly(12)
buf = yield self._streamreader.readexactly(12)
(token, length,) = struct.unpack("<qL", buf)
buf = yield from self._streamreader.readexactly(length)
buf = yield self._streamreader.readexactly(length)

cursor = self._cursor_cache.get(token)
if cursor is not None:
Expand Down Expand Up @@ -341,7 +336,7 @@ def _reader(self):
raise ReqlDriverError("Unexpected response received.")
except Exception as ex:
if not self._closing:
yield from self.close(exception=ex)
yield self.close(exception=ex)


class Connection(ConnectionBase):
Expand All @@ -354,30 +349,29 @@ def __init__(self, *args, **kwargs):
"Could not convert port %s to an integer." % self.port
)

@asyncio.coroutine
def __aenter__(self):
async def __aenter__(self):
return self

@asyncio.coroutine
def __aexit__(self, exception_type, exception_val, traceback):
yield from self.close(False)
async def __aexit__(self, exception_type, exception_val, traceback):
yield self.close(False)

@asyncio.coroutine
def _stop(self, cursor):
async def _stop(self, cursor):
self.check_open()
q = Query(pQuery.STOP, cursor.query.token, None, None)
return (yield from self._instance.run_query(q, True))
yield self._instance.run_query(q, True)
return

@asyncio.coroutine
def reconnect(self, noreply_wait=True, timeout=None):
async def reconnect(self, noreply_wait=True, timeout=None):
# We close before reconnect so reconnect doesn't try to close us
# and then fail to return the Future (this is a little awkward).
yield from self.close(noreply_wait)
yield self.close(noreply_wait)
self._instance = self._conn_type(self, **self._child_kwargs)
return (yield from self._instance.connect(timeout))
yield self._instance.connect(timeout)
return

@asyncio.coroutine
def close(self, noreply_wait=True):
async def close(self, noreply_wait=True):
if self._instance is None:
return None
return (yield from ConnectionBase.close(self, noreply_wait=noreply_wait))
yield None
return
yield ConnectionBase.close(self, noreply_wait=noreply_wait)
return
19 changes: 8 additions & 11 deletions tests/integration/test_asyncio.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import sys
from asyncio import coroutine

import pytest

from tests.helpers import INTEGRATION_TEST_DB, IntegrationTestCaseBase
Expand All @@ -22,21 +20,20 @@ def teardown_method(self):
super(TestAsyncio, self).teardown_method()
self.r.set_loop_type(None)

@coroutine
def test_flow_coroutine_paradigm(self):
connection = yield from self.conn
async def test_flow_coroutine_paradigm(self):
connection = yield self.conn

yield from self.r.table_create(self.table_name).run(connection)
yield self.r.table_create(self.table_name).run(connection)

table = self.r.table(self.table_name)
yield from table.insert(
yield table.insert(
{"id": 1, "name": "Iron Man", "first_appearance": "Tales of Suspense #39"}
).run(connection)

cursor = yield from table.run(connection)
cursor = yield table.run(connection)

while (yield from cursor.fetch_next()):
hero = yield from cursor.__anext__()
while (yield cursor.fetch_next()):
hero = yield cursor.__anext__()
assert hero["name"] == "Iron Man"

yield from connection.close()
yield connection.close()