diff --git a/neo4j/__main__.py b/neo4j/__main__.py index 366eb09a..b1c64829 100644 --- a/neo4j/__main__.py +++ b/neo4j/__main__.py @@ -26,53 +26,10 @@ from json import loads as json_loads from sys import stdout, stderr +from .util import Watcher from .v1.session import GraphDatabase, CypherError -class ColourFormatter(logging.Formatter): - """ Colour formatter for pretty log output. - """ - - def format(self, record): - s = super(ColourFormatter, self).format(record) - if record.levelno == logging.CRITICAL: - return "\x1b[31;1m%s\x1b[0m" % s # bright red - elif record.levelno == logging.ERROR: - return "\x1b[33;1m%s\x1b[0m" % s # bright yellow - elif record.levelno == logging.WARNING: - return "\x1b[33m%s\x1b[0m" % s # yellow - elif record.levelno == logging.INFO: - return "\x1b[36m%s\x1b[0m" % s # cyan - elif record.levelno == logging.DEBUG: - return "\x1b[34m%s\x1b[0m" % s # blue - else: - return s - - -class Watcher(object): - """ Log watcher for debug output. - """ - - handlers = {} - - def __init__(self, logger_name): - super(Watcher, self).__init__() - self.logger_name = logger_name - self.logger = logging.getLogger(self.logger_name) - self.formatter = ColourFormatter("%(asctime)s %(message)s") - - def watch(self, level=logging.INFO, out=stdout): - try: - self.logger.removeHandler(self.handlers[self.logger_name]) - except KeyError: - pass - handler = logging.StreamHandler(out) - handler.setFormatter(self.formatter) - self.handlers[self.logger_name] = handler - self.logger.addHandler(handler) - self.logger.setLevel(level) - - def main(): parser = ArgumentParser(description="Execute one or more Cypher statements using Bolt.") parser.add_argument("statement", nargs="+") diff --git a/neo4j/util.py b/neo4j/util.py new file mode 100644 index 00000000..ddaa1e2e --- /dev/null +++ b/neo4j/util.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +# Copyright (c) 2002-2016 "Neo Technology," +# Network Engine for Objects in Lund AB [http://neotechnology.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import unicode_literals + +import logging +from argparse import ArgumentParser +from json import loads as json_loads +from sys import stdout, stderr + +from .v1.session import GraphDatabase, CypherError + + +class ColourFormatter(logging.Formatter): + """ Colour formatter for pretty log output. + """ + + def format(self, record): + s = super(ColourFormatter, self).format(record) + if record.levelno == logging.CRITICAL: + return "\x1b[31;1m%s\x1b[0m" % s # bright red + elif record.levelno == logging.ERROR: + return "\x1b[33;1m%s\x1b[0m" % s # bright yellow + elif record.levelno == logging.WARNING: + return "\x1b[33m%s\x1b[0m" % s # yellow + elif record.levelno == logging.INFO: + return "\x1b[36m%s\x1b[0m" % s # cyan + elif record.levelno == logging.DEBUG: + return "\x1b[34m%s\x1b[0m" % s # blue + else: + return s + + +class Watcher(object): + """ Log watcher for debug output. + """ + + handlers = {} + + def __init__(self, logger_name): + super(Watcher, self).__init__() + self.logger_name = logger_name + self.logger = logging.getLogger(self.logger_name) + self.formatter = ColourFormatter("%(asctime)s %(message)s") + + def watch(self, level=logging.INFO, out=stdout): + self.stop() + handler = logging.StreamHandler(out) + handler.setFormatter(self.formatter) + self.handlers[self.logger_name] = handler + self.logger.addHandler(handler) + self.logger.setLevel(level) + + def stop(self): + try: + self.logger.removeHandler(self.handlers[self.logger_name]) + except KeyError: + pass diff --git a/neo4j/v1/connection.py b/neo4j/v1/connection.py index 3ba6398b..70dbca1e 100644 --- a/neo4j/v1/connection.py +++ b/neo4j/v1/connection.py @@ -42,7 +42,7 @@ # Signature bytes for each message type INIT = b"\x01" # 0000 0001 // INIT -ACK_FAILURE = b"\x0F" # 0000 1111 // ACK_FAILURE +RESET = b"\x0F" # 0000 1111 // RESET RUN = b"\x10" # 0001 0000 // RUN DISCARD_ALL = b"\x2F" # 0010 1111 // DISCARD * PULL_ALL = b"\x3F" # 0011 1111 // PULL * @@ -56,7 +56,7 @@ message_names = { INIT: "INIT", - ACK_FAILURE: "ACK_FAILURE", + RESET: "RESET", RUN: "RUN", DISCARD_ALL: "DISCARD_ALL", PULL_ALL: "PULL_ALL", @@ -169,14 +169,6 @@ def chunk_reader(self): data = self._recv(chunk_size) yield data - def close(self): - """ Shut down and close the connection. - """ - if __debug__: log_info("~~ [CLOSE]") - socket = self.socket - socket.shutdown(SHUT_RDWR) - socket.close() - class Response(object): """ Subscriber object for a full response (zero or @@ -200,12 +192,6 @@ def on_ignored(self, metadata=None): pass -class AckFailureResponse(Response): - - def on_failure(self, metadata): - raise ProtocolError("Could not acknowledge failure") - - class Connection(object): """ Server connection through which all protocol messages are sent and received. This class is designed for protocol @@ -215,9 +201,11 @@ class Connection(object): """ def __init__(self, sock, **config): + self.defunct = False self.channel = ChunkChannel(sock) self.packer = Packer(self.channel) self.responses = deque() + self.closed = False # Determine the user agent and ensure it is a Unicode value user_agent = config.get("user_agent", DEFAULT_USER_AGENT) @@ -235,8 +223,15 @@ def on_failure(metadata): while not response.complete: self.fetch_next() + def __del__(self): + self.close() + def append(self, signature, fields=(), response=None): """ Add a message to the outgoing queue. + + :arg signature: the signature of the message + :arg fields: the fields of the message as a tuple + :arg response: a response object to handle callbacks """ if __debug__: log_info("C: %s %s", message_names[signature], " ".join(map(repr, fields))) @@ -247,24 +242,58 @@ def append(self, signature, fields=(), response=None): self.channel.flush(end_of_message=True) self.responses.append(response) + def reset(self): + """ Add a RESET message to the outgoing queue, send + it and consume all remaining messages. + """ + response = Response(self) + + def on_failure(metadata): + raise ProtocolError("Reset failed") + + response.on_failure = on_failure + + self.append(RESET, response=response) + self.send() + fetch_next = self.fetch_next + while not response.complete: + fetch_next() + def send(self): """ Send all queued messages to the server. """ + if self.closed: + raise ProtocolError("Cannot write to a closed connection") + if self.defunct: + raise ProtocolError("Cannot write to a defunct connection") self.channel.send() def fetch_next(self): """ Receive exactly one message from the server. """ + if self.closed: + raise ProtocolError("Cannot read from a closed connection") + if self.defunct: + raise ProtocolError("Cannot read from a defunct connection") raw = BytesIO() unpack = Unpacker(raw).unpack - raw.writelines(self.channel.chunk_reader()) - + try: + raw.writelines(self.channel.chunk_reader()) + except ProtocolError: + self.defunct = True + self.close() + raise # Unpack from the raw byte stream and call the relevant message handler(s) raw.seek(0) response = self.responses[0] for signature, fields in unpack(): if __debug__: log_info("S: %s %s", message_names[signature], " ".join(map(repr, fields))) + if signature in SUMMARY: + response.complete = True + self.responses.popleft() + if signature == FAILURE: + self.reset() handler_name = "on_%s" % message_names[signature].lower() try: handler = getattr(response, handler_name) @@ -272,17 +301,16 @@ def fetch_next(self): pass else: handler(*fields) - if signature in SUMMARY: - response.complete = True - self.responses.popleft() - if signature == FAILURE: - self.append(ACK_FAILURE, response=AckFailureResponse(self)) raw.close() def close(self): - """ Shut down and close the connection. + """ Close the connection. """ - self.channel.close() + if not self.closed: + if __debug__: + log_info("~~ [CLOSE]") + self.channel.socket.close() + self.closed = True def connect(host, port=None, **config): diff --git a/neo4j/v1/session.py b/neo4j/v1/session.py index 2fda936d..f7d0fac2 100644 --- a/neo4j/v1/session.py +++ b/neo4j/v1/session.py @@ -28,7 +28,7 @@ class which can be used to obtain `Driver` instances that are used for from __future__ import division -from collections import namedtuple +from collections import deque, namedtuple from .compat import integer, perf_counter, string, urlparse from .connection import connect, Response, RUN, PULL_ALL @@ -36,6 +36,8 @@ class which can be used to obtain `Driver` instances that are used for from .typesystem import hydrated +DEFAULT_MAX_POOL_SIZE = 50 + STATEMENT_TYPE_READ_ONLY = "r" STATEMENT_TYPE_READ_WRITE = "rw" STATEMENT_TYPE_WRITE_ONLY = "w" @@ -91,15 +93,44 @@ def __init__(self, url, **config): else: raise ValueError("Unsupported URL scheme: %s" % parsed.scheme) self.config = config + self.max_pool_size = config.get("max_pool_size", DEFAULT_MAX_POOL_SIZE) + self.session_pool = deque() - def session(self, **config): + def session(self): """ Create a new session based on the graph database details specified within this driver: + >>> from neo4j.v1 import GraphDatabase + >>> driver = GraphDatabase.driver("bolt://localhost") >>> session = driver.session() """ - return Session(connect(self.host, self.port, **dict(self.config, **config))) + session = None + done = False + while not done: + try: + session = self.session_pool.pop() + except IndexError: + session = Session(self) + done = True + else: + if session.healthy: + session.connection.reset() + done = session.healthy + return session + + def recycle(self, session): + """ Pass a session back to the driver for recycling, if healthy. + + :param session: + :return: + """ + pool = self.session_pool + for s in pool: + if not s.healthy: + pool.remove(s) + if session.healthy and len(pool) < self.max_pool_size and session not in pool: + pool.appendleft(session) class Result(list): @@ -118,7 +149,7 @@ def __init__(self, session, statement, parameters): self.statement = statement self.parameters = parameters self.keys = None - self.complete = False + self.more = True self.summary = None self.bench_test = None @@ -137,7 +168,7 @@ def on_record(self, values): def on_footer(self, metadata): """ Called on receipt of the result footer. """ - self.complete = True + self.more = False self.summary = ResultSummary(self.statement, self.parameters, **metadata) if self.bench_test: self.bench_test.end_recv = perf_counter() @@ -152,7 +183,7 @@ def consume(self): callback functions. """ fetch_next = self.session.connection.fetch_next - while not self.complete: + while self.more: fetch_next() def summarize(self): @@ -330,17 +361,30 @@ class Session(object): method. """ - def __init__(self, connection): - self.connection = connection + def __init__(self, driver): + self.driver = driver + self.connection = connect(driver.host, driver.port, **driver.config) self.transaction = None self.bench_tests = [] + def __del__(self): + if not self.connection.closed: + self.connection.close() + def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.close() + @property + def healthy(self): + """ Return ``True`` if this session is healthy, ``False`` if + unhealthy and ``None`` if closed. + """ + connection = self.connection + return None if connection.closed else not connection.defunct + def run(self, statement, parameters=None): """ Run a parameterised Cypher statement. @@ -393,9 +437,9 @@ def run(self, statement, parameters=None): return result def close(self): - """ Shut down and close the session. + """ If still usable, return this session to the driver pool it came from. """ - self.connection.close() + self.driver.recycle(self) def begin_transaction(self): """ Create a new :class:`.Transaction` within this session. @@ -473,6 +517,7 @@ def close(self): self.closed = True self.session.transaction = None + class Record(object): """ Record is an ordered collection of fields. diff --git a/neokit b/neokit index 23ffc81c..db7ab358 160000 --- a/neokit +++ b/neokit @@ -1 +1 @@ -Subproject commit 23ffc81c7a1a1f16369aa1cea71d77d256e57c8d +Subproject commit db7ab3580be3f0e09fcb2352408750f45c14a70e diff --git a/runtests.sh b/runtests.sh index b109a8b2..38a2923b 100755 --- a/runtests.sh +++ b/runtests.sh @@ -73,12 +73,12 @@ else if [ ${EXIT_STATUS} -eq 0 ] then coverage report --show-missing + python -c 'from test.tck.configure_feature_files import *; set_up()' + echo "Feature files downloaded" + neokit/neorun ${NEORUN_OPTIONS} "${BEHAVE_RUNNER}" ${VERSIONS} + python -c 'from test.tck.configure_feature_files import *; clean_up()' + echo "Feature files removed" fi - python -c 'from test.tck.configure_feature_files import *; set_up()' - echo "Feature files downloaded" - neokit/neorun ${NEORUN_OPTIONS} "${BEHAVE_RUNNER}" ${VERSIONS} - python -c 'from test.tck.configure_feature_files import *; clean_up()' - echo "Feature files removed" fi # Exit correctly diff --git a/test/test_session.py b/test/test_session.py index 16e2e19b..625675ef 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -21,15 +21,67 @@ from unittest import TestCase +from mock import patch from neo4j.v1.session import GraphDatabase, CypherError, Record, record from neo4j.v1.typesystem import Node, Relationship, Path +class DriverTestCase(TestCase): + + def test_healthy_session_will_be_returned_to_the_pool_on_close(self): + driver = GraphDatabase.driver("bolt://localhost") + assert len(driver.session_pool) == 0 + driver.session().close() + assert len(driver.session_pool) == 1 + + def test_unhealthy_session_will_not_be_returned_to_the_pool_on_close(self): + driver = GraphDatabase.driver("bolt://localhost") + assert len(driver.session_pool) == 0 + session = driver.session() + session.connection.defunct = True + session.close() + assert len(driver.session_pool) == 0 + + def session_pool_cannot_exceed_max_size(self): + driver = GraphDatabase.driver("bolt://localhost", max_pool_size=1) + assert len(driver.session_pool) == 0 + driver.session().close() + assert len(driver.session_pool) == 1 + driver.session().close() + assert len(driver.session_pool) == 1 + + def test_session_that_dies_in_the_pool_will_not_be_given_out(self): + driver = GraphDatabase.driver("bolt://localhost") + session_1 = driver.session() + session_1.close() + assert len(driver.session_pool) == 1 + session_1.connection.close() + session_2 = driver.session() + assert session_2 is not session_1 + + class RunTestCase(TestCase): + def test_must_use_valid_url_scheme(self): with self.assertRaises(ValueError): GraphDatabase.driver("x://xxx") + def test_sessions_are_reused(self): + driver = GraphDatabase.driver("bolt://localhost") + session_1 = driver.session() + session_1.close() + session_2 = driver.session() + session_2.close() + assert session_1 is session_2 + + def test_sessions_are_not_reused_if_still_in_use(self): + driver = GraphDatabase.driver("bolt://localhost") + session_1 = driver.session() + session_2 = driver.session() + session_2.close() + session_1.close() + assert session_1 is not session_2 + def test_can_run_simple_statement(self): session = GraphDatabase.driver("bolt://localhost").session() count = 0 @@ -204,6 +256,29 @@ def test_can_obtain_notification_info(self): assert position.column == 1 +class ResetTestCase(TestCase): + + def test_automatic_reset_after_failure(self): + with GraphDatabase.driver("bolt://localhost").session() as session: + try: + session.run("X") + except CypherError: + result = session.run("RETURN 1") + assert result[0][0] == 1 + else: + assert False, "A Cypher error should have occurred" + + def test_defunct(self): + from neo4j.v1.connection import ChunkChannel, ProtocolError + with GraphDatabase.driver("bolt://localhost").session() as session: + assert not session.connection.defunct + with patch.object(ChunkChannel, "chunk_reader", side_effect=ProtocolError()): + with self.assertRaises(ProtocolError): + session.run("RETURN 1") + assert session.connection.defunct + assert session.connection.closed + + class RecordTestCase(TestCase): def test_record_equality(self): record1 = Record(["name","empire"], ["Nigel", "The British Empire"]) diff --git a/test/util.py b/test/util.py new file mode 100644 index 00000000..793fadb6 --- /dev/null +++ b/test/util.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +# Copyright (c) 2002-2016 "Neo Technology," +# Network Engine for Objects in Lund AB [http://neotechnology.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import functools + +from neo4j.util import Watcher + + +def watch(f): + """ Decorator to enable log watching for the lifetime of a function. + Useful for debugging unit tests, simply add `@watch` to the top of + the test function. + + :param f: the function to decorate + :return: a decorated function + """ + @functools.wraps(f) + def wrapper(*args, **kwargs): + watcher = Watcher("neo4j") + watcher.watch() + f(*args, **kwargs) + watcher.stop() + return wrapper diff --git a/test_requirements.txt b/test_requirements.txt index 85278c30..b5c3dcfd 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -1,3 +1,4 @@ behave coverage +mock teamcity-messages