Skip to content

Commit c74f0a4

Browse files
committed
Almost working but needs bleeding edge server (DO NOT MERGE YET!!)
1 parent ca6d757 commit c74f0a4

File tree

3 files changed

+54
-12
lines changed

3 files changed

+54
-12
lines changed

neo4j/v1/connection.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242

4343
# Signature bytes for each message type
4444
INIT = b"\x01" # 0000 0001 // INIT <user_agent>
45-
ACK_FAILURE = b"\x0F" # 0000 1111 // ACK_FAILURE
45+
RESET = b"\x0F" # 0000 1111 // RESET
4646
RUN = b"\x10" # 0001 0000 // RUN <statement> <parameters>
4747
DISCARD_ALL = b"\x2F" # 0010 1111 // DISCARD *
4848
PULL_ALL = b"\x3F" # 0011 1111 // PULL *
@@ -56,7 +56,7 @@
5656

5757
message_names = {
5858
INIT: "INIT",
59-
ACK_FAILURE: "ACK_FAILURE",
59+
RESET: "RESET",
6060
RUN: "RUN",
6161
DISCARD_ALL: "DISCARD_ALL",
6262
PULL_ALL: "PULL_ALL",
@@ -200,12 +200,6 @@ def on_ignored(self, metadata=None):
200200
pass
201201

202202

203-
class AckFailureResponse(Response):
204-
205-
def on_failure(self, metadata):
206-
raise ProtocolError("Could not acknowledge failure")
207-
208-
209203
class Connection(object):
210204
""" Server connection through which all protocol messages
211205
are sent and received. This class is designed for protocol
@@ -215,6 +209,7 @@ class Connection(object):
215209
"""
216210

217211
def __init__(self, sock, **config):
212+
self.defunct = False
218213
self.channel = ChunkChannel(sock)
219214
self.packer = Packer(self.channel)
220215
self.responses = deque()
@@ -237,6 +232,10 @@ def on_failure(metadata):
237232

238233
def append(self, signature, fields=(), response=None):
239234
""" Add a message to the outgoing queue.
235+
236+
:arg signature: the signature of the message
237+
:arg fields: the fields of the message as a tuple
238+
:arg response: a response object to handle callbacks
240239
"""
241240
if __debug__:
242241
log_info("C: %s %s", message_names[signature], " ".join(map(repr, fields)))
@@ -247,6 +246,18 @@ def append(self, signature, fields=(), response=None):
247246
self.channel.flush(end_of_message=True)
248247
self.responses.append(response)
249248

249+
def append_reset(self):
250+
""" Add a RESET message to the outgoing queue.
251+
"""
252+
253+
def on_failure(metadata):
254+
raise ProtocolError("Reset failed")
255+
256+
response = Response(self)
257+
response.on_failure = on_failure
258+
259+
self.append(RESET, response=response)
260+
250261
def send(self):
251262
""" Send all queued messages to the server.
252263
"""
@@ -257,8 +268,12 @@ def fetch_next(self):
257268
"""
258269
raw = BytesIO()
259270
unpack = Unpacker(raw).unpack
260-
raw.writelines(self.channel.chunk_reader())
261-
271+
try:
272+
raw.writelines(self.channel.chunk_reader())
273+
except ProtocolError:
274+
self.defunct = True
275+
self.close()
276+
return
262277
# Unpack from the raw byte stream and call the relevant message handler(s)
263278
raw.seek(0)
264279
response = self.responses[0]
@@ -276,7 +291,7 @@ def fetch_next(self):
276291
response.complete = True
277292
self.responses.popleft()
278293
if signature == FAILURE:
279-
self.append(ACK_FAILURE, response=AckFailureResponse(self))
294+
self.append_reset()
280295
raw.close()
281296

282297
def close(self):

neo4j/v1/session.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ def session(self):
101101
102102
"""
103103
try:
104-
return self.sessions.pop()
104+
session = self.sessions.pop()
105+
session.reset()
105106
except IndexError:
106107
return Session(self)
107108

@@ -349,6 +350,11 @@ def __enter__(self):
349350
def __exit__(self, exc_type, exc_value, traceback):
350351
self.close()
351352

353+
def reset(self):
354+
""" Reset the connection so it can be reused from a clean state.
355+
"""
356+
self.connection.append_reset()
357+
352358
def run(self, statement, parameters=None):
353359
""" Run a parameterised Cypher statement.
354360

test/test_session.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,27 @@ def test_can_obtain_notification_info(self):
220220
assert position.column == 1
221221

222222

223+
class ResetTestCase(TestCase):
224+
225+
def test_explicit_reset(self):
226+
with GraphDatabase.driver("bolt://localhost").session() as session:
227+
result = session.run("RETURN 1")
228+
assert result[0][0] == 1
229+
session.reset()
230+
result = session.run("RETURN 1")
231+
assert result[0][0] == 1
232+
233+
def test_automatic_reset_after_failure(self):
234+
with GraphDatabase.driver("bolt://localhost").session() as session:
235+
try:
236+
session.run("X")
237+
except CypherError:
238+
result = session.run("RETURN 1")
239+
assert result[0][0] == 1
240+
else:
241+
assert False, "A Cypher error should have occurred"
242+
243+
223244
class RecordTestCase(TestCase):
224245
def test_record_equality(self):
225246
record1 = Record(["name","empire"], ["Nigel", "The British Empire"])

0 commit comments

Comments
 (0)