Skip to content

Commit 0823655

Browse files
robsdedudebigmontz
andauthored
[4.4] Invalidate writers per database (#1039)
* Invalidate writers per database This should improve the performance of the driver in multi database use-cases. The driver now only removes a server as a writer for a single database (before for all databases) if that server returns an error that notifies the driver that the server is no longer a writer (`Neo.ClientError.Cluster.NotALeader` or `Neo.ClientError.General.ForbiddenOnReadOnlyDatabase`). * Minor code clean-up Co-authored-by: Antonio Barcélos <antonio.barcelos@neo4j.com>
1 parent 9f5c495 commit 0823655

File tree

10 files changed

+539
-37
lines changed

10 files changed

+539
-37
lines changed

neo4j/io/__init__.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,16 @@
105105
log = getLogger("neo4j")
106106

107107

108+
class ClientStateManagerBase(abc.ABC):
109+
@abc.abstractmethod
110+
def __init__(self, init_state, on_change=None):
111+
...
112+
113+
@abc.abstractmethod
114+
def transition(self, message):
115+
...
116+
117+
108118
class Bolt(abc.ABC):
109119
""" Server connection for Bolt protocol.
110120
@@ -125,6 +135,10 @@ class Bolt(abc.ABC):
125135
# The socket
126136
in_use = False
127137

138+
# The database name the connection was last used with
139+
# (BEGIN for explicit transactions, RUN for auto-commit transactions)
140+
last_database = None
141+
128142
# The socket
129143
_closing = False
130144
_closed = False
@@ -399,6 +413,10 @@ def __del__(self):
399413
except OSError:
400414
pass
401415

416+
@abc.abstractmethod
417+
def _get_client_state_manager(self):
418+
...
419+
402420
@abc.abstractmethod
403421
def route(self, database=None, imp_user=None, bookmarks=None):
404422
""" Fetch a routing table from the server for the given
@@ -504,6 +522,8 @@ def _append(self, signature, fields=(), response=None):
504522
self.packer.pack_struct(signature, fields)
505523
self.outbox.wrap_message()
506524
self.responses.append(response)
525+
if response:
526+
self._get_client_state_manager().transition(response.message)
507527

508528
def _send_all(self):
509529
with self.outbox.view() as data:
@@ -867,8 +887,10 @@ def deactivate(self, address):
867887
if not self.connections[address]:
868888
del self.connections[address]
869889

870-
def on_write_failure(self, address):
871-
raise WriteServiceUnavailable("No write service available for pool {}".format(self))
890+
def on_write_failure(self, address, database):
891+
raise WriteServiceUnavailable(
892+
"No write service available for pool {}".format(self)
893+
)
872894

873895
def close(self):
874896
""" Close all connections and empty the pool.
@@ -1342,13 +1364,15 @@ def deactivate(self, address):
13421364
log.debug("[#0000] C: <ROUTING> table=%r", self.routing_tables)
13431365
super(Neo4jPool, self).deactivate(address)
13441366

1345-
def on_write_failure(self, address):
1367+
def on_write_failure(self, address, database):
13461368
""" Remove a writer address from the routing table, if present.
13471369
"""
1348-
log.debug("[#0000] C: <ROUTING> Removing writer %r", address)
1370+
log.debug("[#0000] C: <ROUTING> Removing writer %r for database %r",
1371+
address, database)
13491372
with self.refresh_lock:
1350-
for database in self.routing_tables.keys():
1351-
self.routing_tables[database].writers.discard(address)
1373+
table = self.routing_tables.get(database)
1374+
if table is not None:
1375+
table.writers.discard(address)
13521376
log.debug("[#0000] C: <ROUTING> table=%r", self.routing_tables)
13531377

13541378

neo4j/io/_bolt.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [http://neo4j.com]
3+
#
4+
# This file is part of Neo4j.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
19+
import abc
20+
21+
22+
class ClientStateManagerBase(abc.ABC):
23+
@abc.abstractmethod
24+
def __init__(self, init_state, on_change=None):
25+
...
26+
27+
@abc.abstractmethod
28+
def transition(self, message):
29+
...

neo4j/io/_bolt3.py

Lines changed: 68 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
Bolt,
4444
check_supported_server_product,
4545
)
46+
from neo4j.io._bolt import ClientStateManagerBase
4647
from neo4j.io._common import (
4748
CommitResponse,
4849
InitResponse,
@@ -55,7 +56,7 @@
5556
log = getLogger("neo4j")
5657

5758

58-
class ServerStates(Enum):
59+
class BoltStates(Enum):
5960
CONNECTED = "CONNECTED"
6061
READY = "READY"
6162
STREAMING = "STREAMING"
@@ -65,25 +66,25 @@ class ServerStates(Enum):
6566

6667
class ServerStateManager:
6768
_STATE_TRANSITIONS = {
68-
ServerStates.CONNECTED: {
69-
"hello": ServerStates.READY,
69+
BoltStates.CONNECTED: {
70+
"hello": BoltStates.READY,
7071
},
71-
ServerStates.READY: {
72-
"run": ServerStates.STREAMING,
73-
"begin": ServerStates.TX_READY_OR_TX_STREAMING,
72+
BoltStates.READY: {
73+
"run": BoltStates.STREAMING,
74+
"begin": BoltStates.TX_READY_OR_TX_STREAMING,
7475
},
75-
ServerStates.STREAMING: {
76-
"pull": ServerStates.READY,
77-
"discard": ServerStates.READY,
78-
"reset": ServerStates.READY,
76+
BoltStates.STREAMING: {
77+
"pull": BoltStates.READY,
78+
"discard": BoltStates.READY,
79+
"reset": BoltStates.READY,
7980
},
80-
ServerStates.TX_READY_OR_TX_STREAMING: {
81-
"commit": ServerStates.READY,
82-
"rollback": ServerStates.READY,
83-
"reset": ServerStates.READY,
81+
BoltStates.TX_READY_OR_TX_STREAMING: {
82+
"commit": BoltStates.READY,
83+
"rollback": BoltStates.READY,
84+
"reset": BoltStates.READY,
8485
},
85-
ServerStates.FAILED: {
86-
"reset": ServerStates.READY,
86+
BoltStates.FAILED: {
87+
"reset": BoltStates.READY,
8788
}
8889
}
8990

@@ -102,6 +103,39 @@ def transition(self, message, metadata):
102103
self._on_change(state_before, self.state)
103104

104105

106+
class ClientStateManager(ClientStateManagerBase):
107+
_STATE_TRANSITIONS = {
108+
BoltStates.CONNECTED: {
109+
"hello": BoltStates.READY,
110+
},
111+
BoltStates.READY: {
112+
"run": BoltStates.STREAMING,
113+
"begin": BoltStates.TX_READY_OR_TX_STREAMING,
114+
},
115+
BoltStates.STREAMING: {
116+
"begin": BoltStates.TX_READY_OR_TX_STREAMING,
117+
"reset": BoltStates.READY,
118+
},
119+
BoltStates.TX_READY_OR_TX_STREAMING: {
120+
"commit": BoltStates.READY,
121+
"rollback": BoltStates.READY,
122+
"reset": BoltStates.READY,
123+
},
124+
}
125+
126+
def __init__(self, init_state, on_change=None):
127+
self.state = init_state
128+
self._on_change = on_change
129+
130+
def transition(self, message):
131+
state_before = self.state
132+
self.state = self._STATE_TRANSITIONS \
133+
.get(self.state, {}) \
134+
.get(message, self.state)
135+
if state_before != self.state and callable(self._on_change):
136+
self._on_change(state_before, self.state)
137+
138+
105139
class Bolt3(Bolt):
106140
""" Protocol handler for Bolt 3.
107141
@@ -117,13 +151,23 @@ class Bolt3(Bolt):
117151
def __init__(self, *args, **kwargs):
118152
super().__init__(*args, **kwargs)
119153
self._server_state_manager = ServerStateManager(
120-
ServerStates.CONNECTED, on_change=self._on_server_state_change
154+
BoltStates.CONNECTED, on_change=self._on_server_state_change
155+
)
156+
self._client_state_manager = ClientStateManager(
157+
BoltStates.CONNECTED, on_change=self._on_client_state_change
121158
)
122159

123160
def _on_server_state_change(self, old_state, new_state):
124-
log.debug("[#%04X] State: %s > %s", self.local_port,
161+
log.debug("[#%04X] Server State: %s > %s", self.local_port,
125162
old_state.name, new_state.name)
126163

164+
def _on_client_state_change(self, old_state, new_state):
165+
log.debug("[#%04X] Client state: %s > %s",
166+
self.local_port, old_state.name, new_state.name)
167+
168+
def _get_client_state_manager(self):
169+
return self._client_state_manager
170+
127171
@property
128172
def is_reset(self):
129173
# We can't be sure of the server's state if there are still pending
@@ -132,7 +176,7 @@ def is_reset(self):
132176
if (self.responses and self.responses[-1]
133177
and self.responses[-1].message == "reset"):
134178
return True
135-
return self._server_state_manager.state == ServerStates.READY
179+
return self._server_state_manager.state == BoltStates.READY
136180

137181
@property
138182
def encrypted(self):
@@ -342,7 +386,7 @@ def fetch_message(self):
342386
response.on_ignored(summary_metadata or {})
343387
elif summary_signature == b"\x7F":
344388
log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata)
345-
self._server_state_manager.state = ServerStates.FAILED
389+
self._server_state_manager.state = BoltStates.FAILED
346390
try:
347391
response.on_failure(summary_metadata or {})
348392
except (ServiceUnavailable, DatabaseUnavailable):
@@ -351,7 +395,10 @@ def fetch_message(self):
351395
raise
352396
except (NotALeader, ForbiddenOnReadOnlyDatabase):
353397
if self.pool:
354-
self.pool.on_write_failure(address=self.unresolved_address),
398+
self.pool.on_write_failure(
399+
address=self.unresolved_address,
400+
database=self.last_database,
401+
),
355402
raise
356403
except Neo4jError as e:
357404
if self.pool and e.invalidates_all_connections():

neo4j/io/_bolt4.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,9 @@
5151
tx_timeout_as_ms,
5252
)
5353
from neo4j.io._bolt3 import (
54+
BoltStates,
55+
ClientStateManager,
5456
ServerStateManager,
55-
ServerStates,
5657
)
5758

5859

@@ -74,13 +75,23 @@ class Bolt4x0(Bolt):
7475
def __init__(self, *args, **kwargs):
7576
super().__init__(*args, **kwargs)
7677
self._server_state_manager = ServerStateManager(
77-
ServerStates.CONNECTED, on_change=self._on_server_state_change
78+
BoltStates.CONNECTED, on_change=self._on_server_state_change
79+
)
80+
self._client_state_manager = ClientStateManager(
81+
BoltStates.CONNECTED, on_change=self._on_client_state_change
7882
)
7983

8084
def _on_server_state_change(self, old_state, new_state):
81-
log.debug("[#%04X] State: %s > %s", self.local_port,
85+
log.debug("[#%04X] Server state: %s > %s", self.local_port,
8286
old_state.name, new_state.name)
8387

88+
def _on_client_state_change(self, old_state, new_state):
89+
log.debug("[#%04X] Client state: %s > %s",
90+
self.local_port, old_state.name, new_state.name)
91+
92+
def _get_client_state_manager(self):
93+
return self._client_state_manager
94+
8495
@property
8596
def is_reset(self):
8697
# We can't be sure of the server's state if there are still pending
@@ -89,7 +100,7 @@ def is_reset(self):
89100
if (self.responses and self.responses[-1]
90101
and self.responses[-1].message == "reset"):
91102
return True
92-
return self._server_state_manager.state == ServerStates.READY
103+
return self._server_state_manager.state == BoltStates.READY
93104

94105
@property
95106
def encrypted(self):
@@ -169,6 +180,9 @@ def run(self, query, parameters=None, mode=None, bookmarks=None,
169180
extra["mode"] = "r" # It will default to mode "w" if nothing is specified
170181
if db:
171182
extra["db"] = db
183+
client_state = self._client_state_manager.state
184+
if client_state != BoltStates.TX_READY_OR_TX_STREAMING:
185+
self.last_database = db
172186
if bookmarks:
173187
try:
174188
extra["bookmarks"] = list(bookmarks)
@@ -217,6 +231,7 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
217231
extra["mode"] = "r" # It will default to mode "w" if nothing is specified
218232
if db:
219233
extra["db"] = db
234+
self.last_database = db
220235
if bookmarks:
221236
try:
222237
extra["bookmarks"] = list(bookmarks)
@@ -294,7 +309,7 @@ def fetch_message(self):
294309
response.on_ignored(summary_metadata or {})
295310
elif summary_signature == b"\x7F":
296311
log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata)
297-
self._server_state_manager.state = ServerStates.FAILED
312+
self._server_state_manager.state = BoltStates.FAILED
298313
try:
299314
response.on_failure(summary_metadata or {})
300315
except (ServiceUnavailable, DatabaseUnavailable):
@@ -303,7 +318,10 @@ def fetch_message(self):
303318
raise
304319
except (NotALeader, ForbiddenOnReadOnlyDatabase):
305320
if self.pool:
306-
self.pool.on_write_failure(address=self.unresolved_address),
321+
self.pool.on_write_failure(
322+
address=self.unresolved_address,
323+
database=self.last_database,
324+
),
307325
raise
308326
except Neo4jError as e:
309327
if self.pool and e.invalidates_all_connections():
@@ -471,6 +489,9 @@ def run(self, query, parameters=None, mode=None, bookmarks=None,
471489
extra["mode"] = "r"
472490
if db:
473491
extra["db"] = db
492+
client_state = self._client_state_manager.state
493+
if client_state != BoltStates.TX_READY_OR_TX_STREAMING:
494+
self.last_database = db
474495
if imp_user:
475496
extra["imp_user"] = imp_user
476497
if bookmarks:
@@ -502,6 +523,7 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
502523
extra["mode"] = "r"
503524
if db:
504525
extra["db"] = db
526+
self.last_database = db
505527
if imp_user:
506528
extra["imp_user"] = imp_user
507529
if bookmarks:

0 commit comments

Comments
 (0)