Skip to content

Commit b559c40

Browse files
committed
Invalidate writers per database
This should improve the performance of the driver in multi database use-cases. The driver now only removes a server as a writer for a single database (before for all databases) if that server returns an error that notifies the driver that the server is no longer a writer (`Neo.ClientError.Cluster.NotALeader` or `Neo.ClientError.General.ForbiddenOnReadOnlyDatabase`).
1 parent a09e25f commit b559c40

31 files changed

+1609
-187
lines changed

src/neo4j/_async/io/_bolt.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import abc
2222
import asyncio
23+
import typing as t
2324
from collections import deque
2425
from logging import getLogger
2526
from time import perf_counter
@@ -74,6 +75,16 @@ def failed(self):
7475
...
7576

7677

78+
class ClientStateManagerBase(abc.ABC):
79+
@abc.abstractmethod
80+
def __init__(self, init_state, on_change=None):
81+
...
82+
83+
@abc.abstractmethod
84+
def transition(self, message):
85+
...
86+
87+
7788
class AsyncBolt:
7889
""" Server connection for Bolt protocol.
7990
@@ -103,12 +114,13 @@ class AsyncBolt:
103114

104115
# When the connection was last put back into the pool
105116
idle_since = float("-inf")
117+
# The database name the connection was last used with
118+
# (BEGIN for explicit transactions, RUN for auto-commit transactions)
119+
last_database: t.Optional[str] = None
106120

107121
# The socket
108122
_closing = False
109123
_closed = False
110-
111-
# The socket
112124
_defunct = False
113125

114126
#: The pool of which this connection is a member
@@ -173,6 +185,10 @@ def __del__(self):
173185
def _get_server_state_manager(self) -> ServerStateManagerBase:
174186
...
175187

188+
@abc.abstractmethod
189+
def _get_client_state_manager(self) -> ClientStateManagerBase:
190+
...
191+
176192
@classmethod
177193
def _to_auth_dict(cls, auth):
178194
# Determine auth details
@@ -753,6 +769,8 @@ def _append(self, signature, fields=(), response=None,
753769
"""
754770
self.outbox.append_message(signature, fields, dehydration_hooks)
755771
self.responses.append(response)
772+
if response:
773+
self._get_client_state_manager().transition(response.message)
756774

757775
async def _send_all(self):
758776
if await self.outbox.flush():

src/neo4j/_async/io/_bolt3.py

Lines changed: 69 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from ._bolt import (
4040
AsyncBolt,
41+
ClientStateManagerBase,
4142
ServerStateManagerBase,
4243
tx_timeout_as_ms,
4344
)
@@ -52,7 +53,7 @@
5253
log = getLogger("neo4j")
5354

5455

55-
class ServerStates(Enum):
56+
class BoltStates(Enum):
5657
CONNECTED = "CONNECTED"
5758
READY = "READY"
5859
STREAMING = "STREAMING"
@@ -62,25 +63,25 @@ class ServerStates(Enum):
6263

6364
class ServerStateManager(ServerStateManagerBase):
6465
_STATE_TRANSITIONS: t.Dict[Enum, t.Dict[str, Enum]] = {
65-
ServerStates.CONNECTED: {
66-
"hello": ServerStates.READY,
66+
BoltStates.CONNECTED: {
67+
"hello": BoltStates.READY,
6768
},
68-
ServerStates.READY: {
69-
"run": ServerStates.STREAMING,
70-
"begin": ServerStates.TX_READY_OR_TX_STREAMING,
69+
BoltStates.READY: {
70+
"run": BoltStates.STREAMING,
71+
"begin": BoltStates.TX_READY_OR_TX_STREAMING,
7172
},
72-
ServerStates.STREAMING: {
73-
"pull": ServerStates.READY,
74-
"discard": ServerStates.READY,
75-
"reset": ServerStates.READY,
73+
BoltStates.STREAMING: {
74+
"pull": BoltStates.READY,
75+
"discard": BoltStates.READY,
76+
"reset": BoltStates.READY,
7677
},
77-
ServerStates.TX_READY_OR_TX_STREAMING: {
78-
"commit": ServerStates.READY,
79-
"rollback": ServerStates.READY,
80-
"reset": ServerStates.READY,
78+
BoltStates.TX_READY_OR_TX_STREAMING: {
79+
"commit": BoltStates.READY,
80+
"rollback": BoltStates.READY,
81+
"reset": BoltStates.READY,
8182
},
82-
ServerStates.FAILED: {
83-
"reset": ServerStates.READY,
83+
BoltStates.FAILED: {
84+
"reset": BoltStates.READY,
8485
}
8586
}
8687

@@ -99,7 +100,40 @@ def transition(self, message, metadata):
99100
self._on_change(state_before, self.state)
100101

101102
def failed(self):
102-
return self.state == ServerStates.FAILED
103+
return self.state == BoltStates.FAILED
104+
105+
106+
class ClientStateManager(ClientStateManagerBase):
107+
_STATE_TRANSITIONS: t.Dict[Enum, t.Dict[str, Enum]] = {
108+
BoltStates.CONNECTED: {
109+
"hello": BoltStates.READY,
110+
},
111+
BoltStates.READY: {
112+
"run": BoltStates.STREAMING,
113+
"begin": BoltStates.TX_READY_OR_TX_STREAMING,
114+
},
115+
BoltStates.STREAMING: {
116+
"begin": BoltStates.TX_READY_OR_TX_STREAMING,
117+
"reset": BoltStates.READY,
118+
},
119+
BoltStates.TX_READY_OR_TX_STREAMING: {
120+
"commit": BoltStates.READY,
121+
"rollback": BoltStates.READY,
122+
"reset": BoltStates.READY,
123+
},
124+
}
125+
126+
def __init__(self, init_state, on_change=None):
127+
self.state = init_state
128+
self._on_change = on_change
129+
130+
def transition(self, message):
131+
state_before = self.state
132+
self.state = self._STATE_TRANSITIONS \
133+
.get(self.state, {}) \
134+
.get(message, self.state)
135+
if state_before != self.state and callable(self._on_change):
136+
self._on_change(state_before, self.state)
103137

104138

105139
class AsyncBolt3(AsyncBolt):
@@ -121,25 +155,34 @@ class AsyncBolt3(AsyncBolt):
121155
def __init__(self, *args, **kwargs):
122156
super().__init__(*args, **kwargs)
123157
self._server_state_manager = ServerStateManager(
124-
ServerStates.CONNECTED, on_change=self._on_server_state_change
158+
BoltStates.CONNECTED, on_change=self._on_server_state_change
159+
)
160+
self._client_state_manager = ClientStateManager(
161+
BoltStates.CONNECTED, on_change=self._on_client_state_change
125162
)
126163

127164
def _on_server_state_change(self, old_state, new_state):
128-
log.debug("[#%04X] _: <CONNECTION> state: %s > %s", self.local_port,
129-
old_state.name, new_state.name)
165+
log.debug("[#%04X] _: <CONNECTION> server state: %s > %s",
166+
self.local_port, old_state.name, new_state.name)
130167

131168
def _get_server_state_manager(self) -> ServerStateManagerBase:
132169
return self._server_state_manager
133170

171+
def _on_client_state_change(self, old_state, new_state):
172+
log.debug("[#%04X] _: <CONNECTION> client state: %s > %s",
173+
self.local_port, old_state.name, new_state.name)
174+
175+
def _get_client_state_manager(self) -> ClientStateManagerBase:
176+
return self._client_state_manager
177+
134178
@property
135179
def is_reset(self):
136180
# We can't be sure of the server's state if there are still pending
137181
# responses. Unless the last message we sent was RESET. In that case
138182
# the server state will always be READY when we're done.
139-
if (self.responses and self.responses[-1]
140-
and self.responses[-1].message == "reset"):
141-
return True
142-
return self._server_state_manager.state == ServerStates.READY
183+
if self.responses:
184+
return self.responses[-1] and self.responses[-1].message == "reset"
185+
return self._server_state_manager.state == BoltStates.READY
143186

144187
@property
145188
def encrypted(self):
@@ -216,7 +259,7 @@ async def route(
216259
hydration_hooks=hydration_hooks,
217260
on_success=metadata.update
218261
)
219-
self.pull(dehydration_hooks = None, hydration_hooks = None,
262+
self.pull(dehydration_hooks=None, hydration_hooks=None,
220263
on_success=metadata.update, on_records=records.extend)
221264
await self.send_all()
222265
await self.fetch_all()
@@ -398,7 +441,7 @@ async def _process_message(self, tag, fields):
398441
await response.on_ignored(summary_metadata or {})
399442
elif summary_signature == b"\x7F":
400443
log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata)
401-
self._server_state_manager.state = ServerStates.FAILED
444+
self._server_state_manager.state = BoltStates.FAILED
402445
try:
403446
await response.on_failure(summary_metadata or {})
404447
except (ServiceUnavailable, DatabaseUnavailable):

src/neo4j/_async/io/_bolt4.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,14 @@
3535
)
3636
from ._bolt import (
3737
AsyncBolt,
38+
ClientStateManagerBase,
3839
ServerStateManagerBase,
3940
tx_timeout_as_ms,
4041
)
4142
from ._bolt3 import (
43+
BoltStates,
44+
ClientStateManager,
4245
ServerStateManager,
43-
ServerStates,
4446
)
4547
from ._common import (
4648
check_supported_server_product,
@@ -72,25 +74,34 @@ class AsyncBolt4x0(AsyncBolt):
7274
def __init__(self, *args, **kwargs):
7375
super().__init__(*args, **kwargs)
7476
self._server_state_manager = ServerStateManager(
75-
ServerStates.CONNECTED, on_change=self._on_server_state_change
77+
BoltStates.CONNECTED, on_change=self._on_server_state_change
78+
)
79+
self._client_state_manager = ClientStateManager(
80+
BoltStates.CONNECTED, on_change=self._on_client_state_change
7681
)
7782

7883
def _on_server_state_change(self, old_state, new_state):
79-
log.debug("[#%04X] _: <CONNECTION> state: %s > %s", self.local_port,
80-
old_state.name, new_state.name)
84+
log.debug("[#%04X] _: <CONNECTION> server state: %s > %s",
85+
self.local_port, old_state.name, new_state.name)
8186

8287
def _get_server_state_manager(self) -> ServerStateManagerBase:
8388
return self._server_state_manager
8489

90+
def _on_client_state_change(self, old_state, new_state):
91+
log.debug("[#%04X] _: <CONNECTION> client state: %s > %s",
92+
self.local_port, old_state.name, new_state.name)
93+
94+
def _get_client_state_manager(self) -> ClientStateManagerBase:
95+
return self._client_state_manager
96+
8597
@property
8698
def is_reset(self):
8799
# We can't be sure of the server's state if there are still pending
88100
# responses. Unless the last message we sent was RESET. In that case
89101
# the server state will always be READY when we're done.
90-
if (self.responses and self.responses[-1]
91-
and self.responses[-1].message == "reset"):
92-
return True
93-
return self._server_state_manager.state == ServerStates.READY
102+
if self.responses:
103+
return self.responses[-1] and self.responses[-1].message == "reset"
104+
return self._server_state_manager.state == BoltStates.READY
94105

95106
@property
96107
def encrypted(self):
@@ -202,6 +213,8 @@ def run(self, query, parameters=None, mode=None, bookmarks=None,
202213
extra["mode"] = "r" # It will default to mode "w" if nothing is specified
203214
if db:
204215
extra["db"] = db
216+
if self._client_state_manager.state != BoltStates.TX_READY_OR_TX_STREAMING:
217+
self.last_database = db
205218
if bookmarks:
206219
try:
207220
extra["bookmarks"] = list(bookmarks)
@@ -261,6 +274,7 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
261274
extra["mode"] = "r" # It will default to mode "w" if nothing is specified
262275
if db:
263276
extra["db"] = db
277+
self.last_database = db
264278
if bookmarks:
265279
try:
266280
extra["bookmarks"] = list(bookmarks)
@@ -347,7 +361,7 @@ async def _process_message(self, tag, fields):
347361
await response.on_ignored(summary_metadata or {})
348362
elif summary_signature == b"\x7F":
349363
log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata)
350-
self._server_state_manager.state = ServerStates.FAILED
364+
self._server_state_manager.state = BoltStates.FAILED
351365
try:
352366
await response.on_failure(summary_metadata or {})
353367
except (ServiceUnavailable, DatabaseUnavailable):
@@ -357,7 +371,8 @@ async def _process_message(self, tag, fields):
357371
except (NotALeader, ForbiddenOnReadOnlyDatabase):
358372
if self.pool:
359373
await self.pool.on_write_failure(
360-
address=self.unresolved_address
374+
address=self.unresolved_address,
375+
database=self.last_database
361376
)
362377
raise
363378
except Neo4jError as e:
@@ -535,6 +550,11 @@ def run(self, query, parameters=None, mode=None, bookmarks=None,
535550
extra["mode"] = "r"
536551
if db:
537552
extra["db"] = db
553+
if (
554+
self._client_state_manager.state
555+
!= BoltStates.TX_READY_OR_TX_STREAMING
556+
):
557+
self.last_database = db
538558
if imp_user:
539559
extra["imp_user"] = imp_user
540560
if bookmarks:
@@ -571,6 +591,7 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
571591
extra["mode"] = "r"
572592
if db:
573593
extra["db"] = db
594+
self.last_database = db
574595
if imp_user:
575596
extra["imp_user"] = imp_user
576597
if bookmarks:

0 commit comments

Comments
 (0)