Skip to content

Replace asyncio.coroutine attributes with async def #298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 9, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 37 additions & 52 deletions rethinkdb/asyncio_net/net_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,12 @@
pQuery = ql2_pb2.Query.QueryType


@asyncio.coroutine
def _read_until(streamreader, delimiter):
async def _read_until(streamreader, delimiter):
"""Naive implementation of reading until a delimiter"""
buffer = bytearray()

while True:
c = yield from streamreader.read(1)
c = await streamreader.read(1)
if c == b"":
break # EOF
buffer.append(c[0])
Expand All @@ -69,13 +68,12 @@ def reusable_waiter(loop, timeout):
else:
deadline = None

@asyncio.coroutine
def wait(future):
async def wait(future):
if deadline is not None:
new_timeout = max(deadline - loop.time(), 0)
else:
new_timeout = None
return (yield from asyncio.wait_for(future, new_timeout))
return (await asyncio.wait_for(future, new_timeout))

return wait

Expand All @@ -101,20 +99,18 @@ def __init__(self, *args, **kwargs):
def __aiter__(self):
return self

@asyncio.coroutine
def __anext__(self):
async def __anext__(self):
try:
return (yield from self._get_next(None))
return (await self._get_next(None))
except ReqlCursorEmpty:
raise StopAsyncIteration

@asyncio.coroutine
def close(self):
async def close(self):
if self.error is None:
self.error = self._empty_error()
if self.conn.is_open():
self.outstanding_requests += 1
yield from self.conn._parent._stop(self)
await self.conn._parent._stop(self)

def _extend(self, res_buf):
Cursor._extend(self, res_buf)
Expand All @@ -123,16 +119,15 @@ def _extend(self, res_buf):

# Convenience function so users know when they've hit the end of the cursor
# without having to catch an exception
@asyncio.coroutine
def fetch_next(self, wait=True):
async def fetch_next(self, wait=True):
timeout = Cursor._wait_to_timeout(wait)
waiter = reusable_waiter(self.conn._io_loop, timeout)
while len(self.items) == 0 and self.error is None:
self._maybe_fetch_batch()
if self.error is not None:
raise self.error
with translate_timeout_errors():
yield from waiter(asyncio.shield(self.new_response))
await waiter(asyncio.shield(self.new_response))
# If there is a (non-empty) error to be received, we return True, so the
# user will receive it on the next `next` call.
return len(self.items) != 0 or not isinstance(self.error, RqlCursorEmpty)
Expand All @@ -142,15 +137,14 @@ def _empty_error(self):
# with mechanisms to return from a coroutine.
return RqlCursorEmpty()

@asyncio.coroutine
def _get_next(self, timeout):
async def _get_next(self, timeout):
waiter = reusable_waiter(self.conn._io_loop, timeout)
while len(self.items) == 0:
self._maybe_fetch_batch()
if self.error is not None:
raise self.error
with translate_timeout_errors():
yield from waiter(asyncio.shield(self.new_response))
await waiter(asyncio.shield(self.new_response))
return self.items.popleft()

def _maybe_fetch_batch(self):
Expand Down Expand Up @@ -186,8 +180,7 @@ def client_address(self):
if self.is_open():
return self._streamwriter.get_extra_info("sockname")[0]

@asyncio.coroutine
def connect(self, timeout):
async def connect(self, timeout):
try:
ssl_context = None
if len(self._parent.ssl) > 0:
Expand All @@ -199,7 +192,7 @@ def connect(self, timeout):
ssl_context.check_hostname = True # redundant with match_hostname
ssl_context.load_verify_locations(self._parent.ssl["ca_certs"])

self._streamreader, self._streamwriter = yield from asyncio.open_connection(
self._streamreader, self._streamwriter = await asyncio.open_connection(
self._parent.host,
self._parent.port,
ssl=ssl_context,
Expand Down Expand Up @@ -229,22 +222,22 @@ def connect(self, timeout):
if request != "":
self._streamwriter.write(request)

response = yield from asyncio.wait_for(
response = await asyncio.wait_for(
_read_until(self._streamreader, b"\0"),
timeout,
)
response = response[:-1]
except ReqlAuthError:
yield from self.close()
await self.close()
raise
except ReqlTimeoutError as err:
yield from self.close()
await self.close()
raise ReqlDriverError(
"Connection interrupted during handshake with %s:%s. Error: %s"
% (self._parent.host, self._parent.port, str(err))
)
except Exception as err:
yield from self.close()
await self.close()
raise ReqlDriverError(
"Could not connect to %s:%s. Error: %s"
% (self._parent.host, self._parent.port, str(err))
Expand All @@ -258,8 +251,7 @@ def connect(self, timeout):
def is_open(self):
return not (self._closing or self._streamreader.at_eof())

@asyncio.coroutine
def close(self, noreply_wait=False, token=None, exception=None):
async def close(self, noreply_wait=False, token=None, exception=None):
self._closing = True
if exception is not None:
err_message = "Connection is closed (%s)." % str(exception)
Expand All @@ -279,38 +271,36 @@ def close(self, noreply_wait=False, token=None, exception=None):

if noreply_wait:
noreply = Query(pQuery.NOREPLY_WAIT, token, None, None)
yield from self.run_query(noreply, False)
await self.run_query(noreply, False)

self._streamwriter.close()
# We must not wait for the _reader_task if we got an exception, because that
# means that we were called from it. Waiting would lead to a deadlock.
if self._reader_task and exception is None:
yield from self._reader_task
await self._reader_task

return None

@asyncio.coroutine
def run_query(self, query, noreply):
async def run_query(self, query, noreply):
self._streamwriter.write(query.serialize(self._parent._get_json_encoder(query)))
if noreply:
return None

response_future = asyncio.Future()
self._user_queries[query.token] = (query, response_future)
return (yield from response_future)
return (await response_future)

# The _reader coroutine runs in parallel, reading responses
# off of the socket and forwarding them to the appropriate Future or Cursor.
# This is shut down as a consequence of closing the stream, or an error in the
# socket/protocol from the server. Unexpected errors in this coroutine will
# close the ConnectionInstance and be passed to any open Futures or Cursors.
@asyncio.coroutine
def _reader(self):
async def _reader(self):
try:
while True:
buf = yield from self._streamreader.readexactly(12)
buf = await self._streamreader.readexactly(12)
(token, length,) = struct.unpack("<qL", buf)
buf = yield from self._streamreader.readexactly(length)
buf = await self._streamreader.readexactly(length)

cursor = self._cursor_cache.get(token)
if cursor is not None:
Expand Down Expand Up @@ -339,7 +329,7 @@ def _reader(self):
raise ReqlDriverError("Unexpected response received.")
except Exception as ex:
if not self._closing:
yield from self.close(exception=ex)
await self.close(exception=ex)


class Connection(ConnectionBase):
Expand All @@ -352,30 +342,25 @@ def __init__(self, *args, **kwargs):
"Could not convert port %s to an integer." % self.port
)

@asyncio.coroutine
def __aenter__(self):
async def __aenter__(self):
return self

@asyncio.coroutine
def __aexit__(self, exception_type, exception_val, traceback):
yield from self.close(False)
async def __aexit__(self, exception_type, exception_val, traceback):
await self.close(False)

@asyncio.coroutine
def _stop(self, cursor):
async def _stop(self, cursor):
self.check_open()
q = Query(pQuery.STOP, cursor.query.token, None, None)
return (yield from self._instance.run_query(q, True))
return (await self._instance.run_query(q, True))

@asyncio.coroutine
def reconnect(self, noreply_wait=True, timeout=None):
async def reconnect(self, noreply_wait=True, timeout=None):
# We close before reconnect so reconnect doesn't try to close us
# and then fail to return the Future (this is a little awkward).
yield from self.close(noreply_wait)
await self.close(noreply_wait)
self._instance = self._conn_type(self, **self._child_kwargs)
return (yield from self._instance.connect(timeout))
return (await self._instance.connect(timeout))

@asyncio.coroutine
def close(self, noreply_wait=True):
async def close(self, noreply_wait=True):
if self._instance is None:
return None
return (yield from ConnectionBase.close(self, noreply_wait=noreply_wait))
return (await ConnectionBase.close(self, noreply_wait=noreply_wait))