Skip to content

Commit 3d18149

Browse files
committed
Proper reset handling
1 parent 30fb5e5 commit 3d18149

File tree

6 files changed

+152
-60
lines changed

6 files changed

+152
-60
lines changed

neo4j/__main__.py

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -26,53 +26,10 @@
2626
from json import loads as json_loads
2727
from sys import stdout, stderr
2828

29+
from .util import Watcher
2930
from .v1.session import GraphDatabase, CypherError
3031

3132

32-
class ColourFormatter(logging.Formatter):
33-
""" Colour formatter for pretty log output.
34-
"""
35-
36-
def format(self, record):
37-
s = super(ColourFormatter, self).format(record)
38-
if record.levelno == logging.CRITICAL:
39-
return "\x1b[31;1m%s\x1b[0m" % s # bright red
40-
elif record.levelno == logging.ERROR:
41-
return "\x1b[33;1m%s\x1b[0m" % s # bright yellow
42-
elif record.levelno == logging.WARNING:
43-
return "\x1b[33m%s\x1b[0m" % s # yellow
44-
elif record.levelno == logging.INFO:
45-
return "\x1b[36m%s\x1b[0m" % s # cyan
46-
elif record.levelno == logging.DEBUG:
47-
return "\x1b[34m%s\x1b[0m" % s # blue
48-
else:
49-
return s
50-
51-
52-
class Watcher(object):
53-
""" Log watcher for debug output.
54-
"""
55-
56-
handlers = {}
57-
58-
def __init__(self, logger_name):
59-
super(Watcher, self).__init__()
60-
self.logger_name = logger_name
61-
self.logger = logging.getLogger(self.logger_name)
62-
self.formatter = ColourFormatter("%(asctime)s %(message)s")
63-
64-
def watch(self, level=logging.INFO, out=stdout):
65-
try:
66-
self.logger.removeHandler(self.handlers[self.logger_name])
67-
except KeyError:
68-
pass
69-
handler = logging.StreamHandler(out)
70-
handler.setFormatter(self.formatter)
71-
self.handlers[self.logger_name] = handler
72-
self.logger.addHandler(handler)
73-
self.logger.setLevel(level)
74-
75-
7633
def main():
7734
parser = ArgumentParser(description="Execute one or more Cypher statements using Bolt.")
7835
parser.add_argument("statement", nargs="+")

neo4j/util.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#!/usr/bin/env python
2+
# -*- encoding: utf-8 -*-
3+
4+
# Copyright (c) 2002-2016 "Neo Technology,"
5+
# Network Engine for Objects in Lund AB [http://neotechnology.com]
6+
#
7+
# This file is part of Neo4j.
8+
#
9+
# Licensed under the Apache License, Version 2.0 (the "License");
10+
# you may not use this file except in compliance with the License.
11+
# You may obtain a copy of the License at
12+
#
13+
# http://www.apache.org/licenses/LICENSE-2.0
14+
#
15+
# Unless required by applicable law or agreed to in writing, software
16+
# distributed under the License is distributed on an "AS IS" BASIS,
17+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
# See the License for the specific language governing permissions and
19+
# limitations under the License.
20+
21+
22+
from __future__ import unicode_literals
23+
24+
import logging
25+
from argparse import ArgumentParser
26+
from json import loads as json_loads
27+
from sys import stdout, stderr
28+
29+
from .v1.session import GraphDatabase, CypherError
30+
31+
32+
class ColourFormatter(logging.Formatter):
33+
""" Colour formatter for pretty log output.
34+
"""
35+
36+
def format(self, record):
37+
s = super(ColourFormatter, self).format(record)
38+
if record.levelno == logging.CRITICAL:
39+
return "\x1b[31;1m%s\x1b[0m" % s # bright red
40+
elif record.levelno == logging.ERROR:
41+
return "\x1b[33;1m%s\x1b[0m" % s # bright yellow
42+
elif record.levelno == logging.WARNING:
43+
return "\x1b[33m%s\x1b[0m" % s # yellow
44+
elif record.levelno == logging.INFO:
45+
return "\x1b[36m%s\x1b[0m" % s # cyan
46+
elif record.levelno == logging.DEBUG:
47+
return "\x1b[34m%s\x1b[0m" % s # blue
48+
else:
49+
return s
50+
51+
52+
class Watcher(object):
53+
""" Log watcher for debug output.
54+
"""
55+
56+
handlers = {}
57+
58+
def __init__(self, logger_name):
59+
super(Watcher, self).__init__()
60+
self.logger_name = logger_name
61+
self.logger = logging.getLogger(self.logger_name)
62+
self.formatter = ColourFormatter("%(asctime)s %(message)s")
63+
64+
def watch(self, level=logging.INFO, out=stdout):
65+
self.stop()
66+
handler = logging.StreamHandler(out)
67+
handler.setFormatter(self.formatter)
68+
self.handlers[self.logger_name] = handler
69+
self.logger.addHandler(handler)
70+
self.logger.setLevel(level)
71+
72+
def stop(self):
73+
try:
74+
self.logger.removeHandler(self.handlers[self.logger_name])
75+
except KeyError:
76+
pass

neo4j/v1/connection.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,10 @@ def on_ignored(self, metadata=None):
200200
pass
201201

202202

203+
class Completable(object):
204+
complete = False
205+
206+
203207
class Connection(object):
204208
""" Server connection through which all protocol messages
205209
are sent and received. This class is designed for protocol
@@ -246,17 +250,22 @@ def append(self, signature, fields=(), response=None):
246250
self.channel.flush(end_of_message=True)
247251
self.responses.append(response)
248252

249-
def append_reset(self):
250-
""" Add a RESET message to the outgoing queue.
253+
def reset(self):
254+
""" Add a RESET message to the outgoing queue, send
255+
it and consume all remaining messages.
251256
"""
257+
response = Response(self)
252258

253259
def on_failure(metadata):
254260
raise ProtocolError("Reset failed")
255261

256-
response = Response(self)
257262
response.on_failure = on_failure
258263

259264
self.append(RESET, response=response)
265+
self.send()
266+
fetch_next = self.fetch_next
267+
while not response.complete:
268+
fetch_next()
260269

261270
def send(self):
262271
""" Send all queued messages to the server.
@@ -280,18 +289,18 @@ def fetch_next(self):
280289
for signature, fields in unpack():
281290
if __debug__:
282291
log_info("S: %s %s", message_names[signature], " ".join(map(repr, fields)))
292+
if signature in SUMMARY:
293+
response.complete = True
294+
self.responses.popleft()
295+
if signature == FAILURE:
296+
self.reset()
283297
handler_name = "on_%s" % message_names[signature].lower()
284298
try:
285299
handler = getattr(response, handler_name)
286300
except AttributeError:
287301
pass
288302
else:
289303
handler(*fields)
290-
if signature in SUMMARY:
291-
response.complete = True
292-
self.responses.popleft()
293-
if signature == FAILURE:
294-
self.append_reset()
295304
raw.close()
296305

297306
def close(self):

neo4j/v1/session.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,18 @@ def session(self):
9797
""" Create a new session based on the graph database details
9898
specified within this driver:
9999
100+
>>> from neo4j.v1 import GraphDatabase
101+
>>> driver = GraphDatabase.driver("bolt://localhost")
100102
>>> session = driver.session()
101103
102104
"""
103105
try:
104106
session = self.sessions.pop()
105-
session.reset()
106107
except IndexError:
107-
return Session(self)
108+
session = Session(self)
109+
else:
110+
session.reset()
111+
return session
108112

109113

110114
class Result(list):
@@ -123,7 +127,7 @@ def __init__(self, session, statement, parameters):
123127
self.statement = statement
124128
self.parameters = parameters
125129
self.keys = None
126-
self.complete = False
130+
self.more = True
127131
self.summary = None
128132
self.bench_test = None
129133

@@ -142,7 +146,7 @@ def on_record(self, values):
142146
def on_footer(self, metadata):
143147
""" Called on receipt of the result footer.
144148
"""
145-
self.complete = True
149+
self.more = False
146150
self.summary = ResultSummary(self.statement, self.parameters, **metadata)
147151
if self.bench_test:
148152
self.bench_test.end_recv = perf_counter()
@@ -157,7 +161,7 @@ def consume(self):
157161
callback functions.
158162
"""
159163
fetch_next = self.session.connection.fetch_next
160-
while not self.complete:
164+
while self.more:
161165
fetch_next()
162166

163167
def summarize(self):
@@ -340,6 +344,7 @@ def __init__(self, driver):
340344
self.connection = connect(driver.host, driver.port, **driver.config)
341345
self.transaction = None
342346
self.bench_tests = []
347+
self.closed = False
343348

344349
def __del__(self):
345350
self.connection.close()
@@ -353,7 +358,7 @@ def __exit__(self, exc_type, exc_value, traceback):
353358
def reset(self):
354359
""" Reset the connection so it can be reused from a clean state.
355360
"""
356-
self.connection.append_reset()
361+
self.connection.reset()
357362

358363
def run(self, statement, parameters=None):
359364
""" Run a parameterised Cypher statement.
@@ -407,9 +412,12 @@ def run(self, statement, parameters=None):
407412
return result
408413

409414
def close(self):
410-
""" Return this session to the driver pool it came from.
415+
""" If still usable, return this session to the driver pool it came from.
411416
"""
412-
self.driver.sessions.appendleft(self)
417+
self.reset()
418+
if not self.connection.defunct:
419+
self.driver.sessions.appendleft(self)
420+
self.closed = True
413421

414422
def begin_transaction(self):
415423
""" Create a new :class:`.Transaction` within this session.
@@ -487,6 +495,7 @@ def close(self):
487495
self.closed = True
488496
self.session.transaction = None
489497

498+
490499
class Record(object):
491500
""" Record is an ordered collection of fields.
492501

test/test_session.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from neo4j.v1.session import GraphDatabase, CypherError, Record, record
2525
from neo4j.v1.typesystem import Node, Relationship, Path
26+
from test.util import watch
2627

2728

2829
class RunTestCase(TestCase):

test/util.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#!/usr/bin/env python
2+
# -*- encoding: utf-8 -*-
3+
4+
# Copyright (c) 2002-2016 "Neo Technology,"
5+
# Network Engine for Objects in Lund AB [http://neotechnology.com]
6+
#
7+
# This file is part of Neo4j.
8+
#
9+
# Licensed under the Apache License, Version 2.0 (the "License");
10+
# you may not use this file except in compliance with the License.
11+
# You may obtain a copy of the License at
12+
#
13+
# http://www.apache.org/licenses/LICENSE-2.0
14+
#
15+
# Unless required by applicable law or agreed to in writing, software
16+
# distributed under the License is distributed on an "AS IS" BASIS,
17+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
# See the License for the specific language governing permissions and
19+
# limitations under the License.
20+
21+
22+
import functools
23+
24+
from neo4j.util import Watcher
25+
26+
27+
def watch(f):
28+
""" Decorator to enable log watching for the lifetime of a function.
29+
Useful for debugging unit tests.
30+
31+
:param f: the function to decorate
32+
:return: a decorated function
33+
"""
34+
@functools.wraps(f)
35+
def wrapper(*args, **kwargs):
36+
watcher = Watcher("neo4j")
37+
watcher.watch()
38+
f(*args, **kwargs)
39+
watcher.stop()
40+
return wrapper

0 commit comments

Comments
 (0)