@@ -825,7 +825,7 @@ async def on_connect(self) -> None:
825
825
if str_if_bytes (await self .read_response ()) != "OK" :
826
826
raise ConnectionError ("Invalid Database" )
827
827
828
- async def disconnect (self ) -> None :
828
+ async def disconnect (self , nowait : bool = False ) -> None :
829
829
"""Disconnects from the Redis server"""
830
830
try :
831
831
async with async_timeout .timeout (self .socket_connect_timeout ):
@@ -835,8 +835,9 @@ async def disconnect(self) -> None:
835
835
try :
836
836
if os .getpid () == self .pid :
837
837
self ._writer .close () # type: ignore[union-attr]
838
- # py3.6 doesn't have this method
839
- if hasattr (self ._writer , "wait_closed" ):
838
+ # wait for close to finish, except when handling errors and
839
+ # forcecully disconnecting.
840
+ if not nowait :
840
841
await self ._writer .wait_closed () # type: ignore[union-attr]
841
842
except OSError :
842
843
pass
@@ -936,10 +937,10 @@ async def read_response(self, disable_decoding: bool = False):
936
937
disable_decoding = disable_decoding
937
938
)
938
939
except asyncio .TimeoutError :
939
- await self .disconnect ()
940
+ await self .disconnect (nowait = True )
940
941
raise TimeoutError (f"Timeout reading from { self .host } :{ self .port } " )
941
942
except OSError as e :
942
- await self .disconnect ()
943
+ await self .disconnect (nowait = True )
943
944
raise ConnectionError (
944
945
f"Error while reading from { self .host } :{ self .port } : { e .args } "
945
946
)
@@ -948,7 +949,7 @@ async def read_response(self, disable_decoding: bool = False):
948
949
# is subclass of Exception, not BaseException
949
950
raise
950
951
except Exception :
951
- await self .disconnect ()
952
+ await self .disconnect (nowait = True )
952
953
raise
953
954
954
955
if self .health_check_interval :
0 commit comments