diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 20ac1122..7d52c336 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -58,7 +58,6 @@ ) from threading import ( Condition, - Lock, RLock, ) from time import perf_counter @@ -875,7 +874,7 @@ def __init__(self, opener, pool_config, workspace_config, address): log.debug("[#0000] C: routing address %r", address) self.address = address self.routing_tables = {workspace_config.database: RoutingTable(database=workspace_config.database, routers=[address])} - self.refresh_lock = Lock() + self.refresh_lock = RLock() def __repr__(self): """ The representation shows the initial routing addresses. @@ -1109,23 +1108,25 @@ def _select_address(self, *, access_mode, database, bookmarks): from neo4j.api import READ_ACCESS """ Selects the address with the fewest in-use connections. """ - self.create_routing_table(database) - self.ensure_routing_table_is_fresh( - access_mode=access_mode, database=database, bookmarks=bookmarks - ) - log.debug("[#0000] C: %r", self.routing_tables) - if access_mode == READ_ACCESS: - addresses = self.routing_tables[database].readers - else: - addresses = self.routing_tables[database].writers - addresses_by_usage = {} - for address in addresses: - addresses_by_usage.setdefault(self.in_use_connection_count(address), []).append(address) + with self.refresh_lock: + if access_mode == READ_ACCESS: + addresses = self.routing_tables[database].readers + else: + addresses = self.routing_tables[database].writers + addresses_by_usage = {} + for address in addresses: + addresses_by_usage.setdefault( + self.in_use_connection_count(address), [] + ).append(address) if not addresses_by_usage: if access_mode == READ_ACCESS: - raise ReadServiceUnavailable("No read service currently available") + raise ReadServiceUnavailable( + "No read service currently available" + ) else: - raise WriteServiceUnavailable("No write service currently available") + raise WriteServiceUnavailable( + "No write service currently available" + ) return choice(addresses_by_usage[min(addresses_by_usage)]) def acquire(self, access_mode=None, timeout=None, database=None, @@ -1137,6 +1138,13 @@ def acquire(self, access_mode=None, timeout=None, database=None, from neo4j.api import check_access_mode access_mode = check_access_mode(access_mode) + with self.refresh_lock: + self.create_routing_table(database) + log.debug("[#0000] C: %r", self.routing_tables) + self.ensure_routing_table_is_fresh( + access_mode=access_mode, database=database, bookmarks=bookmarks + ) + while True: try: # Get an address for a connection that have the fewest in-use connections. @@ -1144,10 +1152,10 @@ def acquire(self, access_mode=None, timeout=None, database=None, access_mode=access_mode, database=database, bookmarks=bookmarks ) - log.debug("[#0000] C: database=%r address=%r", database, address) except (ReadServiceUnavailable, WriteServiceUnavailable) as err: raise SessionExpired("Failed to obtain connection towards '%s' server." % access_mode) from err try: + log.debug("[#0000] C: database=%r address=%r", database, address) connection = self._acquire(address, timeout=timeout) # should always be a resolved address except ServiceUnavailable: self.deactivate(address=address) diff --git a/testkitbackend/requests.py b/testkitbackend/requests.py index 70036e82..a23091fb 100644 --- a/testkitbackend/requests.py +++ b/testkitbackend/requests.py @@ -75,13 +75,18 @@ def NewDriver(backend, data): if data["resolverRegistered"] or data["domainNameResolverRegistered"]: resolver = resolution_func(backend, data["resolverRegistered"], data["domainNameResolverRegistered"]) - connection_timeout = data.get("connectionTimeoutMs", None) + connection_timeout = data.get("connectionTimeoutMs") if connection_timeout is not None: connection_timeout /= 1000 + max_transaction_retry_time = data.get("maxTxRetryTimeMs") + if max_transaction_retry_time is not None: + max_transaction_retry_time /= 1000 data.mark_item_as_read("domainNameResolverRegistered") driver = neo4j.GraphDatabase.driver( data["uri"], auth=auth, user_agent=data["userAgent"], - resolver=resolver, connection_timeout=connection_timeout + resolver=resolver, connection_timeout=connection_timeout, + fetch_size=data.get("fetchSize"), + max_transaction_retry_time=max_transaction_retry_time, ) key = backend.next_key() backend.drivers[key] = driver @@ -304,6 +309,13 @@ def TransactionRollback(backend, data): backend.send_response("Transaction", {"id": key}) +def TransactionClose(backend, data): + key = data["txId"] + tx = backend.transactions[key] + tx.close() + backend.send_response("Transaction", {"id": key}) + + def ResultNext(backend, data): result = backend.results[data["resultId"]] try: diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 49287b8a..5c1291ab 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -32,13 +32,18 @@ "Feature:Auth:Custom": true, "Feature:Auth:Kerberos": true, "AuthorizationExpiredTreatment": true, + "Optimization:ConnectionReuse": true, + "Optimization:EagerTransactionBegin": true, "Optimization:ImplicitDefaultArguments": true, "Optimization:MinimalResets": true, - "Optimization:ConnectionReuse": true, "Optimization:PullPipelining": true, "ConfHint:connection.recv_timeout_seconds": true, - "Temporary:ResultKeys": true, + "Temporary:CypherPathAndRelationship": true, + "Temporary:DriverFetchSize": true, + "Temporary:DriverMaxTxRetryTime": true, "Temporary:FullSummary": true, - "Temporary:CypherPathAndRelationship": true + "Temporary:ResultKeys": true, + "Temporary:ResultList": "requires further specification/discussion in the team", + "Temporary:TransactionClose": true } }