5
5
import socket
6
6
import threading
7
7
import weakref
8
+ from io import SEEK_END
8
9
from itertools import chain
9
10
from queue import Empty , Full , LifoQueue
10
11
from time import time
11
- from typing import Optional
12
+ from typing import Optional , Union
12
13
from urllib .parse import parse_qs , unquote , urlparse
13
14
14
15
from redis .backoff import NoBackoff
@@ -163,39 +164,47 @@ def parse_error(self, response):
163
164
164
165
165
166
class SocketBuffer :
166
- def __init__ (self , socket , socket_read_size , socket_timeout ):
167
+ def __init__ (
168
+ self , socket : socket .socket , socket_read_size : int , socket_timeout : float
169
+ ):
167
170
self ._sock = socket
168
171
self .socket_read_size = socket_read_size
169
172
self .socket_timeout = socket_timeout
170
173
self ._buffer = io .BytesIO ()
171
- # number of bytes written to the buffer from the socket
172
- self .bytes_written = 0
173
- # number of bytes read from the buffer
174
- self .bytes_read = 0
175
174
176
- @property
177
- def length (self ):
178
- return self .bytes_written - self .bytes_read
175
+ def unread_bytes (self ) -> int :
176
+ """
177
+ Remaining unread length of buffer
178
+ """
179
+ pos = self ._buffer .tell ()
180
+ end = self ._buffer .seek (0 , SEEK_END )
181
+ self ._buffer .seek (pos )
182
+ return end - pos
179
183
180
- def _read_from_socket (self , length = None , timeout = SENTINEL , raise_on_timeout = True ):
184
+ def _read_from_socket (
185
+ self ,
186
+ length : Optional [int ] = None ,
187
+ timeout : Union [float , object ] = SENTINEL ,
188
+ raise_on_timeout : Optional [bool ] = True ,
189
+ ) -> bool :
181
190
sock = self ._sock
182
191
socket_read_size = self .socket_read_size
183
- buf = self ._buffer
184
- buf .seek (self .bytes_written )
185
192
marker = 0
186
193
custom_timeout = timeout is not SENTINEL
187
194
195
+ buf = self ._buffer
196
+ current_pos = buf .tell ()
197
+ buf .seek (0 , SEEK_END )
198
+ if custom_timeout :
199
+ sock .settimeout (timeout )
188
200
try :
189
- if custom_timeout :
190
- sock .settimeout (timeout )
191
201
while True :
192
202
data = self ._sock .recv (socket_read_size )
193
203
# an empty string indicates the server shutdown the socket
194
204
if isinstance (data , bytes ) and len (data ) == 0 :
195
205
raise ConnectionError (SERVER_CLOSED_CONNECTION_ERROR )
196
206
buf .write (data )
197
207
data_length = len (data )
198
- self .bytes_written += data_length
199
208
marker += data_length
200
209
201
210
if length is not None and length > marker :
@@ -215,55 +224,53 @@ def _read_from_socket(self, length=None, timeout=SENTINEL, raise_on_timeout=True
215
224
return False
216
225
raise ConnectionError (f"Error while reading from socket: { ex .args } " )
217
226
finally :
227
+ buf .seek (current_pos )
218
228
if custom_timeout :
219
229
sock .settimeout (self .socket_timeout )
220
230
221
- def can_read (self , timeout ) :
222
- return bool (self .length ) or self ._read_from_socket (
231
+ def can_read (self , timeout : float ) -> bool :
232
+ return bool (self .unread_bytes () ) or self ._read_from_socket (
223
233
timeout = timeout , raise_on_timeout = False
224
234
)
225
235
226
- def read (self , length ) :
236
+ def read (self , length : int ) -> bytes :
227
237
length = length + 2 # make sure to read the \r\n terminator
228
- # make sure we've read enough data from the socket
229
- if length > self .length :
230
- self ._read_from_socket (length - self .length )
231
-
232
- self ._buffer .seek (self .bytes_read )
238
+ # BufferIO will return less than requested if buffer is short
233
239
data = self ._buffer .read (length )
234
- self .bytes_read += len (data )
240
+ missing = length - len (data )
241
+ if missing :
242
+ # fill up the buffer and read the remainder
243
+ self ._read_from_socket (missing )
244
+ data += self ._buffer .read (missing )
235
245
return data [:- 2 ]
236
246
237
- def readline (self ):
247
+ def readline (self ) -> bytes :
238
248
buf = self ._buffer
239
- buf .seek (self .bytes_read )
240
249
data = buf .readline ()
241
250
while not data .endswith (SYM_CRLF ):
242
251
# there's more data in the socket that we need
243
252
self ._read_from_socket ()
244
- buf .seek (self .bytes_read )
245
- data = buf .readline ()
253
+ data += buf .readline ()
246
254
247
- self .bytes_read += len (data )
248
255
return data [:- 2 ]
249
256
250
- def get_pos (self ):
257
+ def get_pos (self ) -> int :
251
258
"""
252
259
Get current read position
253
260
"""
254
- return self .bytes_read
261
+ return self ._buffer . tell ()
255
262
256
- def rewind (self , pos ) :
263
+ def rewind (self , pos : int ) -> None :
257
264
"""
258
265
Rewind the buffer to a specific position, to re-start reading
259
266
"""
260
- self .bytes_read = pos
267
+ self ._buffer . seek ( pos )
261
268
262
- def purge (self ):
269
+ def purge (self ) -> None :
263
270
"""
264
271
After a successful read, purge the read part of buffer
265
272
"""
266
- unread = self .bytes_written - self . bytes_read
273
+ unread = self .unread_bytes ()
267
274
268
275
# Only if we have read all of the buffer do we truncate, to
269
276
# reduce the amount of memory thrashing. This heuristic
@@ -276,13 +283,10 @@ def purge(self):
276
283
view = self ._buffer .getbuffer ()
277
284
view [:unread ] = view [- unread :]
278
285
self ._buffer .truncate (unread )
279
- self .bytes_written = unread
280
- self .bytes_read = 0
281
286
self ._buffer .seek (0 )
282
287
283
- def close (self ):
288
+ def close (self ) -> None :
284
289
try :
285
- self .bytes_written = self .bytes_read = 0
286
290
self ._buffer .close ()
287
291
except Exception :
288
292
# issue #633 suggests the purge/close somehow raised a
@@ -498,6 +502,7 @@ def read_response(self, disable_decoding=False):
498
502
return response
499
503
500
504
505
+ DefaultParser : BaseParser
501
506
if HIREDIS_AVAILABLE :
502
507
DefaultParser = HiredisParser
503
508
else :
0 commit comments