5
5
import socket
6
6
import threading
7
7
import weakref
8
+ from io import SEEK_END
8
9
from itertools import chain
9
10
from queue import Empty , Full , LifoQueue
10
11
from time import time
11
- from typing import Optional
12
+ from typing import Optional , Union
12
13
from urllib .parse import parse_qs , unquote , urlparse
13
14
14
15
from redis .backoff import NoBackoff
@@ -163,39 +164,47 @@ def parse_error(self, response):
163
164
164
165
165
166
class SocketBuffer :
166
- def __init__ (self , socket , socket_read_size , socket_timeout ):
167
+ def __init__ (
168
+ self , socket : socket .socket , socket_read_size : int , socket_timeout : float
169
+ ):
167
170
self ._sock = socket
168
171
self .socket_read_size = socket_read_size
169
172
self .socket_timeout = socket_timeout
170
173
self ._buffer = io .BytesIO ()
171
- # number of bytes written to the buffer from the socket
172
- self .bytes_written = 0
173
- # number of bytes read from the buffer
174
- self .bytes_read = 0
175
174
176
- @property
177
- def length (self ):
178
- return self .bytes_written - self .bytes_read
175
+ def unread_bytes (self ) -> int :
176
+ """
177
+ Remaining unread length of buffer
178
+ """
179
+ pos = self ._buffer .tell ()
180
+ end = self ._buffer .seek (0 , SEEK_END )
181
+ self ._buffer .seek (pos )
182
+ return end - pos
179
183
180
- def _read_from_socket (self , length = None , timeout = SENTINEL , raise_on_timeout = True ):
184
+ def _read_from_socket (
185
+ self ,
186
+ length : Optional [int ] = None ,
187
+ timeout : Union [float , object ] = SENTINEL ,
188
+ raise_on_timeout : Optional [bool ] = True ,
189
+ ) -> bool :
181
190
sock = self ._sock
182
191
socket_read_size = self .socket_read_size
183
- buf = self ._buffer
184
- buf .seek (self .bytes_written )
185
192
marker = 0
186
193
custom_timeout = timeout is not SENTINEL
187
194
195
+ buf = self ._buffer
196
+ current_pos = buf .tell ()
197
+ buf .seek (0 , SEEK_END )
198
+ if custom_timeout :
199
+ sock .settimeout (timeout )
188
200
try :
189
- if custom_timeout :
190
- sock .settimeout (timeout )
191
201
while True :
192
202
data = self ._sock .recv (socket_read_size )
193
203
# an empty string indicates the server shutdown the socket
194
204
if isinstance (data , bytes ) and len (data ) == 0 :
195
205
raise ConnectionError (SERVER_CLOSED_CONNECTION_ERROR )
196
206
buf .write (data )
197
207
data_length = len (data )
198
- self .bytes_written += data_length
199
208
marker += data_length
200
209
201
210
if length is not None and length > marker :
@@ -215,55 +224,51 @@ def _read_from_socket(self, length=None, timeout=SENTINEL, raise_on_timeout=True
215
224
return False
216
225
raise ConnectionError (f"Error while reading from socket: { ex .args } " )
217
226
finally :
227
+ buf .seek (current_pos )
218
228
if custom_timeout :
219
229
sock .settimeout (self .socket_timeout )
220
230
221
- def can_read (self , timeout ) :
222
- return bool (self .length ) or self ._read_from_socket (
231
+ def can_read (self , timeout : float ) -> bool :
232
+ return bool (self .unread_bytes () ) or self ._read_from_socket (
223
233
timeout = timeout , raise_on_timeout = False
224
234
)
225
235
226
- def read (self , length ) :
236
+ def read (self , length : int ) -> bytes :
227
237
length = length + 2 # make sure to read the \r\n terminator
228
238
# make sure we've read enough data from the socket
229
- if length > self .length :
230
- self ._read_from_socket (length - self .length )
239
+ if length > self .unread_bytes :
240
+ self ._read_from_socket (length - self .unread_bytes )
231
241
232
- self ._buffer .seek (self .bytes_read )
233
242
data = self ._buffer .read (length )
234
- self .bytes_read += len (data )
235
243
return data [:- 2 ]
236
244
237
- def readline (self ):
245
+ def readline (self ) -> bytes :
238
246
buf = self ._buffer
239
- buf .seek (self .bytes_read )
240
247
data = buf .readline ()
241
248
while not data .endswith (SYM_CRLF ):
242
249
# there's more data in the socket that we need
243
250
self ._read_from_socket ()
244
- buf .seek (self .bytes_read )
245
251
data = buf .readline ()
246
252
247
- self .bytes_read += len (data )
248
253
return data [:- 2 ]
249
254
250
- def get_pos (self ):
255
+ def get_pos (self ) -> int :
251
256
"""
252
257
Get current read position
253
258
"""
254
- return self .bytes_read
259
+ return self ._buffer . tell ()
255
260
256
- def rewind (self , pos ) :
261
+ def rewind (self , pos : int ) -> None :
257
262
"""
258
263
Rewind the buffer to a specific position, to re-start reading
259
264
"""
260
- self .bytes_read = pos
265
+ self ._buffer . seek ( pos )
261
266
262
- def purge (self ):
267
+ def purge (self ) -> None :
263
268
"""
264
269
After a successful read, purge the read part of buffer
265
270
"""
266
- unread = self .bytes_written - self . bytes_read
271
+ unread = self .unread_bytes ()
267
272
268
273
# Only if we have read all of the buffer do we truncate, to
269
274
# reduce the amount of memory thrashing. This heuristic
@@ -276,13 +281,10 @@ def purge(self):
276
281
view = self ._buffer .getbuffer ()
277
282
view [:unread ] = view [- unread :]
278
283
self ._buffer .truncate (unread )
279
- self .bytes_written = unread
280
- self .bytes_read = 0
281
284
self ._buffer .seek (0 )
282
285
283
- def close (self ):
286
+ def close (self ) -> None :
284
287
try :
285
- self .bytes_written = self .bytes_read = 0
286
288
self ._buffer .close ()
287
289
except Exception :
288
290
# issue #633 suggests the purge/close somehow raised a
@@ -498,6 +500,7 @@ def read_response(self, disable_decoding=False):
498
500
return response
499
501
500
502
503
+ DefaultParser : BaseParser
501
504
if HIREDIS_AVAILABLE :
502
505
DefaultParser = HiredisParser
503
506
else :
0 commit comments