From f98de6552097699326fdd1ee9079cb4975c6baf2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Boros=20G=C3=A1bor?= Date: Sun, 5 Aug 2018 11:46:56 +0200 Subject: [PATCH] Refactor net_gevent --- rebirthdb/gevent_net/net_gevent.py | 180 +++++------------------------ 1 file changed, 27 insertions(+), 153 deletions(-) diff --git a/rebirthdb/gevent_net/net_gevent.py b/rebirthdb/gevent_net/net_gevent.py index 9635c7d7..1aaf848b 100644 --- a/rebirthdb/gevent_net/net_gevent.py +++ b/rebirthdb/gevent_net/net_gevent.py @@ -15,29 +15,31 @@ # This file incorporates work covered by the following copyright: # Copyright 2010-2016 RethinkDB, all rights reserved. -import errno -import ssl import struct import gevent -import gevent.socket as socket from gevent.event import AsyncResult, Event from gevent.lock import Semaphore + from rebirthdb import net, ql2_pb2 -from rebirthdb.errors import ReqlAuthError, ReqlCursorEmpty, ReqlDriverError, ReqlTimeoutError, RqlDriverError, \ - RqlTimeoutError -from rebirthdb.logger import default_logger +from rebirthdb.errors import ReqlCursorEmpty, RqlDriverError, RqlTimeoutError +from rebirthdb.net import SocketWrapper + __all__ = ['Connection'] -pResponse = ql2_pb2.Response.ResponseType -pQuery = ql2_pb2.Query.QueryType +PROTO_RESPONSE_TYPE = ql2_pb2.Response.ResponseType +PROTO_QUERY_TYPE = ql2_pb2.Query.QueryType class GeventCursorEmpty(ReqlCursorEmpty, StopIteration): pass +class Connection(net.Connection): + pass + + # TODO: allow users to set sync/async? class GeventCursor(net.Cursor): def __init__(self, *args, **kwargs): @@ -68,136 +70,8 @@ def _get_next(self, timeout): return self.items.popleft() -# TODO: would be nice to share this code with net.py -# TODO(grandquista): code seems to already be a duplicate of superclass -# revisit this after testing is inplace. -class SocketWrapper(net.SocketWrapper): - def __init__(self, parent): - self.host = parent._parent.host - self.port = parent._parent.port - self._read_buffer = None - self._socket = None - self.ssl = parent._parent.ssl - - try: - self._socket = socket.create_connection((self.host, self.port)) - self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - - if len(self.ssl) > 0: - try: - if hasattr(ssl, 'SSLContext'): # Python2.7 and 3.2+, or backports.ssl - ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - if hasattr(ssl_context, "options"): - ssl_context.options |= getattr(ssl, "OP_NO_SSLv2", 0) - ssl_context.options |= getattr(ssl, "OP_NO_SSLv3", 0) - self.ssl_context.verify_mode = ssl.CERT_REQUIRED - self.ssl_context.check_hostname = True # redundant with match_hostname - self.ssl_context.load_verify_locations(self.ssl["ca_certs"]) - self._socket = ssl_context.wrap_socket(self._socket, server_hostname=self.host) - else: # this does not disable SSLv2 or SSLv3 - self._socket = ssl.wrap_socket( - self._socket, cert_reqs=ssl.CERT_REQUIRED, ssl_version=ssl.PROTOCOL_SSLv23, - ca_certs=self.ssl["ca_certs"]) - except IOError as exc: - self._socket.close() - raise ReqlDriverError("SSL handshake failed (see server log for more information): %s" % str(exc)) - try: - ssl.match_hostname(self._socket.getpeercert(), hostname=self.host) - except ssl.CertificateError: - self._socket.close() - raise - - parent._parent.handshake.reset() - response = None - while True: - request = parent._parent.handshake.next_message(response) - if request is None: - break - # This may happen in the `V1_0` protocol where we send two requests as - # an optimization, then need to read each separately - if request is not "": - self.sendall(request) - - # The response from the server is a null-terminated string - response = b'' - while True: - char = self.recvall(1) - if char == b'\0': - break - response += char - except (ReqlAuthError, ReqlTimeoutError): - self.close() - raise - except ReqlDriverError as ex: - self.close() - error = str(ex) \ - .replace('receiving from', 'during handshake with') \ - .replace('sending to', 'during handshake with') - raise ReqlDriverError(error) - except Exception as ex: - self.close() - raise ReqlDriverError("Could not connect to %s:%s. Error: %s" % - (self.host, self.port, ex)) - - def close(self): - if self._socket is not None: - try: - self._socket.shutdown(socket.SHUT_RDWR) - self._socket.close() - except Exception as ex: - default_logger.error(ex.message) - finally: - self._socket = None - - def recvall(self, length): - res = b'' if self._read_buffer is None else self._read_buffer - while len(res) < length: - while True: - try: - chunk = self._socket.recv(length - len(res)) - break - except ReqlTimeoutError: - raise - except IOError as ex: - if ex.errno == errno.ECONNRESET: - self.close() - raise ReqlDriverError("Connection is closed.") - elif ex.errno != errno.EINTR: - self.close() - raise ReqlDriverError( - 'Connection interrupted receiving from %s:%s - %s' % (self.host, self.port, str(ex)) - ) - except Exception as ex: - self.close() - raise ReqlDriverError('Error receiving from %s:%s - %s' % (self.host, self.port, str(ex))) - if len(chunk) == 0: - self.close() - raise ReqlDriverError("Connection is closed.") - res += chunk - return res - - def sendall(self, data): - offset = 0 - while offset < len(data): - try: - offset += self._socket.send(data[offset:]) - except IOError as ex: - if ex.errno == errno.ECONNRESET: - self.close() - raise ReqlDriverError("Connection is closed.") - elif ex.errno != errno.EINTR: - self.close() - raise ReqlDriverError(('Connection interrupted ' + - 'sending to %s:%s - %s') % - (self.host, self.port, str(ex))) - except Exception as ex: - self.close() - raise ReqlDriverError('Error sending to %s:%s - %s' % (self.host, self.port, str(ex))) - - class ConnectionInstance(object): - def __init__(self, parent, io_loop=None): + def __init__(self, parent): self._parent = parent self._closing = False self._user_queries = {} @@ -205,10 +79,14 @@ def __init__(self, parent, io_loop=None): self._write_mutex = Semaphore() self._socket = None + self.timeout = None def connect(self, timeout): + if not self.timeout: + self.timeout = timeout + with gevent.Timeout(timeout, RqlTimeoutError(self._parent.host, self._parent.port)) as timeout: - self._socket = SocketWrapper(self) + self._socket = SocketWrapper(self, timeout) # Start a parallel coroutine to perform reads gevent.spawn(self._reader) @@ -217,8 +95,9 @@ def connect(self, timeout): def is_open(self): return self._socket is not None and self._socket.is_open() - def close(self, noreply_wait=False, token=None, exception=None): + def close(self, no_reply_wait=False, token=None, exception=None): self._closing = True + if exception is not None: err_message = "Connection is closed (%s)." % str(exception) else: @@ -234,9 +113,9 @@ def close(self, noreply_wait=False, token=None, exception=None): self._user_queries = {} self._cursor_cache = {} - if noreply_wait: - noreply = net.Query(pQuery.NOREPLY_WAIT, token, None, None) - self.run_query(noreply, False) + if no_reply_wait: + no_reply = net.Query(PROTO_QUERY_TYPE.NOREPLY_WAIT, token, None, None) + self.run_query(no_reply, False) try: self._socket.close() @@ -267,9 +146,9 @@ def run_query(self, query, noreply): def _reader(self): try: while True: - buf = self._socket.recvall(12) + buf = self._socket.recvall(12, self.timeout) (token, length,) = struct.unpack("