23
23
from time import perf_counter
24
24
25
25
from ..._async_compat .network import AsyncBoltSocket
26
- from ..._exceptions import BoltHandshakeError
26
+ from ..._async_compat .util import AsyncUtil
27
+ from ..._exceptions import (
28
+ BoltError ,
29
+ BoltHandshakeError ,
30
+ SocketDeadlineExceeded ,
31
+ )
27
32
from ...addressing import Address
28
33
from ...api import (
29
34
ServerInfo ,
32
37
from ...conf import PoolConfig
33
38
from ...exceptions import (
34
39
AuthError ,
40
+ DriverError ,
35
41
IncompleteCommit ,
36
42
ServiceUnavailable ,
37
43
SessionExpired ,
@@ -76,6 +82,7 @@ class AsyncBolt:
76
82
idle_since = float ("-inf" )
77
83
78
84
# The socket
85
+ _closing = False
79
86
_closed = False
80
87
81
88
# The socket
@@ -260,24 +267,42 @@ async def ping(cls, address, *, timeout=None, **config):
260
267
261
268
@classmethod
262
269
async def open (
263
- cls , address , * , auth = None , timeout = None , routing_context = None , ** pool_config
270
+ cls , address , * , auth = None , timeout = None , routing_context = None ,
271
+ ** pool_config
264
272
):
265
- """ Open a new Bolt connection to a given server address.
273
+ """Open a new Bolt connection to a given server address.
266
274
267
275
:param address:
268
276
:param auth:
269
277
:param timeout: the connection timeout in seconds
270
278
:param routing_context: dict containing routing context
271
279
:param pool_config:
272
- :return:
273
- :raise BoltHandshakeError: raised if the Bolt Protocol can not negotiate a protocol version.
280
+
281
+ :return: connected AsyncBolt instance
282
+
283
+ :raise BoltHandshakeError:
284
+ raised if the Bolt Protocol can not negotiate a protocol version.
274
285
:raise ServiceUnavailable: raised if there was a connection issue.
275
286
"""
287
+ def time_remaining ():
288
+ if timeout is None :
289
+ return None
290
+ t = timeout - (perf_counter () - t0 )
291
+ return t if t > 0 else 0
292
+
293
+ t0 = perf_counter ()
276
294
pool_config = PoolConfig .consume (pool_config )
295
+
296
+ socket_connection_timeout = pool_config .connection_timeout
297
+ if socket_connection_timeout is None :
298
+ socket_connection_timeout = time_remaining ()
299
+ elif timeout is not None :
300
+ socket_connection_timeout = min (pool_config .connection_timeout ,
301
+ time_remaining ())
277
302
s , pool_config .protocol_version , handshake , data = \
278
303
await AsyncBoltSocket .connect (
279
304
address ,
280
- timeout = timeout ,
305
+ timeout = socket_connection_timeout ,
281
306
custom_resolver = pool_config .resolver ,
282
307
ssl_context = pool_config .get_ssl_context (),
283
308
keep_alive = pool_config .keep_alive ,
@@ -308,17 +333,31 @@ async def open(
308
333
AsyncBoltSocket .close_socket (s )
309
334
310
335
supported_versions = cls .protocol_handlers ().keys ()
311
- raise BoltHandshakeError ("The Neo4J server does not support communication with this driver. This driver have support for Bolt Protocols {}" .format (supported_versions ), address = address , request_data = handshake , response_data = data )
336
+ raise BoltHandshakeError (
337
+ "The Neo4J server does not support communication with this "
338
+ "driver. This driver have support for Bolt Protocols {}"
339
+ "" .format (supported_versions ),
340
+ address = address , request_data = handshake , response_data = data
341
+ )
312
342
313
343
connection = bolt_cls (
314
344
address , s , pool_config .max_connection_lifetime , auth = auth ,
315
345
user_agent = pool_config .user_agent , routing_context = routing_context
316
346
)
317
347
318
348
try :
319
- await connection .hello ()
349
+ connection .socket .set_deadline (time_remaining ())
350
+ try :
351
+ await connection .hello ()
352
+ except SocketDeadlineExceeded as e :
353
+ # connection._defunct = True
354
+ raise ServiceUnavailable (
355
+ "Timeout during initial handshake occurred"
356
+ ) from e
357
+ finally :
358
+ connection .socket .set_deadline (None )
320
359
except Exception :
321
- await connection .close ()
360
+ await connection .close_non_blocking ()
322
361
raise
323
362
324
363
return connection
@@ -440,6 +479,11 @@ async def reset(self):
440
479
"""
441
480
pass
442
481
482
+ @abc .abstractmethod
483
+ def goodbye (self ):
484
+ """Append a GOODBYE message to the outgoing queued."""
485
+ pass
486
+
443
487
def _append (self , signature , fields = (), response = None ):
444
488
""" Appends a message to the outgoing queue.
445
489
@@ -481,7 +525,8 @@ async def send_all(self):
481
525
await self ._send_all ()
482
526
483
527
@abc .abstractmethod
484
- async def _fetch_message (self ):
528
+ async def _process_message (self , details , summary_signature ,
529
+ summary_metadata ):
485
530
""" Receive at most one message from the server, if available.
486
531
487
532
:return: 2-tuple of number of detail messages and number of summary
@@ -505,7 +550,12 @@ async def fetch_message(self):
505
550
if not self .responses :
506
551
return 0 , 0
507
552
508
- res = await self ._fetch_message ()
553
+ # Receive exactly one message
554
+ details , summary_signature , summary_metadata = \
555
+ await AsyncUtil .next (self .inbox )
556
+ res = await self ._process_message (
557
+ details , summary_signature , summary_metadata
558
+ )
509
559
self .idle_since = perf_counter ()
510
560
return res
511
561
@@ -548,9 +598,13 @@ async def _set_defunct(self, message, error=None, silent=False):
548
598
# connection from the client side, and remove the address
549
599
# from the connection pool.
550
600
self ._defunct = True
551
- await self .close ()
552
- if self .pool :
553
- await self .pool .deactivate (address = self .unresolved_address )
601
+ if not self ._closing :
602
+ # If we fail while closing the connection, there is no need to
603
+ # remove the connection from the pool, nor to try to close the
604
+ # connection again.
605
+ await self .close ()
606
+ if self .pool :
607
+ await self .pool .deactivate (address = self .unresolved_address )
554
608
# Iterate through the outstanding responses, and if any correspond
555
609
# to COMMIT requests then raise an error to signal that we are
556
610
# unable to confirm that the COMMIT completed successfully.
@@ -584,11 +638,37 @@ def stale(self):
584
638
def set_stale (self ):
585
639
self ._stale = True
586
640
587
- @abc .abstractmethod
588
641
async def close (self ):
589
- """ Close the connection.
642
+ """Close the connection."""
643
+ if self ._closed or self ._closing :
644
+ return
645
+ self ._closing = True
646
+ if not self ._defunct :
647
+ self .goodbye ()
648
+ try :
649
+ await self ._send_all ()
650
+ except (OSError , BoltError , DriverError ):
651
+ pass
652
+ log .debug ("[#%04X] C: <CLOSE>" , self .local_port )
653
+ try :
654
+ self .socket .close ()
655
+ except OSError :
656
+ pass
657
+ finally :
658
+ self ._closed = True
659
+
660
+ async def close_non_blocking (self ):
661
+ """Set the socket to non-blocking and close it.
662
+
663
+ This will try to send the `GOODBYE` message (given the socket is not
664
+ marked as defunct). However, should the write operation require
665
+ blocking (e.g., a full network buffer), then the socket will be closed
666
+ immediately (without `GOODBYE` message).
590
667
"""
591
- pass
668
+ if self ._closed or self ._closing :
669
+ return
670
+ self .socket .settimeout (0 )
671
+ await self .close ()
592
672
593
673
@abc .abstractmethod
594
674
def closed (self ):
0 commit comments