@@ -828,7 +828,7 @@ async def on_connect(self) -> None:
828
828
if str_if_bytes (await self .read_response ()) != "OK" :
829
829
raise ConnectionError ("Invalid Database" )
830
830
831
- async def disconnect (self ) -> None :
831
+ async def disconnect (self , nowait : bool = False ) -> None :
832
832
"""Disconnects from the Redis server"""
833
833
try :
834
834
async with async_timeout .timeout (self .socket_connect_timeout ):
@@ -838,8 +838,9 @@ async def disconnect(self) -> None:
838
838
try :
839
839
if os .getpid () == self .pid :
840
840
self ._writer .close () # type: ignore[union-attr]
841
- # py3.6 doesn't have this method
842
- if hasattr (self ._writer , "wait_closed" ):
841
+ # wait for close to finish, except when handling errors and
842
+ # forcecully disconnecting.
843
+ if not nowait :
843
844
await self ._writer .wait_closed () # type: ignore[union-attr]
844
845
except OSError :
845
846
pass
@@ -894,10 +895,10 @@ async def send_packed_command(
894
895
self ._writer .writelines (command )
895
896
await self ._writer .drain ()
896
897
except asyncio .TimeoutError :
897
- await self .disconnect ()
898
+ await self .disconnect (nowait = True )
898
899
raise TimeoutError ("Timeout writing to socket" ) from None
899
900
except OSError as e :
900
- await self .disconnect ()
901
+ await self .disconnect (nowait = True )
901
902
if len (e .args ) == 1 :
902
903
err_no , errmsg = "UNKNOWN" , e .args [0 ]
903
904
else :
@@ -907,7 +908,7 @@ async def send_packed_command(
907
908
f"Error { err_no } while writing to socket. { errmsg } ."
908
909
) from e
909
910
except BaseException :
910
- await self .disconnect ()
911
+ await self .disconnect (nowait = True )
911
912
raise
912
913
913
914
async def send_command (self , * args : Any , ** kwargs : Any ) -> None :
@@ -923,7 +924,7 @@ async def can_read(self, timeout: float = 0):
923
924
try :
924
925
return await self ._parser .can_read (timeout )
925
926
except OSError as e :
926
- await self .disconnect ()
927
+ await self .disconnect (nowait = True )
927
928
raise ConnectionError (
928
929
f"Error while reading from { self .host } :{ self .port } : { e .args } "
929
930
)
@@ -974,15 +975,15 @@ async def read_response_without_lock(self, disable_decoding: bool = False):
974
975
disable_decoding = disable_decoding
975
976
)
976
977
except asyncio .TimeoutError :
977
- await self .disconnect ()
978
+ await self .disconnect (nowait = True )
978
979
raise TimeoutError (f"Timeout reading from { self .host } :{ self .port } " )
979
980
except OSError as e :
980
- await self .disconnect ()
981
+ await self .disconnect (nowait = True )
981
982
raise ConnectionError (
982
983
f"Error while reading from { self .host } :{ self .port } : { e .args } "
983
984
)
984
985
except BaseException :
985
- await self .disconnect ()
986
+ await self .disconnect (nowait = True )
986
987
raise
987
988
988
989
if self .health_check_interval :
0 commit comments