Skip to content

Invalidate writers per database #959

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions src/neo4j/_async/io/_bolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import abc
import asyncio
import typing as t
from collections import deque
from logging import getLogger
from time import perf_counter
Expand Down Expand Up @@ -74,6 +75,16 @@ def failed(self):
...


class ClientStateManagerBase(abc.ABC):
@abc.abstractmethod
def __init__(self, init_state, on_change=None):
...

@abc.abstractmethod
def transition(self, message):
...


class AsyncBolt:
""" Server connection for Bolt protocol.

Expand Down Expand Up @@ -103,12 +114,13 @@ class AsyncBolt:

# When the connection was last put back into the pool
idle_since = float("-inf")
# The database name the connection was last used with
# (BEGIN for explicit transactions, RUN for auto-commit transactions)
last_database: t.Optional[str] = None

# The socket
_closing = False
_closed = False

# The socket
_defunct = False

#: The pool of which this connection is a member
Expand Down Expand Up @@ -173,6 +185,10 @@ def __del__(self):
def _get_server_state_manager(self) -> ServerStateManagerBase:
...

@abc.abstractmethod
def _get_client_state_manager(self) -> ClientStateManagerBase:
...

@classmethod
def _to_auth_dict(cls, auth):
# Determine auth details
Expand Down Expand Up @@ -753,6 +769,8 @@ def _append(self, signature, fields=(), response=None,
"""
self.outbox.append_message(signature, fields, dehydration_hooks)
self.responses.append(response)
if response:
self._get_client_state_manager().transition(response.message)

async def _send_all(self):
if await self.outbox.flush():
Expand Down
98 changes: 71 additions & 27 deletions src/neo4j/_async/io/_bolt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from ._bolt import (
AsyncBolt,
ClientStateManagerBase,
ServerStateManagerBase,
tx_timeout_as_ms,
)
Expand All @@ -52,7 +53,7 @@
log = getLogger("neo4j")


class ServerStates(Enum):
class BoltStates(Enum):
CONNECTED = "CONNECTED"
READY = "READY"
STREAMING = "STREAMING"
Expand All @@ -62,25 +63,25 @@ class ServerStates(Enum):

class ServerStateManager(ServerStateManagerBase):
_STATE_TRANSITIONS: t.Dict[Enum, t.Dict[str, Enum]] = {
ServerStates.CONNECTED: {
"hello": ServerStates.READY,
BoltStates.CONNECTED: {
"hello": BoltStates.READY,
},
ServerStates.READY: {
"run": ServerStates.STREAMING,
"begin": ServerStates.TX_READY_OR_TX_STREAMING,
BoltStates.READY: {
"run": BoltStates.STREAMING,
"begin": BoltStates.TX_READY_OR_TX_STREAMING,
},
ServerStates.STREAMING: {
"pull": ServerStates.READY,
"discard": ServerStates.READY,
"reset": ServerStates.READY,
BoltStates.STREAMING: {
"pull": BoltStates.READY,
"discard": BoltStates.READY,
"reset": BoltStates.READY,
},
ServerStates.TX_READY_OR_TX_STREAMING: {
"commit": ServerStates.READY,
"rollback": ServerStates.READY,
"reset": ServerStates.READY,
BoltStates.TX_READY_OR_TX_STREAMING: {
"commit": BoltStates.READY,
"rollback": BoltStates.READY,
"reset": BoltStates.READY,
},
ServerStates.FAILED: {
"reset": ServerStates.READY,
BoltStates.FAILED: {
"reset": BoltStates.READY,
}
}

Expand All @@ -99,7 +100,40 @@ def transition(self, message, metadata):
self._on_change(state_before, self.state)

def failed(self):
return self.state == ServerStates.FAILED
return self.state == BoltStates.FAILED


class ClientStateManager(ClientStateManagerBase):
_STATE_TRANSITIONS: t.Dict[Enum, t.Dict[str, Enum]] = {
BoltStates.CONNECTED: {
"hello": BoltStates.READY,
},
BoltStates.READY: {
"run": BoltStates.STREAMING,
"begin": BoltStates.TX_READY_OR_TX_STREAMING,
},
BoltStates.STREAMING: {
"begin": BoltStates.TX_READY_OR_TX_STREAMING,
"reset": BoltStates.READY,
},
BoltStates.TX_READY_OR_TX_STREAMING: {
"commit": BoltStates.READY,
"rollback": BoltStates.READY,
"reset": BoltStates.READY,
},
}

def __init__(self, init_state, on_change=None):
self.state = init_state
self._on_change = on_change

def transition(self, message):
state_before = self.state
self.state = self._STATE_TRANSITIONS \
.get(self.state, {}) \
.get(message, self.state)
if state_before != self.state and callable(self._on_change):
self._on_change(state_before, self.state)


class AsyncBolt3(AsyncBolt):
Expand All @@ -121,25 +155,34 @@ class AsyncBolt3(AsyncBolt):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._server_state_manager = ServerStateManager(
ServerStates.CONNECTED, on_change=self._on_server_state_change
BoltStates.CONNECTED, on_change=self._on_server_state_change
)
self._client_state_manager = ClientStateManager(
BoltStates.CONNECTED, on_change=self._on_client_state_change
)

def _on_server_state_change(self, old_state, new_state):
log.debug("[#%04X] _: <CONNECTION> state: %s > %s", self.local_port,
old_state.name, new_state.name)
log.debug("[#%04X] _: <CONNECTION> server state: %s > %s",
self.local_port, old_state.name, new_state.name)

def _get_server_state_manager(self) -> ServerStateManagerBase:
return self._server_state_manager

def _on_client_state_change(self, old_state, new_state):
log.debug("[#%04X] _: <CONNECTION> client state: %s > %s",
self.local_port, old_state.name, new_state.name)

def _get_client_state_manager(self) -> ClientStateManagerBase:
return self._client_state_manager

@property
def is_reset(self):
# We can't be sure of the server's state if there are still pending
# responses. Unless the last message we sent was RESET. In that case
# the server state will always be READY when we're done.
if (self.responses and self.responses[-1]
and self.responses[-1].message == "reset"):
return True
return self._server_state_manager.state == ServerStates.READY
if self.responses:
return self.responses[-1] and self.responses[-1].message == "reset"
return self._server_state_manager.state == BoltStates.READY

@property
def encrypted(self):
Expand Down Expand Up @@ -216,7 +259,7 @@ async def route(
hydration_hooks=hydration_hooks,
on_success=metadata.update
)
self.pull(dehydration_hooks = None, hydration_hooks = None,
self.pull(dehydration_hooks=None, hydration_hooks=None,
on_success=metadata.update, on_records=records.extend)
await self.send_all()
await self.fetch_all()
Expand Down Expand Up @@ -398,7 +441,7 @@ async def _process_message(self, tag, fields):
await response.on_ignored(summary_metadata or {})
elif summary_signature == b"\x7F":
log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata)
self._server_state_manager.state = ServerStates.FAILED
self._server_state_manager.state = BoltStates.FAILED
try:
await response.on_failure(summary_metadata or {})
except (ServiceUnavailable, DatabaseUnavailable):
Expand All @@ -408,7 +451,8 @@ async def _process_message(self, tag, fields):
except (NotALeader, ForbiddenOnReadOnlyDatabase):
if self.pool:
await self.pool.on_write_failure(
address=self.unresolved_address
address=self.unresolved_address,
database=self.last_database,
)
raise
except Neo4jError as e:
Expand Down
41 changes: 31 additions & 10 deletions src/neo4j/_async/io/_bolt4.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@
)
from ._bolt import (
AsyncBolt,
ClientStateManagerBase,
ServerStateManagerBase,
tx_timeout_as_ms,
)
from ._bolt3 import (
BoltStates,
ClientStateManager,
ServerStateManager,
ServerStates,
)
from ._common import (
check_supported_server_product,
Expand Down Expand Up @@ -72,25 +74,34 @@ class AsyncBolt4x0(AsyncBolt):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._server_state_manager = ServerStateManager(
ServerStates.CONNECTED, on_change=self._on_server_state_change
BoltStates.CONNECTED, on_change=self._on_server_state_change
)
self._client_state_manager = ClientStateManager(
BoltStates.CONNECTED, on_change=self._on_client_state_change
)

def _on_server_state_change(self, old_state, new_state):
log.debug("[#%04X] _: <CONNECTION> state: %s > %s", self.local_port,
old_state.name, new_state.name)
log.debug("[#%04X] _: <CONNECTION> server state: %s > %s",
self.local_port, old_state.name, new_state.name)

def _get_server_state_manager(self) -> ServerStateManagerBase:
return self._server_state_manager

def _on_client_state_change(self, old_state, new_state):
log.debug("[#%04X] _: <CONNECTION> client state: %s > %s",
self.local_port, old_state.name, new_state.name)

def _get_client_state_manager(self) -> ClientStateManagerBase:
return self._client_state_manager

@property
def is_reset(self):
# We can't be sure of the server's state if there are still pending
# responses. Unless the last message we sent was RESET. In that case
# the server state will always be READY when we're done.
if (self.responses and self.responses[-1]
and self.responses[-1].message == "reset"):
return True
return self._server_state_manager.state == ServerStates.READY
if self.responses:
return self.responses[-1] and self.responses[-1].message == "reset"
return self._server_state_manager.state == BoltStates.READY

@property
def encrypted(self):
Expand Down Expand Up @@ -202,6 +213,8 @@ def run(self, query, parameters=None, mode=None, bookmarks=None,
extra["mode"] = "r" # It will default to mode "w" if nothing is specified
if db:
extra["db"] = db
if self._client_state_manager.state != BoltStates.TX_READY_OR_TX_STREAMING:
self.last_database = db
if bookmarks:
try:
extra["bookmarks"] = list(bookmarks)
Expand Down Expand Up @@ -261,6 +274,7 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
extra["mode"] = "r" # It will default to mode "w" if nothing is specified
if db:
extra["db"] = db
self.last_database = db
if bookmarks:
try:
extra["bookmarks"] = list(bookmarks)
Expand Down Expand Up @@ -347,7 +361,7 @@ async def _process_message(self, tag, fields):
await response.on_ignored(summary_metadata or {})
elif summary_signature == b"\x7F":
log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata)
self._server_state_manager.state = ServerStates.FAILED
self._server_state_manager.state = BoltStates.FAILED
try:
await response.on_failure(summary_metadata or {})
except (ServiceUnavailable, DatabaseUnavailable):
Expand All @@ -357,7 +371,8 @@ async def _process_message(self, tag, fields):
except (NotALeader, ForbiddenOnReadOnlyDatabase):
if self.pool:
await self.pool.on_write_failure(
address=self.unresolved_address
address=self.unresolved_address,
database=self.last_database
)
raise
except Neo4jError as e:
Expand Down Expand Up @@ -535,6 +550,11 @@ def run(self, query, parameters=None, mode=None, bookmarks=None,
extra["mode"] = "r"
if db:
extra["db"] = db
if (
self._client_state_manager.state
!= BoltStates.TX_READY_OR_TX_STREAMING
):
self.last_database = db
if imp_user:
extra["imp_user"] = imp_user
if bookmarks:
Expand Down Expand Up @@ -571,6 +591,7 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
extra["mode"] = "r"
if db:
extra["db"] = db
self.last_database = db
if imp_user:
extra["imp_user"] = imp_user
if bookmarks:
Expand Down
Loading