Skip to content

Don't send RESET on READY (clean) connections #572

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions neo4j/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class Bolt(abc.ABC):
PROTOCOL_VERSION = None

# flag if connection needs RESET to go back to READY state
_is_reset = True
is_reset = False

# The socket
in_use = False
Expand Down Expand Up @@ -460,10 +460,6 @@ def rollback(self, **handlers):
""" Appends a ROLLBACK message to the output queue."""
pass

@property
def is_reset(self):
return self._is_reset

@abc.abstractmethod
def reset(self):
""" Appends a RESET message to the outgoing queue, sends it and consumes
Expand Down
94 changes: 81 additions & 13 deletions neo4j/io/_bolt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from logging import getLogger
from ssl import SSLSocket

Expand Down Expand Up @@ -52,6 +53,53 @@
log = getLogger("neo4j")


class ServerStates(Enum):
CONNECTED = "CONNECTED"
READY = "READY"
STREAMING = "STREAMING"
TX_READY_OR_TX_STREAMING = "TX_READY||TX_STREAMING"
FAILED = "FAILED"


class ServerStateManager:
_STATE_TRANSITIONS = {
ServerStates.CONNECTED: {
"hello": ServerStates.READY,
},
ServerStates.READY: {
"run": ServerStates.STREAMING,
"begin": ServerStates.TX_READY_OR_TX_STREAMING,
},
ServerStates.STREAMING: {
"pull": ServerStates.READY,
"discard": ServerStates.READY,
"reset": ServerStates.READY,
},
ServerStates.TX_READY_OR_TX_STREAMING: {
"commit": ServerStates.READY,
"rollback": ServerStates.READY,
"reset": ServerStates.READY,
},
ServerStates.FAILED: {
"reset": ServerStates.READY,
}
}

def __init__(self, init_state, on_change=None):
self.state = init_state
self._on_change = on_change

def transition(self, message, metadata):
if metadata.get("has_more"):
return
state_before = self.state
self.state = self._STATE_TRANSITIONS\
.get(self.state, {})\
.get(message, self.state)
if state_before != self.state and callable(self._on_change):
self._on_change(state_before, self.state)


class Bolt3(Bolt):
""" Protocol handler for Bolt 3.

Expand All @@ -64,6 +112,25 @@ class Bolt3(Bolt):

supports_multiple_databases = False

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._server_state_manager = ServerStateManager(
ServerStates.CONNECTED, on_change=self._on_server_state_change
)

def _on_server_state_change(self, old_state, new_state):
log.debug("[#%04X] State: %s > %s", self.local_port,
old_state.name, new_state.name)

@property
def is_reset(self):
if self.responses:
# We can't be sure of the server's state as there are still pending
# responses. Unless the last message we sent was RESET. In that case
# the server state will always be READY when we're done.
return self.responses[-1].message == "reset"
return self._server_state_manager.state == ServerStates.READY

@property
def encrypted(self):
return isinstance(self.socket, SSLSocket)
Expand Down Expand Up @@ -92,7 +159,8 @@ def hello(self):
logged_headers["credentials"] = "*******"
log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers)
self._append(b"\x01", (headers,),
response=InitResponse(self, on_success=self.server_info.update))
response=InitResponse(self, "hello",
on_success=self.server_info.update))
self.send_all()
self.fetch_all()
check_supported_server_product(self.server_info.agent)
Expand Down Expand Up @@ -155,21 +223,20 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None,
fields = (query, parameters, extra)
log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields)))
if query.upper() == u"COMMIT":
self._append(b"\x10", fields, CommitResponse(self, **handlers))
self._append(b"\x10", fields, CommitResponse(self, "run",
**handlers))
else:
self._append(b"\x10", fields, Response(self, **handlers))
self._is_reset = False
self._append(b"\x10", fields, Response(self, "run", **handlers))

def discard(self, n=-1, qid=-1, **handlers):
# Just ignore n and qid, it is not supported in the Bolt 3 Protocol.
log.debug("[#%04X] C: DISCARD_ALL", self.local_port)
self._append(b"\x2F", (), Response(self, **handlers))
self._append(b"\x2F", (), Response(self, "discard", **handlers))

def pull(self, n=-1, qid=-1, **handlers):
# Just ignore n and qid, it is not supported in the Bolt 3 Protocol.
log.debug("[#%04X] C: PULL_ALL", self.local_port)
self._append(b"\x3F", (), Response(self, **handlers))
self._is_reset = False
self._append(b"\x3F", (), Response(self, "pull", **handlers))

def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers):
if db is not None:
Expand All @@ -193,16 +260,15 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None,
except TypeError:
raise TypeError("Timeout must be specified as a number of seconds")
log.debug("[#%04X] C: BEGIN %r", self.local_port, extra)
self._append(b"\x11", (extra,), Response(self, **handlers))
self._is_reset = False
self._append(b"\x11", (extra,), Response(self, "begin", **handlers))

def commit(self, **handlers):
log.debug("[#%04X] C: COMMIT", self.local_port)
self._append(b"\x12", (), CommitResponse(self, **handlers))
self._append(b"\x12", (), CommitResponse(self, "commit", **handlers))

def rollback(self, **handlers):
log.debug("[#%04X] C: ROLLBACK", self.local_port)
self._append(b"\x13", (), Response(self, **handlers))
self._append(b"\x13", (), Response(self, "rollback", **handlers))

def reset(self):
""" Add a RESET message to the outgoing queue, send
Expand All @@ -213,10 +279,9 @@ def fail(metadata):
raise BoltProtocolError("RESET failed %r" % metadata, address=self.unresolved_address)

log.debug("[#%04X] C: RESET", self.local_port)
self._append(b"\x0F", response=Response(self, on_failure=fail))
self._append(b"\x0F", response=Response(self, "reset", on_failure=fail))
self.send_all()
self.fetch_all()
self._is_reset = True

def fetch_message(self):
""" Receive at most one message from the server, if available.
Expand Down Expand Up @@ -249,12 +314,15 @@ def fetch_message(self):
response.complete = True
if summary_signature == b"\x70":
log.debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata)
self._server_state_manager.transition(response.message,
summary_metadata)
response.on_success(summary_metadata or {})
elif summary_signature == b"\x7E":
log.debug("[#%04X] S: IGNORED", self.local_port)
response.on_ignored(summary_metadata or {})
elif summary_signature == b"\x7F":
log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata)
self._server_state_manager.state = ServerStates.FAILED
try:
response.on_failure(summary_metadata or {})
except (ServiceUnavailable, DatabaseUnavailable):
Expand Down
59 changes: 43 additions & 16 deletions neo4j/io/_bolt4.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from logging import getLogger
from ssl import SSLSocket

Expand All @@ -37,7 +38,6 @@
Neo4jError,
NotALeader,
ServiceUnavailable,
SessionExpired,
)
from neo4j.io import (
Bolt,
Expand All @@ -48,6 +48,10 @@
InitResponse,
Response,
)
from neo4j.io._bolt3 import (
ServerStateManager,
ServerStates,
)


log = getLogger("neo4j")
Expand All @@ -65,6 +69,25 @@ class Bolt4x0(Bolt):

supports_multiple_databases = True

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._server_state_manager = ServerStateManager(
ServerStates.CONNECTED, on_change=self._on_server_state_change
)

def _on_server_state_change(self, old_state, new_state):
log.debug("[#%04X] State: %s > %s", self.local_port,
old_state.name, new_state.name)

@property
def is_reset(self):
if self.responses:
# We can't be sure of the server's state as there are still pending
# responses. Unless the last message we sent was RESET. In that case
# the server state will always be READY when we're done.
return self.responses[-1].message == "reset"
return self._server_state_manager.state == ServerStates.READY

@property
def encrypted(self):
return isinstance(self.socket, SSLSocket)
Expand Down Expand Up @@ -93,7 +116,8 @@ def hello(self):
logged_headers["credentials"] = "*******"
log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers)
self._append(b"\x01", (headers,),
response=InitResponse(self, on_success=self.server_info.update))
response=InitResponse(self, "hello",
on_success=self.server_info.update))
self.send_all()
self.fetch_all()
check_supported_server_product(self.server_info.agent)
Expand Down Expand Up @@ -162,25 +186,24 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None,
fields = (query, parameters, extra)
log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields)))
if query.upper() == u"COMMIT":
self._append(b"\x10", fields, CommitResponse(self, **handlers))
self._append(b"\x10", fields, CommitResponse(self, "run",
**handlers))
else:
self._append(b"\x10", fields, Response(self, **handlers))
self._is_reset = False
self._append(b"\x10", fields, Response(self, "run", **handlers))

def discard(self, n=-1, qid=-1, **handlers):
extra = {"n": n}
if qid != -1:
extra["qid"] = qid
log.debug("[#%04X] C: DISCARD %r", self.local_port, extra)
self._append(b"\x2F", (extra,), Response(self, **handlers))
self._append(b"\x2F", (extra,), Response(self, "discard", **handlers))

def pull(self, n=-1, qid=-1, **handlers):
extra = {"n": n}
if qid != -1:
extra["qid"] = qid
log.debug("[#%04X] C: PULL %r", self.local_port, extra)
self._append(b"\x3F", (extra,), Response(self, **handlers))
self._is_reset = False
self._append(b"\x3F", (extra,), Response(self, "pull", **handlers))

def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
db=None, **handlers):
Expand All @@ -205,16 +228,15 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
except TypeError:
raise TypeError("Timeout must be specified as a number of seconds")
log.debug("[#%04X] C: BEGIN %r", self.local_port, extra)
self._append(b"\x11", (extra,), Response(self, **handlers))
self._is_reset = False
self._append(b"\x11", (extra,), Response(self, "begin", **handlers))

def commit(self, **handlers):
log.debug("[#%04X] C: COMMIT", self.local_port)
self._append(b"\x12", (), CommitResponse(self, **handlers))
self._append(b"\x12", (), CommitResponse(self, "commit", **handlers))

def rollback(self, **handlers):
log.debug("[#%04X] C: ROLLBACK", self.local_port)
self._append(b"\x13", (), Response(self, **handlers))
self._append(b"\x13", (), Response(self, "rollback", **handlers))

def reset(self):
""" Add a RESET message to the outgoing queue, send
Expand All @@ -225,10 +247,9 @@ def fail(metadata):
raise BoltProtocolError("RESET failed %r" % metadata, self.unresolved_address)

log.debug("[#%04X] C: RESET", self.local_port)
self._append(b"\x0F", response=Response(self, on_failure=fail))
self._append(b"\x0F", response=Response(self, "reset", on_failure=fail))
self.send_all()
self.fetch_all()
self._is_reset = True

def fetch_message(self):
""" Receive at most one message from the server, if available.
Expand Down Expand Up @@ -261,12 +282,15 @@ def fetch_message(self):
response.complete = True
if summary_signature == b"\x70":
log.debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata)
self._server_state_manager.transition(response.message,
summary_metadata)
response.on_success(summary_metadata or {})
elif summary_signature == b"\x7E":
log.debug("[#%04X] S: IGNORED", self.local_port)
response.on_ignored(summary_metadata or {})
elif summary_signature == b"\x7F":
log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata)
self._server_state_manager.state = ServerStates.FAILED
try:
response.on_failure(summary_metadata or {})
except (ServiceUnavailable, DatabaseUnavailable):
Expand Down Expand Up @@ -372,7 +396,9 @@ def fail(md):
else:
bookmarks = list(bookmarks)
self._append(b"\x66", (routing_context, bookmarks, database),
response=Response(self, on_success=metadata.update, on_failure=fail))
response=Response(self, "route",
on_success=metadata.update,
on_failure=fail))
self.send_all()
self.fetch_all()
return [metadata.get("rt")]
Expand Down Expand Up @@ -400,7 +426,8 @@ def on_success(metadata):
logged_headers["credentials"] = "*******"
log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers)
self._append(b"\x01", (headers,),
response=InitResponse(self, on_success=on_success))
response=InitResponse(self, "hello",
on_success=on_success))
self.send_all()
self.fetch_all()
check_supported_server_product(self.server_info.agent)
Expand Down
3 changes: 2 additions & 1 deletion neo4j/io/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,10 @@ class Response:
more detail messages followed by one summary message).
"""

def __init__(self, connection, **handlers):
def __init__(self, connection, message, **handlers):
self.connection = connection
self.handlers = handlers
self.message = message
self.complete = False

def on_records(self, records):
Expand Down
6 changes: 2 additions & 4 deletions testkitbackend/test_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,12 @@
"stub.session_run_parameters.test_session_run_parameters.TestSessionRunParameters.test_empty_query":
"Driver rejects empty queries before sending it to the server",
"tls.tlsversions.TestTlsVersions.test_1_1":
"TLSv1.1 and below are disabled in the driver",
"stub.disconnects.test_disconnects.TestDisconnects.test_fail_on_reset":
"Driver silently ignores all errors on releasing connections back into the pool."
"TLSv1.1 and below are disabled in the driver"
},
"features": {
"AuthorizationExpiredTreatment": true,
"Optimization:ImplicitDefaultArguments": true,
"Optimization:MinimalResets": "Driver resets some clean connections when put back into pool",
"Optimization:MinimalResets": true,
"Optimization:ConnectionReuse": true,
"Optimization:PullPipelining": true,
"ConfHint:connection.recv_timeout_seconds": true,
Expand Down