Skip to content

Discard transaction on disconnect #518

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 6, 2021
Merged
58 changes: 25 additions & 33 deletions neo4j/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,60 +38,52 @@
from logging import getLogger
from random import choice
from select import select
from time import perf_counter

from socket import (
AF_INET,
AF_INET6,
SHUT_RDWR,
SO_KEEPALIVE,
socket,
SOL_SOCKET,
SO_KEEPALIVE,
SHUT_RDWR,
timeout as SocketTimeout,
AF_INET,
AF_INET6,
)

from ssl import (
HAS_SNI,
SSLError,
)

from struct import (
pack as struct_pack,
)

from threading import (
Condition,
Lock,
RLock,
Condition,
)
from time import perf_counter

from neo4j.addressing import Address
from neo4j.conf import PoolConfig
from neo4j._exceptions import (
BoltHandshakeError,
BoltProtocolError,
BoltRoutingError,
BoltSecurityError,
BoltProtocolError,
BoltHandshakeError,
)
from neo4j.exceptions import (
ServiceUnavailable,
ClientError,
SessionExpired,
ReadServiceUnavailable,
WriteServiceUnavailable,
ConfigurationError,
UnsupportedServerProduct,
from neo4j.addressing import Address
from neo4j.api import (
READ_ACCESS,
Version,
WRITE_ACCESS,
)
from neo4j.routing import RoutingTable
from neo4j.conf import (
PoolConfig,
WorkspaceConfig,
)
from neo4j.api import (
READ_ACCESS,
WRITE_ACCESS,
Version,
from neo4j.exceptions import (
ClientError,
ConfigurationError,
ReadServiceUnavailable,
ServiceUnavailable,
SessionExpired,
UnsupportedServerProduct,
WriteServiceUnavailable,
)
from neo4j.routing import RoutingTable

# Set up logger
log = getLogger("neo4j")
Expand Down Expand Up @@ -258,7 +250,7 @@ def open(cls, address, *, auth=None, timeout=None, routing_context=None, **pool_
except Exception as error:
log.debug("[#%04X] C: <CLOSE> %s", s.getsockname()[1], str(error))
_close_socket(s)
raise error
raise

return connection

Expand Down Expand Up @@ -522,7 +514,7 @@ def deactivate(self, address):
connections.remove(conn)
try:
conn.close()
except IOError:
except OSError:
pass
if not connections:
self.remove(address)
Expand All @@ -538,7 +530,7 @@ def remove(self, address):
for connection in self.connections.pop(address, ()):
try:
connection.close()
except IOError:
except OSError:
pass

def close(self):
Expand Down
97 changes: 47 additions & 50 deletions neo4j/io/_bolt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,44 +19,49 @@
# limitations under the License.

from collections import deque
from logging import getLogger
from ssl import SSLSocket
from time import perf_counter

from neo4j._exceptions import (
BoltError,
BoltProtocolError,
)
from neo4j.addressing import Address
from neo4j.api import (
Version,
READ_ACCESS,
ServerInfo,
Version,
)
from neo4j.io._common import (
Inbox,
Outbox,
Response,
InitResponse,
CommitResponse,
)
from neo4j.meta import get_user_agent
from neo4j.exceptions import (
AuthError,
DatabaseUnavailable,
ConfigurationError,
DatabaseUnavailable,
DriverError,
ForbiddenOnReadOnlyDatabase,
IncompleteCommit,
NotALeader,
ServiceUnavailable,
SessionExpired,
)
from neo4j._exceptions import BoltProtocolError
from neo4j.packstream import (
Unpacker,
Packer,
)
from neo4j.io import (
check_supported_server_product,
Bolt,
BoltPool,
check_supported_server_product,
)
from neo4j.api import ServerInfo
from neo4j.addressing import Address
from neo4j.io._common import (
CommitResponse,
Inbox,
InitResponse,
Outbox,
Response,
)
from neo4j.meta import get_user_agent
from neo4j.packstream import (
Packer,
Unpacker,
)

from logging import getLogger
log = getLogger("neo4j")


Expand Down Expand Up @@ -85,7 +90,7 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=No
self.socket = sock
self.server_info = ServerInfo(Address(sock.getpeername()), self.PROTOCOL_VERSION)
self.outbox = Outbox()
self.inbox = Inbox(self.socket, on_error=self._set_defunct)
self.inbox = Inbox(self.socket, on_error=self._set_defunct_read)
self.packer = Packer(self.outbox)
self.unpacker = Unpacker(self.inbox)
self.responses = deque()
Expand Down Expand Up @@ -135,7 +140,7 @@ def der_encoded_server_certificate(self):
def local_port(self):
try:
return self.socket.getsockname()[1]
except IOError:
except OSError:
return 0

def get_base_headers(self):
Expand Down Expand Up @@ -292,7 +297,10 @@ def fail(metadata):
def _send_all(self):
data = self.outbox.view()
if data:
self.socket.sendall(data)
try:
self.socket.sendall(data)
except OSError as error:
self._set_defunct_write(error)
self.outbox.clear()

def send_all(self):
Expand All @@ -306,17 +314,7 @@ def send_all(self):
raise ServiceUnavailable("Failed to write to defunct connection {!r} ({!r})".format(
self.unresolved_address, self.server_info.address))

try:
self._send_all()
except (IOError, OSError) as error:
log.error("Failed to write data to connection "
"{!r} ({!r}); ({!r})".
format(self.unresolved_address,
self.server_info.address,
"; ".join(map(repr, error.args))))
if self.pool:
self.pool.deactivate(address=self.unresolved_address)
raise
self._send_all()

def fetch_message(self):
""" Receive at least one message from the server, if available.
Expand All @@ -336,17 +334,7 @@ def fetch_message(self):
return 0, 0

# Receive exactly one message
try:
details, summary_signature, summary_metadata = next(self.inbox)
except (IOError, OSError) as error:
log.error("Failed to read data from connection "
"{!r} ({!r}); ({!r})".
format(self.unresolved_address,
self.server_info.address,
"; ".join(map(repr, error.args))))
if self.pool:
self.pool.deactivate(address=self.unresolved_address)
raise
details, summary_signature, summary_metadata = next(self.inbox)

if details:
log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # Do not log any data
Expand Down Expand Up @@ -380,11 +368,20 @@ def fetch_message(self):

return len(details), 1

def _set_defunct(self, error=None):
direct_driver = isinstance(self.pool, BoltPool)
def _set_defunct_read(self, error=None):
message = "Failed to read from defunct connection {!r} ({!r})".format(
self.unresolved_address, self.server_info.address
)
self._set_defunct(message, error=error)

message = ("Failed to read from defunct connection {!r} ({!r})".format(
self.unresolved_address, self.server_info.address))
def _set_defunct_write(self, error=None):
message = "Failed to write data to connection {!r} ({!r})".format(
self.unresolved_address, self.server_info.address
)
self._set_defunct(message, error=error)

def _set_defunct(self, message, error=None):
direct_driver = isinstance(self.pool, BoltPool)

if error:
log.error(str(error))
Expand Down Expand Up @@ -445,12 +442,12 @@ def close(self):
self._append(b"\x02", ())
try:
self._send_all()
except:
except (OSError, BoltError, DriverError):
pass
log.debug("[#%04X] C: <CLOSE>", self.local_port)
try:
self.socket.close()
except IOError:
except OSError:
pass
finally:
self._closed = True
Expand Down
Loading