From 57ac783bb8ebab629a6e2fe76c7d6bdb399624ca Mon Sep 17 00:00:00 2001 From: Nigel Small Date: Sat, 2 Jan 2016 10:46:44 +0000 Subject: [PATCH 1/8] Session pool --- neo4j/v1/session.py | 22 +++++++++++++++------- test/test_session.py | 16 ++++++++++++++++ 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/neo4j/v1/session.py b/neo4j/v1/session.py index 2fda936d..de96a75b 100644 --- a/neo4j/v1/session.py +++ b/neo4j/v1/session.py @@ -28,7 +28,7 @@ class which can be used to obtain `Driver` instances that are used for from __future__ import division -from collections import namedtuple +from collections import deque, namedtuple from .compat import integer, perf_counter, string, urlparse from .connection import connect, Response, RUN, PULL_ALL @@ -91,15 +91,19 @@ def __init__(self, url, **config): else: raise ValueError("Unsupported URL scheme: %s" % parsed.scheme) self.config = config + self.sessions = deque() - def session(self, **config): + def session(self): """ Create a new session based on the graph database details specified within this driver: >>> session = driver.session() """ - return Session(connect(self.host, self.port, **dict(self.config, **config))) + try: + return self.sessions.pop() + except IndexError: + return Session(self) class Result(list): @@ -330,11 +334,15 @@ class Session(object): method. """ - def __init__(self, connection): - self.connection = connection + def __init__(self, driver): + self.driver = driver + self.connection = connect(driver.host, driver.port, **driver.config) self.transaction = None self.bench_tests = [] + def __del__(self): + self.connection.close() + def __enter__(self): return self @@ -393,9 +401,9 @@ def run(self, statement, parameters=None): return result def close(self): - """ Shut down and close the session. + """ Return this session to the driver pool it came from. """ - self.connection.close() + self.driver.sessions.appendleft(self) def begin_transaction(self): """ Create a new :class:`.Transaction` within this session. diff --git a/test/test_session.py b/test/test_session.py index 16e2e19b..92d6aeeb 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -30,6 +30,22 @@ def test_must_use_valid_url_scheme(self): with self.assertRaises(ValueError): GraphDatabase.driver("x://xxx") + def test_sessions_are_reused(self): + driver = GraphDatabase.driver("bolt://localhost") + session_1 = driver.session() + session_1.close() + session_2 = driver.session() + session_2.close() + assert session_1 is session_2 + + def test_sessions_are_not_reused_if_still_in_use(self): + driver = GraphDatabase.driver("bolt://localhost") + session_1 = driver.session() + session_2 = driver.session() + session_2.close() + session_1.close() + assert session_1 is not session_2 + def test_can_run_simple_statement(self): session = GraphDatabase.driver("bolt://localhost").session() count = 0 From ca6d757d583d2a99c1b0bcdf9c68eed7888d50cc Mon Sep 17 00:00:00 2001 From: Nigel Small Date: Fri, 8 Jan 2016 21:02:15 +0000 Subject: [PATCH 2/8] Only run behave if unit tests pass --- runtests.sh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/runtests.sh b/runtests.sh index b109a8b2..38a2923b 100755 --- a/runtests.sh +++ b/runtests.sh @@ -73,12 +73,12 @@ else if [ ${EXIT_STATUS} -eq 0 ] then coverage report --show-missing + python -c 'from test.tck.configure_feature_files import *; set_up()' + echo "Feature files downloaded" + neokit/neorun ${NEORUN_OPTIONS} "${BEHAVE_RUNNER}" ${VERSIONS} + python -c 'from test.tck.configure_feature_files import *; clean_up()' + echo "Feature files removed" fi - python -c 'from test.tck.configure_feature_files import *; set_up()' - echo "Feature files downloaded" - neokit/neorun ${NEORUN_OPTIONS} "${BEHAVE_RUNNER}" ${VERSIONS} - python -c 'from test.tck.configure_feature_files import *; clean_up()' - echo "Feature files removed" fi # Exit correctly From c74f0a4df8ddd39a004739bda55bca079ae8a178 Mon Sep 17 00:00:00 2001 From: Nigel Small Date: Fri, 8 Jan 2016 22:03:11 +0000 Subject: [PATCH 3/8] Almost working but needs bleeding edge server (DO NOT MERGE YET!!) --- neo4j/v1/connection.py | 37 ++++++++++++++++++++++++++----------- neo4j/v1/session.py | 8 +++++++- test/test_session.py | 21 +++++++++++++++++++++ 3 files changed, 54 insertions(+), 12 deletions(-) diff --git a/neo4j/v1/connection.py b/neo4j/v1/connection.py index 3ba6398b..3647fe34 100644 --- a/neo4j/v1/connection.py +++ b/neo4j/v1/connection.py @@ -42,7 +42,7 @@ # Signature bytes for each message type INIT = b"\x01" # 0000 0001 // INIT -ACK_FAILURE = b"\x0F" # 0000 1111 // ACK_FAILURE +RESET = b"\x0F" # 0000 1111 // RESET RUN = b"\x10" # 0001 0000 // RUN DISCARD_ALL = b"\x2F" # 0010 1111 // DISCARD * PULL_ALL = b"\x3F" # 0011 1111 // PULL * @@ -56,7 +56,7 @@ message_names = { INIT: "INIT", - ACK_FAILURE: "ACK_FAILURE", + RESET: "RESET", RUN: "RUN", DISCARD_ALL: "DISCARD_ALL", PULL_ALL: "PULL_ALL", @@ -200,12 +200,6 @@ def on_ignored(self, metadata=None): pass -class AckFailureResponse(Response): - - def on_failure(self, metadata): - raise ProtocolError("Could not acknowledge failure") - - class Connection(object): """ Server connection through which all protocol messages are sent and received. This class is designed for protocol @@ -215,6 +209,7 @@ class Connection(object): """ def __init__(self, sock, **config): + self.defunct = False self.channel = ChunkChannel(sock) self.packer = Packer(self.channel) self.responses = deque() @@ -237,6 +232,10 @@ def on_failure(metadata): def append(self, signature, fields=(), response=None): """ Add a message to the outgoing queue. + + :arg signature: the signature of the message + :arg fields: the fields of the message as a tuple + :arg response: a response object to handle callbacks """ if __debug__: log_info("C: %s %s", message_names[signature], " ".join(map(repr, fields))) @@ -247,6 +246,18 @@ def append(self, signature, fields=(), response=None): self.channel.flush(end_of_message=True) self.responses.append(response) + def append_reset(self): + """ Add a RESET message to the outgoing queue. + """ + + def on_failure(metadata): + raise ProtocolError("Reset failed") + + response = Response(self) + response.on_failure = on_failure + + self.append(RESET, response=response) + def send(self): """ Send all queued messages to the server. """ @@ -257,8 +268,12 @@ def fetch_next(self): """ raw = BytesIO() unpack = Unpacker(raw).unpack - raw.writelines(self.channel.chunk_reader()) - + try: + raw.writelines(self.channel.chunk_reader()) + except ProtocolError: + self.defunct = True + self.close() + return # Unpack from the raw byte stream and call the relevant message handler(s) raw.seek(0) response = self.responses[0] @@ -276,7 +291,7 @@ def fetch_next(self): response.complete = True self.responses.popleft() if signature == FAILURE: - self.append(ACK_FAILURE, response=AckFailureResponse(self)) + self.append_reset() raw.close() def close(self): diff --git a/neo4j/v1/session.py b/neo4j/v1/session.py index de96a75b..ee6d981c 100644 --- a/neo4j/v1/session.py +++ b/neo4j/v1/session.py @@ -101,7 +101,8 @@ def session(self): """ try: - return self.sessions.pop() + session = self.sessions.pop() + session.reset() except IndexError: return Session(self) @@ -349,6 +350,11 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self.close() + def reset(self): + """ Reset the connection so it can be reused from a clean state. + """ + self.connection.append_reset() + def run(self, statement, parameters=None): """ Run a parameterised Cypher statement. diff --git a/test/test_session.py b/test/test_session.py index 92d6aeeb..979b75c7 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -220,6 +220,27 @@ def test_can_obtain_notification_info(self): assert position.column == 1 +class ResetTestCase(TestCase): + + def test_explicit_reset(self): + with GraphDatabase.driver("bolt://localhost").session() as session: + result = session.run("RETURN 1") + assert result[0][0] == 1 + session.reset() + result = session.run("RETURN 1") + assert result[0][0] == 1 + + def test_automatic_reset_after_failure(self): + with GraphDatabase.driver("bolt://localhost").session() as session: + try: + session.run("X") + except CypherError: + result = session.run("RETURN 1") + assert result[0][0] == 1 + else: + assert False, "A Cypher error should have occurred" + + class RecordTestCase(TestCase): def test_record_equality(self): record1 = Record(["name","empire"], ["Nigel", "The British Empire"]) From b5f169bb8902afdbaf666a493c5336ce9cf48be2 Mon Sep 17 00:00:00 2001 From: Nigel Small Date: Mon, 11 Jan 2016 14:28:44 +0000 Subject: [PATCH 4/8] Proper reset handling --- neo4j/__main__.py | 45 +------------------------ neo4j/util.py | 76 ++++++++++++++++++++++++++++++++++++++++++ neo4j/v1/connection.py | 25 +++++++++----- neo4j/v1/session.py | 25 +++++++++----- test/test_session.py | 1 + test/util.py | 40 ++++++++++++++++++++++ 6 files changed, 152 insertions(+), 60 deletions(-) create mode 100644 neo4j/util.py create mode 100644 test/util.py diff --git a/neo4j/__main__.py b/neo4j/__main__.py index 366eb09a..b1c64829 100644 --- a/neo4j/__main__.py +++ b/neo4j/__main__.py @@ -26,53 +26,10 @@ from json import loads as json_loads from sys import stdout, stderr +from .util import Watcher from .v1.session import GraphDatabase, CypherError -class ColourFormatter(logging.Formatter): - """ Colour formatter for pretty log output. - """ - - def format(self, record): - s = super(ColourFormatter, self).format(record) - if record.levelno == logging.CRITICAL: - return "\x1b[31;1m%s\x1b[0m" % s # bright red - elif record.levelno == logging.ERROR: - return "\x1b[33;1m%s\x1b[0m" % s # bright yellow - elif record.levelno == logging.WARNING: - return "\x1b[33m%s\x1b[0m" % s # yellow - elif record.levelno == logging.INFO: - return "\x1b[36m%s\x1b[0m" % s # cyan - elif record.levelno == logging.DEBUG: - return "\x1b[34m%s\x1b[0m" % s # blue - else: - return s - - -class Watcher(object): - """ Log watcher for debug output. - """ - - handlers = {} - - def __init__(self, logger_name): - super(Watcher, self).__init__() - self.logger_name = logger_name - self.logger = logging.getLogger(self.logger_name) - self.formatter = ColourFormatter("%(asctime)s %(message)s") - - def watch(self, level=logging.INFO, out=stdout): - try: - self.logger.removeHandler(self.handlers[self.logger_name]) - except KeyError: - pass - handler = logging.StreamHandler(out) - handler.setFormatter(self.formatter) - self.handlers[self.logger_name] = handler - self.logger.addHandler(handler) - self.logger.setLevel(level) - - def main(): parser = ArgumentParser(description="Execute one or more Cypher statements using Bolt.") parser.add_argument("statement", nargs="+") diff --git a/neo4j/util.py b/neo4j/util.py new file mode 100644 index 00000000..ddaa1e2e --- /dev/null +++ b/neo4j/util.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +# Copyright (c) 2002-2016 "Neo Technology," +# Network Engine for Objects in Lund AB [http://neotechnology.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import unicode_literals + +import logging +from argparse import ArgumentParser +from json import loads as json_loads +from sys import stdout, stderr + +from .v1.session import GraphDatabase, CypherError + + +class ColourFormatter(logging.Formatter): + """ Colour formatter for pretty log output. + """ + + def format(self, record): + s = super(ColourFormatter, self).format(record) + if record.levelno == logging.CRITICAL: + return "\x1b[31;1m%s\x1b[0m" % s # bright red + elif record.levelno == logging.ERROR: + return "\x1b[33;1m%s\x1b[0m" % s # bright yellow + elif record.levelno == logging.WARNING: + return "\x1b[33m%s\x1b[0m" % s # yellow + elif record.levelno == logging.INFO: + return "\x1b[36m%s\x1b[0m" % s # cyan + elif record.levelno == logging.DEBUG: + return "\x1b[34m%s\x1b[0m" % s # blue + else: + return s + + +class Watcher(object): + """ Log watcher for debug output. + """ + + handlers = {} + + def __init__(self, logger_name): + super(Watcher, self).__init__() + self.logger_name = logger_name + self.logger = logging.getLogger(self.logger_name) + self.formatter = ColourFormatter("%(asctime)s %(message)s") + + def watch(self, level=logging.INFO, out=stdout): + self.stop() + handler = logging.StreamHandler(out) + handler.setFormatter(self.formatter) + self.handlers[self.logger_name] = handler + self.logger.addHandler(handler) + self.logger.setLevel(level) + + def stop(self): + try: + self.logger.removeHandler(self.handlers[self.logger_name]) + except KeyError: + pass diff --git a/neo4j/v1/connection.py b/neo4j/v1/connection.py index 3647fe34..268a2e79 100644 --- a/neo4j/v1/connection.py +++ b/neo4j/v1/connection.py @@ -200,6 +200,10 @@ def on_ignored(self, metadata=None): pass +class Completable(object): + complete = False + + class Connection(object): """ Server connection through which all protocol messages are sent and received. This class is designed for protocol @@ -246,17 +250,22 @@ def append(self, signature, fields=(), response=None): self.channel.flush(end_of_message=True) self.responses.append(response) - def append_reset(self): - """ Add a RESET message to the outgoing queue. + def reset(self): + """ Add a RESET message to the outgoing queue, send + it and consume all remaining messages. """ + response = Response(self) def on_failure(metadata): raise ProtocolError("Reset failed") - response = Response(self) response.on_failure = on_failure self.append(RESET, response=response) + self.send() + fetch_next = self.fetch_next + while not response.complete: + fetch_next() def send(self): """ Send all queued messages to the server. @@ -280,6 +289,11 @@ def fetch_next(self): for signature, fields in unpack(): if __debug__: log_info("S: %s %s", message_names[signature], " ".join(map(repr, fields))) + if signature in SUMMARY: + response.complete = True + self.responses.popleft() + if signature == FAILURE: + self.reset() handler_name = "on_%s" % message_names[signature].lower() try: handler = getattr(response, handler_name) @@ -287,11 +301,6 @@ def fetch_next(self): pass else: handler(*fields) - if signature in SUMMARY: - response.complete = True - self.responses.popleft() - if signature == FAILURE: - self.append_reset() raw.close() def close(self): diff --git a/neo4j/v1/session.py b/neo4j/v1/session.py index ee6d981c..cb485709 100644 --- a/neo4j/v1/session.py +++ b/neo4j/v1/session.py @@ -97,14 +97,18 @@ def session(self): """ Create a new session based on the graph database details specified within this driver: + >>> from neo4j.v1 import GraphDatabase + >>> driver = GraphDatabase.driver("bolt://localhost") >>> session = driver.session() """ try: session = self.sessions.pop() - session.reset() except IndexError: - return Session(self) + session = Session(self) + else: + session.reset() + return session class Result(list): @@ -123,7 +127,7 @@ def __init__(self, session, statement, parameters): self.statement = statement self.parameters = parameters self.keys = None - self.complete = False + self.more = True self.summary = None self.bench_test = None @@ -142,7 +146,7 @@ def on_record(self, values): def on_footer(self, metadata): """ Called on receipt of the result footer. """ - self.complete = True + self.more = False self.summary = ResultSummary(self.statement, self.parameters, **metadata) if self.bench_test: self.bench_test.end_recv = perf_counter() @@ -157,7 +161,7 @@ def consume(self): callback functions. """ fetch_next = self.session.connection.fetch_next - while not self.complete: + while self.more: fetch_next() def summarize(self): @@ -340,6 +344,7 @@ def __init__(self, driver): self.connection = connect(driver.host, driver.port, **driver.config) self.transaction = None self.bench_tests = [] + self.closed = False def __del__(self): self.connection.close() @@ -353,7 +358,7 @@ def __exit__(self, exc_type, exc_value, traceback): def reset(self): """ Reset the connection so it can be reused from a clean state. """ - self.connection.append_reset() + self.connection.reset() def run(self, statement, parameters=None): """ Run a parameterised Cypher statement. @@ -407,9 +412,12 @@ def run(self, statement, parameters=None): return result def close(self): - """ Return this session to the driver pool it came from. + """ If still usable, return this session to the driver pool it came from. """ - self.driver.sessions.appendleft(self) + self.reset() + if not self.connection.defunct: + self.driver.sessions.appendleft(self) + self.closed = True def begin_transaction(self): """ Create a new :class:`.Transaction` within this session. @@ -487,6 +495,7 @@ def close(self): self.closed = True self.session.transaction = None + class Record(object): """ Record is an ordered collection of fields. diff --git a/test/test_session.py b/test/test_session.py index 979b75c7..87f88b23 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -23,6 +23,7 @@ from neo4j.v1.session import GraphDatabase, CypherError, Record, record from neo4j.v1.typesystem import Node, Relationship, Path +from test.util import watch class RunTestCase(TestCase): diff --git a/test/util.py b/test/util.py new file mode 100644 index 00000000..19e81332 --- /dev/null +++ b/test/util.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +# Copyright (c) 2002-2016 "Neo Technology," +# Network Engine for Objects in Lund AB [http://neotechnology.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import functools + +from neo4j.util import Watcher + + +def watch(f): + """ Decorator to enable log watching for the lifetime of a function. + Useful for debugging unit tests. + + :param f: the function to decorate + :return: a decorated function + """ + @functools.wraps(f) + def wrapper(*args, **kwargs): + watcher = Watcher("neo4j") + watcher.watch() + f(*args, **kwargs) + watcher.stop() + return wrapper From ad0df8402e096e6e38b6ecd8b3b7f044b93977f7 Mon Sep 17 00:00:00 2001 From: Nigel Small Date: Thu, 14 Jan 2016 10:10:31 +0000 Subject: [PATCH 5/8] Test for defunct (VERY BROKEN STILL) --- neo4j/v1/connection.py | 24 ++++++++++-------------- neo4j/v1/session.py | 12 ++++-------- test/test_session.py | 18 ++++++++++-------- test_requirements.txt | 1 + 4 files changed, 25 insertions(+), 30 deletions(-) diff --git a/neo4j/v1/connection.py b/neo4j/v1/connection.py index 268a2e79..270a1362 100644 --- a/neo4j/v1/connection.py +++ b/neo4j/v1/connection.py @@ -169,14 +169,6 @@ def chunk_reader(self): data = self._recv(chunk_size) yield data - def close(self): - """ Shut down and close the connection. - """ - if __debug__: log_info("~~ [CLOSE]") - socket = self.socket - socket.shutdown(SHUT_RDWR) - socket.close() - class Response(object): """ Subscriber object for a full response (zero or @@ -200,10 +192,6 @@ def on_ignored(self, metadata=None): pass -class Completable(object): - complete = False - - class Connection(object): """ Server connection through which all protocol messages are sent and received. This class is designed for protocol @@ -217,6 +205,7 @@ def __init__(self, sock, **config): self.channel = ChunkChannel(sock) self.packer = Packer(self.channel) self.responses = deque() + self.closed = False # Determine the user agent and ensure it is a Unicode value user_agent = config.get("user_agent", DEFAULT_USER_AGENT) @@ -234,6 +223,9 @@ def on_failure(metadata): while not response.complete: self.fetch_next() + def __del__(self): + self.close() + def append(self, signature, fields=(), response=None): """ Add a message to the outgoing queue. @@ -304,9 +296,13 @@ def fetch_next(self): raw.close() def close(self): - """ Shut down and close the connection. + """ Close the connection. """ - self.channel.close() + if not self.closed: + if __debug__: + log_info("~~ [CLOSE]") + self.channel.socket.close() + self.closed = True def connect(host, port=None, **config): diff --git a/neo4j/v1/session.py b/neo4j/v1/session.py index cb485709..57e95d3d 100644 --- a/neo4j/v1/session.py +++ b/neo4j/v1/session.py @@ -107,7 +107,7 @@ def session(self): except IndexError: session = Session(self) else: - session.reset() + session.connection.reset() return session @@ -347,7 +347,9 @@ def __init__(self, driver): self.closed = False def __del__(self): - self.connection.close() + if not self.closed: + self.connection.close() + self.closed = True def __enter__(self): return self @@ -355,11 +357,6 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self.close() - def reset(self): - """ Reset the connection so it can be reused from a clean state. - """ - self.connection.reset() - def run(self, statement, parameters=None): """ Run a parameterised Cypher statement. @@ -414,7 +411,6 @@ def run(self, statement, parameters=None): def close(self): """ If still usable, return this session to the driver pool it came from. """ - self.reset() if not self.connection.defunct: self.driver.sessions.appendleft(self) self.closed = True diff --git a/test/test_session.py b/test/test_session.py index 87f88b23..85afd724 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -21,6 +21,7 @@ from unittest import TestCase +from mock import patch from neo4j.v1.session import GraphDatabase, CypherError, Record, record from neo4j.v1.typesystem import Node, Relationship, Path from test.util import watch @@ -223,14 +224,6 @@ def test_can_obtain_notification_info(self): class ResetTestCase(TestCase): - def test_explicit_reset(self): - with GraphDatabase.driver("bolt://localhost").session() as session: - result = session.run("RETURN 1") - assert result[0][0] == 1 - session.reset() - result = session.run("RETURN 1") - assert result[0][0] == 1 - def test_automatic_reset_after_failure(self): with GraphDatabase.driver("bolt://localhost").session() as session: try: @@ -241,6 +234,15 @@ def test_automatic_reset_after_failure(self): else: assert False, "A Cypher error should have occurred" + @watch + def test_defunct(self): + from neo4j.v1.connection import ChunkChannel, ProtocolError + with GraphDatabase.driver("bolt://localhost").session() as session: + assert not session.connection.defunct + with patch.object(ChunkChannel, "chunk_reader", side_effect=ProtocolError()): + session.run("RETURN 1") + assert session.connection.defunct + class RecordTestCase(TestCase): def test_record_equality(self): diff --git a/test_requirements.txt b/test_requirements.txt index 85278c30..b5c3dcfd 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -1,3 +1,4 @@ behave coverage +mock teamcity-messages From 02f4ddadabc07bad7c5c7bf66d31bbb4fa4bf172 Mon Sep 17 00:00:00 2001 From: Nigel Small Date: Thu, 14 Jan 2016 16:19:16 +0000 Subject: [PATCH 6/8] Defunct test working :-) --- neo4j/v1/connection.py | 10 +++++++++- neokit | 2 +- test/test_session.py | 5 +++-- test/util.py | 3 ++- 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/neo4j/v1/connection.py b/neo4j/v1/connection.py index 270a1362..70dbca1e 100644 --- a/neo4j/v1/connection.py +++ b/neo4j/v1/connection.py @@ -262,11 +262,19 @@ def on_failure(metadata): def send(self): """ Send all queued messages to the server. """ + if self.closed: + raise ProtocolError("Cannot write to a closed connection") + if self.defunct: + raise ProtocolError("Cannot write to a defunct connection") self.channel.send() def fetch_next(self): """ Receive exactly one message from the server. """ + if self.closed: + raise ProtocolError("Cannot read from a closed connection") + if self.defunct: + raise ProtocolError("Cannot read from a defunct connection") raw = BytesIO() unpack = Unpacker(raw).unpack try: @@ -274,7 +282,7 @@ def fetch_next(self): except ProtocolError: self.defunct = True self.close() - return + raise # Unpack from the raw byte stream and call the relevant message handler(s) raw.seek(0) response = self.responses[0] diff --git a/neokit b/neokit index 23ffc81c..db7ab358 160000 --- a/neokit +++ b/neokit @@ -1 +1 @@ -Subproject commit 23ffc81c7a1a1f16369aa1cea71d77d256e57c8d +Subproject commit db7ab3580be3f0e09fcb2352408750f45c14a70e diff --git a/test/test_session.py b/test/test_session.py index 85afd724..9bac2ab4 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -234,14 +234,15 @@ def test_automatic_reset_after_failure(self): else: assert False, "A Cypher error should have occurred" - @watch def test_defunct(self): from neo4j.v1.connection import ChunkChannel, ProtocolError with GraphDatabase.driver("bolt://localhost").session() as session: assert not session.connection.defunct with patch.object(ChunkChannel, "chunk_reader", side_effect=ProtocolError()): - session.run("RETURN 1") + with self.assertRaises(ProtocolError): + session.run("RETURN 1") assert session.connection.defunct + assert session.connection.closed class RecordTestCase(TestCase): diff --git a/test/util.py b/test/util.py index 19e81332..793fadb6 100644 --- a/test/util.py +++ b/test/util.py @@ -26,7 +26,8 @@ def watch(f): """ Decorator to enable log watching for the lifetime of a function. - Useful for debugging unit tests. + Useful for debugging unit tests, simply add `@watch` to the top of + the test function. :param f: the function to decorate :return: a decorated function From e28cd52f95bac8141e42a53b2ced410e07a3473a Mon Sep 17 00:00:00 2001 From: Nigel Small Date: Fri, 15 Jan 2016 00:42:59 +0000 Subject: [PATCH 7/8] max_pool_size --- neo4j/v1/session.py | 52 +++++++++++++++++++++++++++++++++----------- test/test_session.py | 27 ++++++++++++++++++++++- 2 files changed, 65 insertions(+), 14 deletions(-) diff --git a/neo4j/v1/session.py b/neo4j/v1/session.py index 57e95d3d..f7d0fac2 100644 --- a/neo4j/v1/session.py +++ b/neo4j/v1/session.py @@ -36,6 +36,8 @@ class which can be used to obtain `Driver` instances that are used for from .typesystem import hydrated +DEFAULT_MAX_POOL_SIZE = 50 + STATEMENT_TYPE_READ_ONLY = "r" STATEMENT_TYPE_READ_WRITE = "rw" STATEMENT_TYPE_WRITE_ONLY = "w" @@ -91,7 +93,8 @@ def __init__(self, url, **config): else: raise ValueError("Unsupported URL scheme: %s" % parsed.scheme) self.config = config - self.sessions = deque() + self.max_pool_size = config.get("max_pool_size", DEFAULT_MAX_POOL_SIZE) + self.session_pool = deque() def session(self): """ Create a new session based on the graph database details @@ -102,14 +105,33 @@ def session(self): >>> session = driver.session() """ - try: - session = self.sessions.pop() - except IndexError: - session = Session(self) - else: - session.connection.reset() + session = None + done = False + while not done: + try: + session = self.session_pool.pop() + except IndexError: + session = Session(self) + done = True + else: + if session.healthy: + session.connection.reset() + done = session.healthy return session + def recycle(self, session): + """ Pass a session back to the driver for recycling, if healthy. + + :param session: + :return: + """ + pool = self.session_pool + for s in pool: + if not s.healthy: + pool.remove(s) + if session.healthy and len(pool) < self.max_pool_size and session not in pool: + pool.appendleft(session) + class Result(list): """ A handler for the result of Cypher statement execution. @@ -344,12 +366,10 @@ def __init__(self, driver): self.connection = connect(driver.host, driver.port, **driver.config) self.transaction = None self.bench_tests = [] - self.closed = False def __del__(self): - if not self.closed: + if not self.connection.closed: self.connection.close() - self.closed = True def __enter__(self): return self @@ -357,6 +377,14 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self.close() + @property + def healthy(self): + """ Return ``True`` if this session is healthy, ``False`` if + unhealthy and ``None`` if closed. + """ + connection = self.connection + return None if connection.closed else not connection.defunct + def run(self, statement, parameters=None): """ Run a parameterised Cypher statement. @@ -411,9 +439,7 @@ def run(self, statement, parameters=None): def close(self): """ If still usable, return this session to the driver pool it came from. """ - if not self.connection.defunct: - self.driver.sessions.appendleft(self) - self.closed = True + self.driver.recycle(self) def begin_transaction(self): """ Create a new :class:`.Transaction` within this session. diff --git a/test/test_session.py b/test/test_session.py index 9bac2ab4..1c0408cb 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -24,10 +24,35 @@ from mock import patch from neo4j.v1.session import GraphDatabase, CypherError, Record, record from neo4j.v1.typesystem import Node, Relationship, Path -from test.util import watch + + +class DriverTestCase(TestCase): + + def test_healthy_session_will_be_returned_to_the_pool_on_close(self): + driver = GraphDatabase.driver("bolt://localhost") + assert len(driver.session_pool) == 0 + driver.session().close() + assert len(driver.session_pool) == 1 + + def test_unhealthy_session_will_not_be_returned_to_the_pool_on_close(self): + driver = GraphDatabase.driver("bolt://localhost") + assert len(driver.session_pool) == 0 + session = driver.session() + session.connection.defunct = True + session.close() + assert len(driver.session_pool) == 0 + + def session_pool_cannot_exceed_max_size(self): + driver = GraphDatabase.driver("bolt://localhost", max_pool_size=1) + assert len(driver.session_pool) == 0 + driver.session().close() + assert len(driver.session_pool) == 1 + driver.session().close() + assert len(driver.session_pool) == 1 class RunTestCase(TestCase): + def test_must_use_valid_url_scheme(self): with self.assertRaises(ValueError): GraphDatabase.driver("x://xxx") From 309826f4059f4620bd054c419decfa5653a18d41 Mon Sep 17 00:00:00 2001 From: Nigel Small Date: Fri, 15 Jan 2016 00:48:06 +0000 Subject: [PATCH 8/8] Test sessions expiring in the pool --- test/test_session.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/test_session.py b/test/test_session.py index 1c0408cb..625675ef 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -50,6 +50,15 @@ def session_pool_cannot_exceed_max_size(self): driver.session().close() assert len(driver.session_pool) == 1 + def test_session_that_dies_in_the_pool_will_not_be_given_out(self): + driver = GraphDatabase.driver("bolt://localhost") + session_1 = driver.session() + session_1.close() + assert len(driver.session_pool) == 1 + session_1.connection.close() + session_2 = driver.session() + assert session_2 is not session_1 + class RunTestCase(TestCase):