Skip to content

Simplify the sync SocketBuffer, add type hints #2543

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
* Simplify synchronous SocketBuffer state management
* Fix string cleanse in Redis Graph
* Make PythonParser resumable in case of error (#2510)
* Add `timeout=None` in `SentinelConnectionManager.read_response`
Expand Down
83 changes: 44 additions & 39 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import socket
import threading
import weakref
from io import SEEK_END
from itertools import chain
from queue import Empty, Full, LifoQueue
from time import time
from typing import Optional
from typing import Optional, Union
from urllib.parse import parse_qs, unquote, urlparse

from redis.backoff import NoBackoff
Expand Down Expand Up @@ -163,39 +164,47 @@ def parse_error(self, response):


class SocketBuffer:
def __init__(self, socket, socket_read_size, socket_timeout):
def __init__(
self, socket: socket.socket, socket_read_size: int, socket_timeout: float
):
self._sock = socket
self.socket_read_size = socket_read_size
self.socket_timeout = socket_timeout
self._buffer = io.BytesIO()
# number of bytes written to the buffer from the socket
self.bytes_written = 0
# number of bytes read from the buffer
self.bytes_read = 0

@property
def length(self):
return self.bytes_written - self.bytes_read
def unread_bytes(self) -> int:
"""
Remaining unread length of buffer
"""
pos = self._buffer.tell()
end = self._buffer.seek(0, SEEK_END)
self._buffer.seek(pos)
return end - pos

def _read_from_socket(self, length=None, timeout=SENTINEL, raise_on_timeout=True):
def _read_from_socket(
self,
length: Optional[int] = None,
timeout: Union[float, object] = SENTINEL,
raise_on_timeout: Optional[bool] = True,
) -> bool:
sock = self._sock
socket_read_size = self.socket_read_size
buf = self._buffer
buf.seek(self.bytes_written)
marker = 0
custom_timeout = timeout is not SENTINEL

buf = self._buffer
current_pos = buf.tell()
buf.seek(0, SEEK_END)
if custom_timeout:
sock.settimeout(timeout)
try:
if custom_timeout:
sock.settimeout(timeout)
while True:
data = self._sock.recv(socket_read_size)
# an empty string indicates the server shutdown the socket
if isinstance(data, bytes) and len(data) == 0:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
buf.write(data)
data_length = len(data)
self.bytes_written += data_length
marker += data_length

if length is not None and length > marker:
Expand All @@ -215,55 +224,53 @@ def _read_from_socket(self, length=None, timeout=SENTINEL, raise_on_timeout=True
return False
raise ConnectionError(f"Error while reading from socket: {ex.args}")
finally:
buf.seek(current_pos)
if custom_timeout:
sock.settimeout(self.socket_timeout)

def can_read(self, timeout):
return bool(self.length) or self._read_from_socket(
def can_read(self, timeout: float) -> bool:
return bool(self.unread_bytes()) or self._read_from_socket(
timeout=timeout, raise_on_timeout=False
)

def read(self, length):
def read(self, length: int) -> bytes:
length = length + 2 # make sure to read the \r\n terminator
# make sure we've read enough data from the socket
if length > self.length:
self._read_from_socket(length - self.length)

self._buffer.seek(self.bytes_read)
# BufferIO will return less than requested if buffer is short
data = self._buffer.read(length)
self.bytes_read += len(data)
missing = length - len(data)
if missing:
# fill up the buffer and read the remainder
self._read_from_socket(missing)
data += self._buffer.read(missing)
return data[:-2]

def readline(self):
def readline(self) -> bytes:
buf = self._buffer
buf.seek(self.bytes_read)
data = buf.readline()
while not data.endswith(SYM_CRLF):
# there's more data in the socket that we need
self._read_from_socket()
buf.seek(self.bytes_read)
data = buf.readline()
data += buf.readline()

self.bytes_read += len(data)
return data[:-2]

def get_pos(self):
def get_pos(self) -> int:
"""
Get current read position
"""
return self.bytes_read
return self._buffer.tell()

def rewind(self, pos):
def rewind(self, pos: int) -> None:
"""
Rewind the buffer to a specific position, to re-start reading
"""
self.bytes_read = pos
self._buffer.seek(pos)

def purge(self):
def purge(self) -> None:
"""
After a successful read, purge the read part of buffer
"""
unread = self.bytes_written - self.bytes_read
unread = self.unread_bytes()

# Only if we have read all of the buffer do we truncate, to
# reduce the amount of memory thrashing. This heuristic
Expand All @@ -276,13 +283,10 @@ def purge(self):
view = self._buffer.getbuffer()
view[:unread] = view[-unread:]
self._buffer.truncate(unread)
self.bytes_written = unread
self.bytes_read = 0
self._buffer.seek(0)

def close(self):
def close(self) -> None:
try:
self.bytes_written = self.bytes_read = 0
self._buffer.close()
except Exception:
# issue #633 suggests the purge/close somehow raised a
Expand Down Expand Up @@ -498,6 +502,7 @@ def read_response(self, disable_decoding=False):
return response


DefaultParser: BaseParser
if HIREDIS_AVAILABLE:
DefaultParser = HiredisParser
else:
Expand Down