diff --git a/.coveragerc b/.coveragerc index 4a797bf5..69400d1f 100644 --- a/.coveragerc +++ b/.coveragerc @@ -2,4 +2,6 @@ branch = True omit = test/* + neo4j/util.py neo4j/v1/compat.py + neo4j/v1/ssl_compat.py diff --git a/.gitignore b/.gitignore index d28affbf..b27cc7bd 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,8 @@ docs/build dist *.egg-info build + +*/run/* + +neo4j-community-* +neo4j-enterprise-* diff --git a/.gitmodules b/.gitmodules index 940f8c3d..e69de29b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +0,0 @@ -[submodule "neokit"] - path = neokit - url = https://github.com/neo-technology/neokit.git diff --git a/examples/test_examples.py b/examples/test_examples.py index ba68fac0..ffbe7859 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -32,7 +32,13 @@ # (* "good reason" is defined as knowing what you are doing) -auth_token = basic_auth("neo4j", "neo4j") +auth_token = basic_auth("neotest", "neotest") + + +# Deliberately shadow the built-in print function to +# mute noise from example code. +def print(*args, **kwargs): + pass class FreshDatabaseTestCase(ServerTestCase): @@ -48,7 +54,7 @@ class MinimalWorkingExampleTestCase(FreshDatabaseTestCase): def test_minimal_working_example(self): # tag::minimal-example[] - driver = GraphDatabase.driver("bolt://localhost", auth=basic_auth("neo4j", "neo4j")) + driver = GraphDatabase.driver("bolt://localhost", auth=basic_auth("neotest", "neotest")) session = driver.session() session.run("CREATE (a:Person {name:'Arthur', title:'King'})") @@ -65,33 +71,33 @@ class ExamplesTestCase(FreshDatabaseTestCase): def test_construct_driver(self): # tag::construct-driver[] - driver = GraphDatabase.driver("bolt://localhost", auth=basic_auth("neo4j", "neo4j")) + driver = GraphDatabase.driver("bolt://localhost", auth=basic_auth("neotest", "neotest")) # end::construct-driver[] return driver def test_configuration(self): # tag::configuration[] - driver = GraphDatabase.driver("bolt://localhost", auth=basic_auth("neo4j", "neo4j"), max_pool_size=10) + driver = GraphDatabase.driver("bolt://localhost", auth=basic_auth("neotest", "neotest"), max_pool_size=10) # end::configuration[] return driver @skipUnless(SSL_AVAILABLE, "Bolt over TLS is not supported by this version of Python") def test_tls_require_encryption(self): # tag::tls-require-encryption[] - driver = GraphDatabase.driver("bolt://localhost", auth=basic_auth("neo4j", "neo4j"), encrypted=True) + driver = GraphDatabase.driver("bolt://localhost", auth=basic_auth("neotest", "neotest"), encrypted=True) # end::tls-require-encryption[] @skipUnless(SSL_AVAILABLE, "Bolt over TLS is not supported by this version of Python") def test_tls_trust_on_first_use(self): # tag::tls-trust-on-first-use[] - driver = GraphDatabase.driver("bolt://localhost", auth=basic_auth("neo4j", "neo4j"), encrypted=True, trust=TRUST_ON_FIRST_USE) + driver = GraphDatabase.driver("bolt://localhost", auth=basic_auth("neotest", "neotest"), encrypted=True, trust=TRUST_ON_FIRST_USE) # end::tls-trust-on-first-use[] assert driver @skip("testing verified certificates not yet supported ") def test_tls_signed(self): # tag::tls-signed[] - driver = GraphDatabase.driver("bolt://localhost", auth=basic_auth("neo4j", "neo4j"), encrypted=True, trust=TRUST_SIGNED_CERTIFICATES) + driver = GraphDatabase.driver("bolt://localhost", auth=basic_auth("neotest", "neotest"), encrypted=True, trust=TRUST_SIGNED_CERTIFICATES) # end::tls-signed[] assert driver diff --git a/neo4j/util.py b/neo4j/util.py index 0bdbbf96..2ff112b0 100644 --- a/neo4j/util.py +++ b/neo4j/util.py @@ -19,6 +19,12 @@ # limitations under the License. +try: + from collections.abc import MutableSet +except ImportError: + from collections import MutableSet, OrderedDict +else: + from collections import OrderedDict import logging from sys import stdout @@ -55,6 +61,13 @@ def __init__(self, logger_name): self.logger = logging.getLogger(self.logger_name) self.formatter = ColourFormatter("%(asctime)s %(message)s") + def __enter__(self): + self.watch() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() + def watch(self, level=logging.INFO, out=stdout): self.stop() handler = logging.StreamHandler(out) @@ -81,3 +94,61 @@ def watch(logger_name, level=logging.INFO, out=stdout): watcher = Watcher(logger_name) watcher.watch(level, out) return watcher + + +class RoundRobinSet(MutableSet): + + def __init__(self, elements=()): + self._elements = OrderedDict.fromkeys(elements) + self._current = None + + def __repr__(self): + return "{%s}" % ", ".join(map(repr, self._elements)) + + def __contains__(self, element): + return element in self._elements + + def __next__(self): + current = None + if self._elements: + if self._current is None: + self._current = 0 + else: + self._current = (self._current + 1) % len(self._elements) + current = list(self._elements.keys())[self._current] + return current + + def __iter__(self): + return iter(self._elements) + + def __len__(self): + return len(self._elements) + + def add(self, element): + self._elements[element] = None + + def clear(self): + self._elements.clear() + + def discard(self, element): + try: + del self._elements[element] + except KeyError: + pass + + def next(self): + return self.__next__() + + def remove(self, element): + try: + del self._elements[element] + except KeyError: + raise ValueError(element) + + def update(self, elements=()): + self._elements.update(OrderedDict.fromkeys(elements)) + + def replace(self, elements=()): + e = self._elements + e.clear() + e.update(OrderedDict.fromkeys(elements)) diff --git a/neo4j/v1/bolt.py b/neo4j/v1/bolt.py index d55b8a93..453c1ec0 100644 --- a/neo4j/v1/bolt.py +++ b/neo4j/v1/bolt.py @@ -37,9 +37,10 @@ from select import select from socket import create_connection, SHUT_RDWR, error as SocketError from struct import pack as struct_pack, unpack as struct_unpack, unpack_from as struct_unpack_from +from threading import Lock from .constants import DEFAULT_USER_AGENT, KNOWN_HOSTS, MAGIC_PREAMBLE, TRUST_DEFAULT, TRUST_ON_FIRST_USE -from .exceptions import ProtocolError, Unauthorized +from .exceptions import ProtocolError, Unauthorized, ServiceUnavailable from .packstream import Packer, Unpacker from .ssl_compat import SSL_AVAILABLE, HAS_SNI, SSLError @@ -83,6 +84,7 @@ class BufferingSocket(object): def __init__(self, socket): + self.address = socket.getpeername() self.socket = socket self.buffer = bytearray() @@ -90,12 +92,11 @@ def fill(self): ready_to_read, _, _ = select((self.socket,), (), (), 0) received = self.socket.recv(65539) if received: - if __debug__: - log_debug("S: b%r", received) + log_debug("S: b%r", received) self.buffer[len(self.buffer):] = received else: if ready_to_read is not None: - raise ProtocolError("Server closed connection") + raise ServiceUnavailable("Failed to read from connection %r" % (self.address,)) def read_message(self): message_data = bytearray() @@ -126,6 +127,7 @@ class ChunkChannel(object): def __init__(self, sock): self.socket = sock + self.address = sock.getpeername() self.raw = BytesIO() self.output_buffer = [] self.output_size = 0 @@ -171,8 +173,7 @@ def send(self): """ Send all queued messages to the server. """ data = self.raw.getvalue() - if __debug__: - log_debug("C: b%r", data) + log_debug("C: b%r", data) self.socket.sendall(data) self.raw.seek(self.raw.truncate(0)) @@ -201,9 +202,11 @@ def on_ignored(self, metadata=None): class Connection(object): - """ Server connection through which all protocol messages - are sent and received. This class is designed for protocol - version 1. + """ Server connection for Bolt protocol v1. + + A :class:`.Connection` should be constructed following a + successful Bolt handshake and takes the socket over which + the handshake was carried out. .. note:: logs at INFO level """ @@ -211,12 +214,14 @@ class Connection(object): def __init__(self, sock, **config): self.socket = sock self.buffering_socket = BufferingSocket(sock) - self.defunct = False + self.address = sock.getpeername() self.channel = ChunkChannel(sock) self.packer = Packer(self.channel) self.unpacker = Unpacker() self.responses = deque() + self.in_use = False self.closed = False + self.defunct = False # Determine the user agent and ensure it is a Unicode value user_agent = config.get("user_agent", DEFAULT_USER_AGENT) @@ -235,7 +240,8 @@ def __init__(self, sock, **config): def on_failure(metadata): code = metadata.get("code") - error = Unauthorized if code == "Neo.ClientError.Security.Unauthorized" else ProtocolError + error = (Unauthorized if code == "Neo.ClientError.Security.Unauthorized" else + ServiceUnavailable) raise error(metadata.get("message", "INIT failed")) response = Response(self) @@ -249,13 +255,6 @@ def on_failure(metadata): def __del__(self): self.close() - @property - def healthy(self): - """ Return ``True`` if this connection is healthy, ``False`` if - unhealthy and ``None`` if closed. - """ - return None if self.closed else not self.defunct - def append(self, signature, fields=(), response=None): """ Add a message to the outgoing queue. @@ -263,8 +262,7 @@ def append(self, signature, fields=(), response=None): :arg fields: the fields of the message as a tuple :arg response: a response object to handle callbacks """ - if __debug__: - log_info("C: %s %s", message_names[signature], " ".join(map(repr, fields))) + log_info("C: %s %r", message_names[signature], fields) self.packer.pack_struct_header(len(fields), signature) for field in fields: @@ -310,47 +308,54 @@ def send(self): """ Send all queued messages to the server. """ if self.closed: - raise ProtocolError("Cannot write to a closed connection") + raise ServiceUnavailable("Failed to write to closed connection %r" % (self.address,)) if self.defunct: - raise ProtocolError("Cannot write to a defunct connection") + raise ServiceUnavailable("Failed to write to defunct connection %r" % (self.address,)) self.channel.send() def fetch(self): """ Receive exactly one message from the server. """ if self.closed: - raise ProtocolError("Cannot read from a closed connection") + raise ServiceUnavailable("Failed to read from closed connection %r" % (self.address,)) if self.defunct: - raise ProtocolError("Cannot read from a defunct connection") + raise ServiceUnavailable("Failed to read from defunct connection %r" % (self.address,)) try: message_data = self.buffering_socket.read_message() except ProtocolError: self.defunct = True self.close() raise - # Unpack from the raw byte stream and call the relevant message handler(s) - self.unpacker.load(message_data) - size, signature = self.unpacker.unpack_structure_header() - fields = [self.unpacker.unpack() for _ in range(size)] - if __debug__: - log_info("S: %s %r", message_names[signature], fields) + unpacker = self.unpacker + unpacker.load(message_data) + size, signature = unpacker.unpack_structure_header() + if size > 1: + raise ProtocolError("Expected one field") if signature == SUCCESS: + metadata = unpacker.unpack_map() + log_info("S: SUCCESS (%r)", metadata) response = self.responses.popleft() response.complete = True - response.on_success(*fields) + response.on_success(metadata or {}) elif signature == RECORD: + data = unpacker.unpack_list() + log_info("S: RECORD (%r)", data) response = self.responses[0] - response.on_record(*fields) + response.on_record(data or []) elif signature == IGNORED: + metadata = unpacker.unpack_map() + log_info("S: IGNORED (%r)", metadata) response = self.responses.popleft() response.complete = True - response.on_ignored(*fields) + response.on_ignored(metadata or {}) elif signature == FAILURE: + metadata = unpacker.unpack_map() + log_info("S: FAILURE (%r)", metadata) response = self.responses.popleft() response.complete = True - response.on_failure(*fields) + response.on_failure(metadata or {}) else: raise ProtocolError("Unexpected response message with signature %02X" % signature) @@ -364,12 +369,62 @@ def close(self): """ Close the connection. """ if not self.closed: - if __debug__: - log_info("~~ [CLOSE]") + log_info("~~ [CLOSE]") self.channel.socket.close() self.closed = True +class ConnectionPool(object): + """ A collection of connections to one or more server addresses. + """ + + def __init__(self, connector): + self.connector = connector + self.connections = {} + self.lock = Lock() + + def acquire(self, address): + """ Acquire a connection to a given address from the pool. + This method is thread safe. + """ + with self.lock: + try: + connections = self.connections[address] + except KeyError: + connections = self.connections[address] = deque() + for connection in list(connections): + if connection.closed or connection.defunct: + connections.remove(connection) + continue + if not connection.in_use: + connection.in_use = True + return connection + connection = self.connector(address) + connection.in_use = True + connections.append(connection) + return connection + + def release(self, connection): + """ Release a connection back into the pool. + This method is thread safe. + """ + with self.lock: + connection.in_use = False + + def close(self): + """ Close all connections and empty the pool. + This method is thread safe. + """ + with self.lock: + for _, connections in self.connections.items(): + for connection in connections: + try: + connection.close() + except IOError: + pass + self.connections.clear() + + class CertificateStore(object): def match_or_trust(self, host, der_encoded_certificate): @@ -416,7 +471,7 @@ def match_or_trust(self, host, der_encoded_certificate): return True -def connect(host_port, ssl_context=None, **config): +def connect(address, ssl_context=None, **config): """ Connect and perform a handshake and return a valid Connection object, assuming a protocol version can be agreed. """ @@ -424,45 +479,47 @@ def connect(host_port, ssl_context=None, **config): # Establish a connection to the host and port specified # Catches refused connections see: # https://docs.python.org/2/library/errno.html - if __debug__: log_info("~~ [CONNECT] %s", host_port) + log_info("~~ [CONNECT] %s", address) try: - s = create_connection(host_port) + s = create_connection(address) except SocketError as error: if error.errno == 111 or error.errno == 61 or error.errno == 10061: - raise ProtocolError("Unable to connect to %s on port %d - is the server running?" % host_port) + raise ServiceUnavailable("Failed to establish connection to %r" % (address,)) else: raise # Secure the connection if an SSL context has been provided if ssl_context and SSL_AVAILABLE: - host, port = host_port - if __debug__: log_info("~~ [SECURE] %s", host) + host, port = address + log_info("~~ [SECURE] %s", host) try: s = ssl_context.wrap_socket(s, server_hostname=host if HAS_SNI else None) except SSLError as cause: - error = ProtocolError("Cannot establish secure connection; %s" % cause.args[1]) + error = ServiceUnavailable("Failed to establish secure " + "connection to %r" % cause.args[1]) error.__cause__ = cause raise error else: # Check that the server provides a certificate der_encoded_server_certificate = s.getpeercert(binary_form=True) if der_encoded_server_certificate is None: - raise ProtocolError("When using a secure socket, the server should always provide a certificate") + raise ProtocolError("When using a secure socket, the server should always " + "provide a certificate") trust = config.get("trust", TRUST_DEFAULT) if trust == TRUST_ON_FIRST_USE: store = PersonalCertificateStore() if not store.match_or_trust(host, der_encoded_server_certificate): - raise ProtocolError("Server certificate does not match known certificate for %r; check " - "details in file %r" % (host, KNOWN_HOSTS)) + raise ProtocolError("Server certificate does not match known certificate " + "for %r; check details in file %r" % (host, KNOWN_HOSTS)) else: der_encoded_server_certificate = None # Send details of the protocol versions supported supported_versions = [1, 0, 0, 0] handshake = [MAGIC_PREAMBLE] + supported_versions - if __debug__: log_info("C: [HANDSHAKE] 0x%X %r", MAGIC_PREAMBLE, supported_versions) + log_info("C: [HANDSHAKE] 0x%X %r", MAGIC_PREAMBLE, supported_versions) data = b"".join(struct_pack(">I", num) for num in handshake) - if __debug__: log_debug("C: b%r", data) + log_debug("C: b%r", data) s.sendall(data) # Handle the handshake response @@ -475,25 +532,25 @@ def connect(host_port, ssl_context=None, **config): # If no data is returned after a successful select # response, the server has closed the connection log_error("S: [CLOSE]") - raise ProtocolError("Server closed connection without responding to handshake") + raise ProtocolError("Connection to %r closed without handshake response" % (address,)) if data_size == 4: - if __debug__: log_debug("S: b%r", data) + log_debug("S: b%r", data) else: # Some other garbled data has been received log_error("S: @*#!") raise ProtocolError("Expected four byte handshake response, received %r instead" % data) agreed_version, = struct_unpack(">I", data) - if __debug__: log_info("S: [HANDSHAKE] %d", agreed_version) + log_info("S: [HANDSHAKE] %d", agreed_version) if agreed_version == 0: - if __debug__: log_info("~~ [CLOSE]") + log_info("~~ [CLOSE]") s.shutdown(SHUT_RDWR) s.close() elif agreed_version == 1: return Connection(s, der_encoded_server_certificate=der_encoded_server_certificate, **config) - elif agreed_version == 1213486160: + elif agreed_version == 0x48545450: log_error("S: [CLOSE]") - raise ProtocolError("Server responded HTTP. Make sure you are not trying to connect to the http endpoint " + - "(HTTP defaults to port 7474 whereas BOLT defaults to port 7687)") + raise ServiceUnavailable("Cannot to connect to Bolt service on %r " + "(looks like HTTP)" % (address,)) else: log_error("S: [CLOSE]") raise ProtocolError("Unknown Bolt protocol version: %d", agreed_version) diff --git a/neo4j/v1/constants.py b/neo4j/v1/constants.py index 660d9293..b6ce0afb 100644 --- a/neo4j/v1/constants.py +++ b/neo4j/v1/constants.py @@ -34,9 +34,15 @@ ENCRYPTION_OFF = 0 ENCRYPTION_ON = 1 -ENCRYPTION_NON_LOCAL = 2 -ENCRYPTION_DEFAULT = ENCRYPTION_NON_LOCAL if SSL_AVAILABLE else ENCRYPTION_OFF - -TRUST_ON_FIRST_USE = 0 -TRUST_SIGNED_CERTIFICATES = 1 -TRUST_DEFAULT = TRUST_ON_FIRST_USE +ENCRYPTION_DEFAULT = ENCRYPTION_ON if SSL_AVAILABLE else ENCRYPTION_OFF + +TRUST_ON_FIRST_USE = 0 # Deprecated +TRUST_SIGNED_CERTIFICATES = 1 # Deprecated +TRUST_ALL_CERTIFICATES = 2 +TRUST_CUSTOM_CA_SIGNED_CERTIFICATES = 3 +TRUST_SYSTEM_CA_SIGNED_CERTIFICATES = 4 +TRUST_DEFAULT = TRUST_ALL_CERTIFICATES + +READ_ACCESS = "READ" +WRITE_ACCESS = "WRITE" +ACCESS_DEFAULT = WRITE_ACCESS diff --git a/neo4j/v1/exceptions.py b/neo4j/v1/exceptions.py index baaed5ad..6aba5521 100644 --- a/neo4j/v1/exceptions.py +++ b/neo4j/v1/exceptions.py @@ -43,6 +43,26 @@ def __init__(self, data): setattr(self, key, value) +class TransactionError(Exception): + """ Raised when an error occurs while using a transaction. + """ + + class ResultError(Exception): """ Raised when an error occurs while consuming a result. """ + + +class ServiceUnavailable(Exception): + """ Raised when no database service is available. + """ + + +class SessionExpired(Exception): + """ Raised when no a session is no longer able to fulfil + its purpose. + """ + + def __init__(self, session, *args, **kwargs): + self.session = session + super(SessionExpired, self).__init__(*args, **kwargs) diff --git a/neo4j/v1/packstream.py b/neo4j/v1/packstream.py index f9557688..406b7906 100644 --- a/neo4j/v1/packstream.py +++ b/neo4j/v1/packstream.py @@ -708,64 +708,12 @@ def unpack(self): return self.read(byte_size).decode(ENCODING) # List - elif marker_high == 0x90: - size = marker & 0x0F - return [unpack() for _ in range(size)] - elif marker == 0xD4: # LIST_8: - size = UNPACKED_UINT_8[self.read_bytes(1)] - return [unpack() for _ in range(size)] - elif marker == 0xD5: # LIST_16: - size = UNPACKED_UINT_16[self.read_bytes(2)] - return [unpack() for _ in range(size)] - elif marker == 0xD6: # LIST_32: - size = struct_unpack(UINT_32_STRUCT, self.read(4))[0] - return [unpack() for _ in range(size)] - elif marker == 0xD7: # LIST_STREAM: - value = [] - item = None - while item is not EndOfStream: - item = unpack() - if item is not EndOfStream: - value.append(item) - return value + elif 0x90 <= marker <= 0x9F or 0xD4 <= marker <= 0xD7: + return self._unpack_list(marker) # Map - elif marker_high == 0xA0: - size = marker & 0x0F - value = {} - for _ in range(size): - key = unpack() - value[key] = unpack() - return value - elif marker == 0xD8: # MAP_8: - size = UNPACKED_UINT_8[self.read_bytes(1)] - value = {} - for _ in range(size): - key = unpack() - value[key] = unpack() - return value - elif marker == 0xD9: # MAP_16: - size = UNPACKED_UINT_16[self.read_bytes(2)] - value = {} - for _ in range(size): - key = unpack() - value[key] = unpack() - return value - elif marker == 0xDA: # MAP_32: - size = struct_unpack(UINT_32_STRUCT, self.read(4))[0] - value = {} - for _ in range(size): - key = unpack() - value[key] = unpack() - return value - elif marker == 0xDB: # MAP_STREAM: - value = {} - key = None - while key is not EndOfStream: - key = unpack() - if key is not EndOfStream: - value[key] = unpack() - return value + elif 0xA0 <= marker <= 0xAF or 0xD8 <= marker <= 0xDB: + return self._unpack_map(marker) # Structure elif marker_high == 0xB0: @@ -793,6 +741,87 @@ def unpack(self): else: raise RuntimeError("Unknown PackStream marker %02X" % marker) + def unpack_list(self): + marker = self.read_marker() + return self._unpack_list(marker) + + def _unpack_list(self, marker): + marker_high = marker & 0xF0 + unpack = self.unpack + if marker_high == 0x90: + size = marker & 0x0F + if size == 0: + return [] + elif size == 1: + return [unpack()] + else: + return [unpack() for _ in range(size)] + elif marker == 0xD4: # LIST_8: + size = UNPACKED_UINT_8[self.read_bytes(1)] + return [unpack() for _ in range(size)] + elif marker == 0xD5: # LIST_16: + size = UNPACKED_UINT_16[self.read_bytes(2)] + return [unpack() for _ in range(size)] + elif marker == 0xD6: # LIST_32: + size = struct_unpack(UINT_32_STRUCT, self.read_bytes(4))[0] + return [unpack() for _ in range(size)] + elif marker == 0xD7: # LIST_STREAM: + value = [] + item = None + while item is not EndOfStream: + item = unpack() + if item is not EndOfStream: + value.append(item) + return value + else: + return None + + def unpack_map(self): + marker = self.read_marker() + return self._unpack_map(marker) + + def _unpack_map(self, marker): + marker_high = marker & 0xF0 + unpack = self.unpack + if marker_high == 0xA0: + size = marker & 0x0F + value = {} + for _ in range(size): + key = unpack() + value[key] = unpack() + return value + elif marker == 0xD8: # MAP_8: + size = UNPACKED_UINT_8[self.read_bytes(1)] + value = {} + for _ in range(size): + key = unpack() + value[key] = unpack() + return value + elif marker == 0xD9: # MAP_16: + size = UNPACKED_UINT_16[self.read_bytes(2)] + value = {} + for _ in range(size): + key = unpack() + value[key] = unpack() + return value + elif marker == 0xDA: # MAP_32: + size = struct_unpack(UINT_32_STRUCT, self.read_bytes(4))[0] + value = {} + for _ in range(size): + key = unpack() + value[key] = unpack() + return value + elif marker == 0xDB: # MAP_STREAM: + value = {} + key = None + while key is not EndOfStream: + key = unpack() + if key is not EndOfStream: + value[key] = unpack() + return value + else: + return None + def unpack_structure_header(self): marker = self.read_marker() if marker == -1: diff --git a/neo4j/v1/session.py b/neo4j/v1/session.py index 24db614e..95af0ab5 100644 --- a/neo4j/v1/session.py +++ b/neo4j/v1/session.py @@ -29,19 +29,27 @@ from collections import deque import re +from threading import Lock +try: + from time import monotonic as clock +except ImportError: + from time import clock +from warnings import warn -from .bolt import connect, Response, RUN, PULL_ALL +from neo4j.util import RoundRobinSet + +from .bolt import connect, Response, RUN, PULL_ALL, ConnectionPool from .compat import integer, string, urlparse -from .constants import DEFAULT_PORT, ENCRYPTION_DEFAULT, TRUST_DEFAULT, TRUST_SIGNED_CERTIFICATES, ENCRYPTION_ON, \ - ENCRYPTION_NON_LOCAL -from .exceptions import CypherError, ProtocolError, ResultError +from .constants import DEFAULT_PORT, ENCRYPTION_DEFAULT, TRUST_DEFAULT, TRUST_SIGNED_CERTIFICATES, \ + TRUST_ON_FIRST_USE, READ_ACCESS, TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, \ + TRUST_ALL_CERTIFICATES, TRUST_CUSTOM_CA_SIGNED_CERTIFICATES +from .exceptions import CypherError, ProtocolError, ResultError, TransactionError, \ + ServiceUnavailable, SessionExpired from .ssl_compat import SSL_AVAILABLE, SSLContext, PROTOCOL_SSLv23, OP_NO_SSLv2, CERT_REQUIRED from .summary import ResultSummary from .types import hydrated -DEFAULT_MAX_POOL_SIZE = 50 - localhost = re.compile(r"^(localhost|127(\.\d+){3})$", re.IGNORECASE) @@ -62,243 +70,230 @@ class GraphDatabase(object): """ @staticmethod - def driver(url, **config): + def driver(uri, **config): """ Acquire a :class:`.Driver` instance for the given URL and configuration: >>> from neo4j.v1 import GraphDatabase >>> driver = GraphDatabase.driver("bolt://localhost") - """ - return Driver(url, **config) - + :param uri: URI for a graph database + :param config: configuration and authentication details (valid keys are listed below) -class Driver(object): - """ A :class:`.Driver` is an accessor for a specific graph database - resource. It provides both a template for sessions and a container - for the session pool. All configuration and authentication settings - are collected by the `Driver` constructor; should different settings - be required, a new `Driver` instance should be created. + `auth` + An authentication token for the server, for example + ``basic_auth("neo4j", "password")``. - :param address: address of the remote server as either a `bolt` URI - or a `host:port` string - :param config: configuration and authentication details (valid keys are listed below) + `der_encoded_server_certificate` + The server certificate in DER format, if required. - `auth` - An authentication token for the server, for example - ``basic_auth("neo4j", "password")``. + `encrypted` + Encryption level: one of :attr:`.ENCRYPTION_ON`, :attr:`.ENCRYPTION_OFF` + or :attr:`.ENCRYPTION_NON_LOCAL`. The default setting varies + depending on whether SSL is available or not. If it is, + :attr:`.ENCRYPTION_NON_LOCAL` is the default. - `der_encoded_server_certificate` - The server certificate in DER format, if required. + `trust` + Trust level: one of :attr:`.TRUST_ON_FIRST_USE` (default) or + :attr:`.TRUST_SIGNED_CERTIFICATES`. - `encrypted` - Encryption level: one of :attr:`.ENCRYPTION_ON`, :attr:`.ENCRYPTION_OFF` - or :attr:`.ENCRYPTION_NON_LOCAL`. The default setting varies - depending on whether SSL is available or not. If it is, - :attr:`.ENCRYPTION_NON_LOCAL` is the default. + `user_agent` + A custom user agent string, if required. - `max_pool_size` - The maximum number of sessions to keep idle in the session - pool. - - `trust` - Trust level: one of :attr:`.TRUST_ON_FIRST_USE` (default) or - :attr:`.TRUST_SIGNED_CERTIFICATES`. + """ + parsed = urlparse(uri) + if parsed.scheme == "bolt": + return DirectDriver((parsed.hostname, parsed.port or DEFAULT_PORT), **config) + elif parsed.scheme == "bolt+routing": + return RoutingDriver((parsed.hostname, parsed.port or DEFAULT_PORT), **config) + else: + raise ProtocolError("URI scheme %r not supported" % parsed.scheme) - `user_agent` - A custom user agent string, if required. - """ +class SecurityPlan(object): - def __init__(self, address, **config): - if "://" in address: - parsed = urlparse(address) - if parsed.scheme == "bolt": - host = parsed.hostname - port = parsed.port or DEFAULT_PORT - else: - raise ProtocolError("Only the 'bolt' URI scheme is supported [%s]" % address) - elif ":" in address: - host, port = address.split(":") - port = int(port) - else: - host = address - port = DEFAULT_PORT - self.address = (host, port) - self.config = config - self.max_pool_size = config.get("max_pool_size", DEFAULT_MAX_POOL_SIZE) - self.session_pool = deque() + @classmethod + def build(cls, address, **config): encrypted = config.get("encrypted", None) if encrypted is None: - _warn_about_insecure_default() - encrypted = ENCRYPTION_DEFAULT - self.encrypted = encrypted - self.trust = trust = config.get("trust", TRUST_DEFAULT) - if encrypted == ENCRYPTION_ON or \ - encrypted == ENCRYPTION_NON_LOCAL and not localhost.match(host): + encrypted = _encryption_default() + trust = config.get("trust", TRUST_DEFAULT) + if encrypted: if not SSL_AVAILABLE: - raise RuntimeError("Bolt over TLS is only available in Python 2.7.9+ and Python 3.3+") + raise RuntimeError("Bolt over TLS is only available in Python 2.7.9+ and " + "Python 3.3+") ssl_context = SSLContext(PROTOCOL_SSLv23) ssl_context.options |= OP_NO_SSLv2 - if trust >= TRUST_SIGNED_CERTIFICATES: + if trust == TRUST_ON_FIRST_USE: + warn("TRUST_ON_FIRST_USE is deprecated, please use " + "TRUST_ALL_CERTIFICATES instead") + elif trust == TRUST_SIGNED_CERTIFICATES: + warn("TRUST_SIGNED_CERTIFICATES is deprecated, please use " + "TRUST_SYSTEM_CA_SIGNED_CERTIFICATES instead") + ssl_context.verify_mode = CERT_REQUIRED + elif trust == TRUST_ALL_CERTIFICATES: + pass + elif trust == TRUST_CUSTOM_CA_SIGNED_CERTIFICATES: + raise NotImplementedError("Custom CA support is not implemented") + elif trust == TRUST_SYSTEM_CA_SIGNED_CERTIFICATES: ssl_context.verify_mode = CERT_REQUIRED + else: + raise ValueError("Unknown trust mode") ssl_context.set_default_verify_paths() - self.ssl_context = ssl_context else: - self.ssl_context = None - - def session(self): - """ Create a new session based on the graph database details - specified within this driver: + ssl_context = None + return cls(encrypted, ssl_context, trust != TRUST_ON_FIRST_USE) - >>> from neo4j.v1 import GraphDatabase - >>> driver = GraphDatabase.driver("bolt://localhost") - >>> session = driver.session() - """ - session = None - connected = False - while not connected: - try: - session = self.session_pool.pop() - except IndexError: - connection = connect(self.address, self.ssl_context, **self.config) - session = Session(self, connection) - connected = True - else: - if session.healthy: - connected = session.healthy - return session + def __init__(self, requires_encryption, ssl_context, routing_compatible): + self.encrypted = bool(requires_encryption) + self.ssl_context = ssl_context + self.routing_compatible = routing_compatible - def recycle(self, session): - """ Accept a session for recycling, if healthy. - - :param session: - :return: - """ - pool = self.session_pool - for s in list(pool): # freezing the pool into a list for iteration allows pool mutation inside the loop - if not s.healthy: - pool.remove(s) - if session.healthy and len(pool) < self.max_pool_size and session not in pool: - pool.appendleft(session) +class Driver(object): + """ A :class:`.Driver` is an accessor for a specific graph database + resource. It is thread-safe, acts as a template for sessions and hosts + a connection pool. -class StatementResult(object): - """ A handler for the result of Cypher statement execution. + All configuration and authentication settings are held immutably by the + `Driver`. Should different settings be required, a new `Driver` instance + should be created via the :meth:`.GraphDatabase.driver` method. """ - #: The statement text that was executed to produce this result. - statement = None - - #: Dictionary of parameters passed with the statement. - parameters = None + pool = None - def __init__(self, connection, run_response, pull_all_response): - super(StatementResult, self).__init__() + def __init__(self, connector): + self.pool = ConnectionPool(connector) - # The Connection instance behind this result. - self.connection = connection + def __enter__(self): + return self - # The keys for the records in the result stream. These are - # lazily populated on request. - self._keys = None + def __exit__(self, exc_type, exc_value, traceback): + self.close() - # Buffer for incoming records to be queued before yielding. If - # the result is used immediately, this buffer will be ignored. - self._buffer = deque() + def session(self, access_mode=None): + """ Create a new session using a connection from the driver connection + pool. Session creation is a lightweight operation and sessions are + not thread safe, therefore a session should generally be short-lived + within a single thread. + """ + pass - # The result summary (populated after the records have been - # fully consumed). - self._summary = None + def close(self): + if self.pool: + self.pool.close() - # Flag to indicate whether the entire stream has been consumed - # from the network (but not necessarily yielded). - self._consumed = False - def on_header(metadata): - # Called on receipt of the result header. - self._keys = metadata["fields"] +class DirectDriver(Driver): + """ A :class:`.DirectDriver` is created from a `bolt` URI and addresses + a single database instance. + """ - def on_record(values): - # Called on receipt of each result record. - self._buffer.append(values) + def __init__(self, address, **config): + self.address = address + self.security_plan = security_plan = SecurityPlan.build(address, **config) + self.encrypted = security_plan.encrypted + Driver.__init__(self, lambda a: connect(a, security_plan.ssl_context, **config)) - def on_footer(metadata): - # Called on receipt of the result footer. - self._summary = ResultSummary(self.statement, self.parameters, **metadata) - self._consumed = True + def session(self, access_mode=None): + return Session(self, self.pool.acquire(self.address)) - def on_failure(metadata): - # Called on execution failure. - self.connection.acknowledge_failure() - self._consumed = True - raise CypherError(metadata) - run_response.on_success = on_header - run_response.on_failure = on_failure +class ConnectionRouter(object): + """ The `Router` class contains logic for discovering servers within a + cluster that supports routing. + """ - pull_all_response.on_record = on_record - pull_all_response.on_success = on_footer - pull_all_response.on_failure = on_failure + timer = clock - def __iter__(self): - while self._buffer: - values = self._buffer.popleft() - yield Record(self.keys(), tuple(map(hydrated, values))) - while not self._consumed: - self.connection.fetch() - while self._buffer: - values = self._buffer.popleft() - yield Record(self.keys(), tuple(map(hydrated, values))) + def __init__(self, pool, *routers): + self.pool = pool + self.lock = Lock() + self.expiry_time = None + self.routers = RoundRobinSet(routers) + self.readers = RoundRobinSet() + self.writers = RoundRobinSet() - def keys(self): - """ Return the keys for the records. + def stale(self): + """ Indicator for whether routing information is out of date or + incomplete. """ - # Fetch messages until we have the header or a failure - while self._keys is None and not self._consumed: - self.connection.fetch() - return tuple(self._keys) + expired = self.expiry_time is None or self.expiry_time <= self.timer() + return expired or len(self.routers) <= 1 or not self.readers or not self.writers - def buffer(self): - if self.connection and not self.connection.closed: - while not self._consumed: - self.connection.fetch() - self.connection = None - - def consume(self): - """ Consume the remainder of this result and return the - summary. + def discover(self): + """ Perform cluster member discovery. """ - if self.connection and not self.connection.closed: - list(self) - self.connection = None - return self._summary - - def single(self): - """ Return the next record, failing if none or more than one remain. + with self.lock: + if not self.routers: + raise ServiceUnavailable("No routers available") + for router in list(self.routers): + connection = self.pool.acquire(router) + with Session(self, connection) as session: + try: + record = session.run("CALL dbms.cluster.routing.getServers").single() + except CypherError as error: + if error.code == "Neo.ClientError.Procedure.ProcedureNotFound": + raise ServiceUnavailable("Server %r does not support " + "routing" % (router,)) + raise + except ResultError: + raise ServiceUnavailable("Server %r returned no record from " + "discovery procedure" % (router,)) + else: + new_expiry_time = self.timer() + record["ttl"] + servers = record["servers"] + new_routers = [s["addresses"] for s in servers if s["role"] == "ROUTE"][0] + new_readers = [s["addresses"] for s in servers if s["role"] == "READ"][0] + new_writers = [s["addresses"] for s in servers if s["role"] == "WRITE"][0] + if new_routers and new_readers and new_writers: + self.expiry_time = new_expiry_time + self.routers.replace(map(parse_address, new_routers)) + self.readers.replace(map(parse_address, new_readers)) + self.writers.replace(map(parse_address, new_writers)) + return router + raise ServiceUnavailable("Unable to establish routing information") + + def acquire_read_connection(self): + """ Acquire a connection to a read server. """ - records = list(self) - num_records = len(records) - if num_records == 0: - raise ResultError("Cannot retrieve a single record, because this result is empty.") - elif num_records != 1: - raise ResultError("Expected a result with a single record, but this result contains at least one more.") - else: - return records[0] + if self.stale(): + self.discover() + return self.pool.acquire(next(self.readers)) - def peek(self): - """ Return the next record without advancing the cursor. Fails - if no records remain. + def acquire_write_connection(self): + """ Acquire a connection to a write server. """ - if self._buffer: - values = self._buffer[0] - return Record(self.keys(), tuple(map(hydrated, values))) - while not self._buffer and not self._consumed: - self.connection.fetch() - if self._buffer: - values = self._buffer[0] - return Record(self.keys(), tuple(map(hydrated, values))) - raise ResultError("End of stream") + if self.stale(): + self.discover() + return self.pool.acquire(next(self.writers)) + + +class RoutingDriver(Driver): + """ A :class:`.RoutingDriver` is created from a `bolt+routing` URI. + """ + + def __init__(self, address, **config): + self.security_plan = security_plan = SecurityPlan.build(address, **config) + self.encrypted = security_plan.encrypted + if not security_plan.routing_compatible: + # this error message is case-specific as there is only one incompatible + # scenario right now + raise ValueError("TRUST_ON_FIRST_USE is not compatible with routing") + Driver.__init__(self, lambda a: connect(a, security_plan.ssl_context, **config)) + try: + self.router = ConnectionRouter(self.pool, address) + self.router.discover() + except: + self.close() + raise + + def session(self, access_mode=None): + if access_mode == READ_ACCESS: + connection = self.router.acquire_read_connection() + else: + connection = self.router.acquire_write_connection() + return Session(self, connection, access_mode) class Session(object): @@ -307,10 +302,15 @@ class Session(object): method. """ - def __init__(self, driver, connection): + transaction = None + + def __init__(self, driver, connection, access_mode=None): self.driver = driver self.connection = connection - self.transaction = None + self.access_mode = access_mode + + def __del__(self): + self.close() def __enter__(self): return self @@ -318,34 +318,56 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self.close() - @property - def healthy(self): - """ Return ``True`` if this session is healthy, ``False`` if - unhealthy and ``None`` if closed. - """ - return self.connection.healthy - - def run(self, statement, parameters=None): - """ Run a parameterised Cypher statement. + def run(self, statement, parameters=None, **kwparameters): + """ Run a parameterised Cypher statement. If an explicit transaction + has been created, the statement will be executed within that + transactional context. Otherwise, this will take place within an + auto-commit transaction. :param statement: Cypher statement to execute :param parameters: dictionary of parameters :return: Cypher result :rtype: :class:`.StatementResult` """ - if self.transaction: - raise ProtocolError("Statements cannot be run directly on a session with an open transaction;" - " either run from within the transaction or use a different session.") - return run(self.connection, statement, parameters) + statement = _norm_statement(statement) + parameters = _norm_parameters(parameters, **kwparameters) + + run_response = Response(self.connection) + pull_all_response = Response(self.connection) + result = StatementResult(self, run_response, pull_all_response) + result.statement = statement + result.parameters = parameters + + self.connection.append(RUN, (statement, parameters), response=run_response) + self.connection.append(PULL_ALL, response=pull_all_response) + self.connection.send() + + return result + + def fetch(self): + try: + return self.connection.fetch() + except ServiceUnavailable as cause: + self.connection.in_use = False + self.connection = None + if self.access_mode: + exception = SessionExpired(self, "Session %r is no longer valid for " + "%r work" % (self, self.access_mode)) + exception.__cause__ = cause + raise exception + else: + raise def close(self): - """ Recycle this session through the driver it came from. + """ Close the session. """ - if self.connection and not self.connection.closed: - self.connection.fetch_all() if self.transaction: self.transaction.close() - self.driver.recycle(self) + if self.connection: + if not self.connection.closed: + self.connection.fetch_all() + self.connection.in_use = False + self.connection = None def begin_transaction(self): """ Create a new :class:`.Transaction` within this session. @@ -353,15 +375,21 @@ def begin_transaction(self): :return: new :class:`.Transaction` instance. """ if self.transaction: - raise ProtocolError("You cannot begin a transaction on a session with an open transaction;" - " either run from within the transaction or use a different session.") + raise TransactionError("Explicit transaction already open") def clear_transaction(): self.transaction = None - self.transaction = Transaction(self.connection, on_close=clear_transaction) + self.run("BEGIN") + self.transaction = Transaction(self, on_close=clear_transaction) return self.transaction + def commit_transaction(self): + self.run("COMMIT") + + def rollback_transaction(self): + self.run("ROLLBACK") + class Transaction(object): """ Container for multiple Cypher queries to be executed within @@ -384,10 +412,9 @@ class Transaction(object): #: with commit or rollback. closed = False - def __init__(self, connection, on_close): - self.connection = connection + def __init__(self, session, on_close): + self.session = session self.on_close = on_close - run(self.connection, "BEGIN") def __enter__(self): return self @@ -397,7 +424,7 @@ def __exit__(self, exc_type, exc_value, traceback): self.success = False self.close() - def run(self, statement, parameters=None): + def run(self, statement, parameters=None, **kwparameters): """ Run a Cypher statement within the context of this transaction. :param statement: Cypher statement @@ -405,7 +432,7 @@ def run(self, statement, parameters=None): :return: result object """ assert not self.closed - return run(self.connection, statement, parameters) + return self.session.run(statement, parameters, **kwparameters) def commit(self): """ Mark this transaction as successful and close in order to @@ -426,13 +453,140 @@ def close(self): """ assert not self.closed if self.success: - run(self.connection, "COMMIT") + self.session.commit_transaction() else: - run(self.connection, "ROLLBACK") + self.session.rollback_transaction() self.closed = True self.on_close() +class StatementResult(object): + """ A handler for the result of Cypher statement execution. + """ + + #: The statement text that was executed to produce this result. + statement = None + + #: Dictionary of parameters passed with the statement. + parameters = None + + def __init__(self, session, run_response, pull_all_response): + super(StatementResult, self).__init__() + + # The Session behind this result. When all data has been + # received, this is set to :const:`None` and can therefore + # be used as a "consumed" indicator. + self.session = session + + # The keys for the records in the result stream. These are + # lazily populated on request. + self._keys = None + + # Buffer for incoming records to be queued before yielding. If + # the result is used immediately, this buffer will be ignored. + self._buffer = deque() + + # The result summary (populated after the records have been + # fully consumed). + self._summary = None + + def on_header(metadata): + # Called on receipt of the result header. + self._keys = metadata["fields"] + + def on_record(values): + # Called on receipt of each result record. + self._buffer.append(values) + + def on_footer(metadata): + # Called on receipt of the result footer. + self._summary = ResultSummary(self.statement, self.parameters, **metadata) + self.session = None + + def on_failure(metadata): + # Called on execution failure. + self.session.connection.acknowledge_failure() + self.session = None + raise CypherError(metadata) + + run_response.on_success = on_header + run_response.on_failure = on_failure + + pull_all_response.on_record = on_record + pull_all_response.on_success = on_footer + pull_all_response.on_failure = on_failure + + def __iter__(self): + while self._buffer: + values = self._buffer.popleft() + yield Record(self.keys(), tuple(map(hydrated, values))) + while self.online(): + self.session.fetch() + while self._buffer: + values = self._buffer.popleft() + yield Record(self.keys(), tuple(map(hydrated, values))) + + def online(self): + """ True if this result is still attached to an active Session. + """ + return self.session and not self.session.connection.closed + + def keys(self): + """ Return the keys for the records. + """ + # Fetch messages until we have the header or a failure + while self._keys is None and self.online(): + self.session.fetch() + return tuple(self._keys) + + def buffer(self): + """ Fetch the remainder of the result from the network and buffer + it for future consumption. + """ + while self.online(): + self.session.fetch() + + def consume(self): + """ Consume the remainder of this result and return the summary. + """ + if self.online(): + list(self) + return self._summary + + def summary(self): + """ Return the summary, buffering any remaining records. + """ + self.buffer() + return self._summary + + def single(self): + """ Return the next record, failing if none or more than one remain. + """ + records = list(self) + num_records = len(records) + if num_records == 0: + raise ResultError("Cannot retrieve a single record, because this result is empty.") + elif num_records != 1: + raise ResultError("Expected a result with a single record, but this result contains " + "at least one more.") + else: + return records[0] + + def peek(self): + """ Return the next record without advancing the cursor. Fails + if no records remain. + """ + if self._buffer: + values = self._buffer[0] + return Record(self.keys(), tuple(map(hydrated, values))) + while not self._buffer and self.online(): + self.session.fetch() + if self._buffer: + values = self._buffer[0] + return Record(self.keys(), tuple(map(hydrated, values))) + raise ResultError("End of stream") + + class Record(object): """ Record is an ordered collection of fields. @@ -521,48 +675,40 @@ def basic_auth(user, password): return AuthToken("basic", user, password) -def run(connection, statement, parameters=None): - """ Run a Cypher statement on a given connection. - - :param connection: connection to carry the request and response - :param statement: Cypher statement - :param parameters: optional dictionary of parameters - :return: statement result +def parse_address(address): + """ Convert an address string to a tuple. """ - # Ensure the statement is a Unicode value - if isinstance(statement, bytes): - statement = statement.decode("UTF-8") - - params = {} - for key, value in (parameters or {}).items(): - if isinstance(key, bytes): - key = key.decode("UTF-8") - if isinstance(value, bytes): - params[key] = value.decode("UTF-8") - else: - params[key] = value - parameters = params - - run_response = Response(connection) - pull_all_response = Response(connection) - result = StatementResult(connection, run_response, pull_all_response) - result.statement = statement - result.parameters = parameters - - connection.append(RUN, (statement, parameters), response=run_response) - connection.append(PULL_ALL, response=pull_all_response) - connection.send() - - return result + host, _, port = address.partition(":") + return host, int(port) _warned_about_insecure_default = False -def _warn_about_insecure_default(): +def _encryption_default(): global _warned_about_insecure_default if not SSL_AVAILABLE and not _warned_about_insecure_default: - from warnings import warn warn("Bolt over TLS is only available in Python 2.7.9+ and Python 3.3+ " "so communications are not secure") _warned_about_insecure_default = True + return ENCRYPTION_DEFAULT + + +def _norm_statement(statement): + if isinstance(statement, bytes): + statement = statement.decode("UTF-8") + return statement + + +def _norm_parameters(parameters=None, **kwparameters): + params_in = parameters or {} + params_in.update(kwparameters) + params_out = {} + for key, value in params_in.items(): + if isinstance(key, bytes): + key = key.decode("UTF-8") + if isinstance(value, bytes): + params_out[key] = value.decode("UTF-8") + else: + params_out[key] = value + return params_out diff --git a/neo4j/v1/summary.py b/neo4j/v1/summary.py index f6fabfbb..9bf6c0ec 100644 --- a/neo4j/v1/summary.py +++ b/neo4j/v1/summary.py @@ -59,6 +59,7 @@ class ResultSummary(object): def __init__(self, statement, parameters, **metadata): self.statement = statement self.parameters = parameters + self.metadata = metadata self.statement_type = metadata.get("type") self.counters = SummaryCounters(metadata.get("stats", {})) if "plan" in metadata: diff --git a/neokit b/neokit deleted file mode 160000 index 6506ce94..00000000 --- a/neokit +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 6506ce944a47f78e071f49b600df588a3fe91e28 diff --git a/runexamples.sh b/runexamples.sh new file mode 100755 index 00000000..48e78b6a --- /dev/null +++ b/runexamples.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +RUN=$(dirname "$0")/examples/run/ + +# Export DIST_HOST=localhost if local web server hosts server packages +neotest -e 3.0.6:3.1.0-M09 ${RUN} python -m unittest discover -vs examples diff --git a/example.py b/runtck.sh old mode 100644 new mode 100755 similarity index 50% rename from example.py rename to runtck.sh index 4d328243..bedd973a --- a/example.py +++ b/runtck.sh @@ -1,5 +1,4 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- +#!/usr/bin/env bash # Copyright (c) 2002-2016 "Neo Technology," # Network Engine for Objects in Lund AB [http://neotechnology.com] @@ -18,22 +17,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from neo4j.v1.session import GraphDatabase, basic_auth - - -driver = GraphDatabase.driver("bolt://localhost", auth=basic_auth("neo4j", "neo4j")) -session = driver.session() - -session.run("MERGE (a:Person {name:'Alice'})") - -friends = ["Bob", "Carol", "Dave", "Eve", "Frank"] -with session.begin_transaction() as tx: - for friend in friends: - tx.run("MATCH (a:Person {name:'Alice'}) " - "MERGE (a)-[:KNOWS]->(x:Person {name:{n}})", {"n": friend}) - tx.success = True - -for friend, in session.run("MATCH (a:Person {name:'Alice'})-[:KNOWS]->(x) RETURN x"): - print('Alice says, "hello, %s"' % friend["name"]) - -session.close() +python -c "from tck.configure_feature_files import *; set_up()" +neotest 3.1.0-M09 $(dirname "$0")/tck/run/ behave --format=progress --tags=-db --tags=-tls --tags=-fixed_session_pool tck +python -c "from tck.configure_feature_files import *; clean_up()" diff --git a/runtests.py b/runtests.py index f78de346..7e32c619 100644 --- a/runtests.py +++ b/runtests.py @@ -96,7 +96,7 @@ def main(): stdout.write("Using python version:\n") runcommand('python --version') - runpymodule('pip install --upgrade -r ./test_requirements.txt') + runpymodule('pip install --upgrade -r ./test/requirements.txt') retcode = 0 register(neorun, '--stop=' + NEO4J_HOME) diff --git a/runtests.sh b/runtests.sh index b2698e98..48e16910 100755 --- a/runtests.sh +++ b/runtests.sh @@ -1,25 +1,13 @@ #!/usr/bin/env bash -# Copyright (c) 2002-2016 "Neo Technology," -# Network Engine for Objects in Lund AB [http://neotechnology.com] -# -# This file is part of Neo4j. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +RUN=$(dirname "$0")/test/run/ -if [ "$1" == "" ]; then - python ./runtests.py --tests --examples --tck +# Export DIST_HOST=localhost if local web server hosts server packages +if [ -z $1 ] +then + # Full test (with coverage) + neotest -e 3.0.7:3.1.0-M13-beta3 ${RUN} coverage run --source neo4j -m unittest discover -vs test && coverage report --show-missing else - #Example: NEORUN_START_ARGS="-n 3.1 -p neo4j" python ./runtests.py --tests --examples --tck - NEORUN_START_ARGS="$1" python ./runtests.py --tests --examples --tck + # Partial test + neotest -e 3.0.7:3.1.0-M13-beta3 ${RUN} python -m unittest -v $1 fi diff --git a/test/tck/__init__.py b/tck/__init__.py similarity index 100% rename from test/tck/__init__.py rename to tck/__init__.py diff --git a/test/tck/configure_feature_files.py b/tck/configure_feature_files.py similarity index 100% rename from test/tck/configure_feature_files.py rename to tck/configure_feature_files.py diff --git a/test/tck/environment.py b/tck/environment.py similarity index 98% rename from test/tck/environment.py rename to tck/environment.py index f6402578..f16412ad 100644 --- a/test/tck/environment.py +++ b/tck/environment.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from test.tck import tck_util +from tck import tck_util failing_features = {} diff --git a/test_requirements.txt b/tck/requirements.txt similarity index 100% rename from test_requirements.txt rename to tck/requirements.txt diff --git a/test/tck/resultparser.py b/tck/resultparser.py similarity index 100% rename from test/tck/resultparser.py rename to tck/resultparser.py diff --git a/tck/steps/__init__.py b/tck/steps/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/tck/steps/bolt_compability_steps.py b/tck/steps/bolt_compability_steps.py similarity index 97% rename from test/tck/steps/bolt_compability_steps.py rename to tck/steps/bolt_compability_steps.py index e56af162..46a75a4c 100644 --- a/test/tck/steps/bolt_compability_steps.py +++ b/tck/steps/bolt_compability_steps.py @@ -18,17 +18,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import random import string -import copy from behave import * - -from test.tck import tck_util -from test.tck.resultparser import parse_values -from test.tck.tck_util import to_unicode, Type, string_to_type +from tck.resultparser import parse_values from neo4j.v1 import compat +from tck import tck_util +from tck.tck_util import to_unicode, Type, string_to_type + use_step_matcher("re") diff --git a/test/tck/steps/cypher_compability_steps.py b/tck/steps/cypher_compability_steps.py similarity index 96% rename from test/tck/steps/cypher_compability_steps.py rename to tck/steps/cypher_compability_steps.py index 1231b663..541b71b4 100644 --- a/test/tck/steps/cypher_compability_steps.py +++ b/tck/steps/cypher_compability_steps.py @@ -19,9 +19,9 @@ # limitations under the License. from behave import * +from tck.resultparser import parse_values -from test.tck import tck_util -from test.tck.resultparser import parse_values +from tck import tck_util use_step_matcher("re") diff --git a/test/tck/steps/driver_auth_steps.py b/tck/steps/driver_auth_steps.py similarity index 84% rename from test/tck/steps/driver_auth_steps.py rename to tck/steps/driver_auth_steps.py index edc9d266..3079e4d3 100644 --- a/test/tck/steps/driver_auth_steps.py +++ b/tck/steps/driver_auth_steps.py @@ -21,21 +21,22 @@ from behave import * from neo4j.v1 import GraphDatabase, basic_auth, exceptions +from tck.tck_util import BOLT_URI, AUTH_TOKEN @given("a driver configured with auth disabled") def step_impl(context): - context.driver = GraphDatabase.driver("bolt://localhost", encrypted=False) + context.driver = GraphDatabase.driver(BOLT_URI, encrypted=False) @given("a driver is configured with auth enabled and correct password is provided") def step_impl(context): - context.driver = GraphDatabase.driver("bolt://localhost", auth=basic_auth("neo4j", "neo4j"), encrypted=False) + context.driver = GraphDatabase.driver(BOLT_URI, auth=AUTH_TOKEN, encrypted=False) @given("a driver is configured with auth enabled and the wrong password is provided") def step_impl(context): - context.driver = GraphDatabase.driver("bolt://localhost", auth=basic_auth("neo4j", "wrong"), encrypted=False) + context.driver = GraphDatabase.driver(BOLT_URI, auth=basic_auth("neo4j", "wrong"), encrypted=False) @step("reading and writing to the database should be possible") diff --git a/test/tck/steps/driver_equality_steps.py b/tck/steps/driver_equality_steps.py similarity index 97% rename from test/tck/steps/driver_equality_steps.py rename to tck/steps/driver_equality_steps.py index 98c0b602..437c3b65 100644 --- a/test/tck/steps/driver_equality_steps.py +++ b/tck/steps/driver_equality_steps.py @@ -20,7 +20,7 @@ from behave import * -from test.tck.tck_util import send_string +from tck.tck_util import send_string use_step_matcher("re") diff --git a/test/tck/steps/driver_result_api_steps.py b/tck/steps/driver_result_api_steps.py similarity index 96% rename from test/tck/steps/driver_result_api_steps.py rename to tck/steps/driver_result_api_steps.py index ddcc606a..52a7176c 100644 --- a/test/tck/steps/driver_result_api_steps.py +++ b/tck/steps/driver_result_api_steps.py @@ -20,10 +20,9 @@ from behave import * -from neo4j.v1 import STATEMENT_TYPE_READ_ONLY, STATEMENT_TYPE_READ_WRITE, STATEMENT_TYPE_WRITE_ONLY, \ - STATEMENT_TYPE_SCHEMA_WRITE - -from test.tck.resultparser import parse_values +from neo4j.v1.summary import STATEMENT_TYPE_READ_ONLY, STATEMENT_TYPE_READ_WRITE, \ + STATEMENT_TYPE_WRITE_ONLY, STATEMENT_TYPE_SCHEMA_WRITE +from tck.resultparser import parse_values use_step_matcher("re") diff --git a/test/tck/steps/error_reporting_steps.py b/tck/steps/error_reporting_steps.py similarity index 98% rename from test/tck/steps/error_reporting_steps.py rename to tck/steps/error_reporting_steps.py index cb19fcca..81bfcd24 100644 --- a/test/tck/steps/error_reporting_steps.py +++ b/tck/steps/error_reporting_steps.py @@ -20,10 +20,9 @@ from behave import * -from neo4j.v1.exceptions import ProtocolError, CypherError -from test.tck import tck_util - from neo4j.v1 import GraphDatabase +from neo4j.v1.exceptions import ProtocolError, CypherError +from tck import tck_util use_step_matcher("re") diff --git a/test/tck/steps/statement_result.py b/tck/steps/statement_result.py similarity index 98% rename from test/tck/steps/statement_result.py rename to tck/steps/statement_result.py index 57a4aaf4..9ac3802e 100644 --- a/test/tck/steps/statement_result.py +++ b/tck/steps/statement_result.py @@ -19,12 +19,11 @@ # limitations under the License. from behave import * +from tck.resultparser import parse_values_to_comparable from neo4j.v1 import Record, ResultSummary -from test.tck import tck_util - from neo4j.v1.exceptions import ResultError -from test.tck.resultparser import parse_values_to_comparable +from tck import tck_util use_step_matcher("re") diff --git a/test/tck/tck_util.py b/tck/tck_util.py similarity index 93% rename from test/tck/tck_util.py rename to tck/tck_util.py index 063e0635..d6106563 100644 --- a/test/tck/tck_util.py +++ b/tck/tck_util.py @@ -19,11 +19,15 @@ # limitations under the License. +from tck.test_value import TestValue + from neo4j.v1 import GraphDatabase, basic_auth -from test.tck.test_value import TestValue -from test.tck.resultparser import parse_values_to_comparable +from tck.resultparser import parse_values_to_comparable + +BOLT_URI = "bolt://localhost:7687" +AUTH_TOKEN = basic_auth("neotest", "neotest") -driver = GraphDatabase.driver("bolt://localhost", auth=basic_auth("neo4j", "neo4j"), encrypted=False) +driver = GraphDatabase.driver(BOLT_URI, auth=AUTH_TOKEN, encrypted=False) runners = [] diff --git a/test/tck/test_value.py b/tck/test_value.py similarity index 100% rename from test/tck/test_value.py rename to tck/test_value.py diff --git a/test/requirements.txt b/test/requirements.txt new file mode 100644 index 00000000..2dc2d80b --- /dev/null +++ b/test/requirements.txt @@ -0,0 +1,3 @@ +coverage +mock +teamcity-messages diff --git a/test/resources/bad_router.script b/test/resources/bad_router.script new file mode 100644 index 00000000..454810fb --- /dev/null +++ b/test/resources/bad_router.script @@ -0,0 +1,8 @@ +!: AUTO INIT +!: AUTO RESET + +C: RUN "CALL dbms.cluster.routing.getServers" {} + PULL_ALL +S: SUCCESS {"fields": ["ttl", "servers"]} + RECORD [300, [{"role":"ROUTE","addresses":[]},{"role":"READ","addresses":[]},{"role":"WRITE","addresses":[]}]] + SUCCESS {} diff --git a/test/resources/create_a.script b/test/resources/create_a.script new file mode 100644 index 00000000..a57cd269 --- /dev/null +++ b/test/resources/create_a.script @@ -0,0 +1,7 @@ +!: AUTO INIT +!: AUTO RESET + +C: RUN "CREATE (a $x)" {"x": {"name": "Alice"}} + PULL_ALL +S: SUCCESS {"fields": []} + SUCCESS {} diff --git a/test/resources/disconnect_on_pull_all.script b/test/resources/disconnect_on_pull_all.script new file mode 100644 index 00000000..5865a61d --- /dev/null +++ b/test/resources/disconnect_on_pull_all.script @@ -0,0 +1,6 @@ +!: AUTO INIT +!: AUTO RESET + +C: RUN "RETURN $x" {"x": 1} + PULL_ALL +S: diff --git a/test/resources/disconnect_on_run.script b/test/resources/disconnect_on_run.script new file mode 100644 index 00000000..35c82766 --- /dev/null +++ b/test/resources/disconnect_on_run.script @@ -0,0 +1,5 @@ +!: AUTO INIT +!: AUTO RESET + +C: RUN "RETURN $x" {"x": 1} +S: diff --git a/test/resources/non_router.script b/test/resources/non_router.script new file mode 100644 index 00000000..26ca2638 --- /dev/null +++ b/test/resources/non_router.script @@ -0,0 +1,9 @@ +!: AUTO INIT +!: AUTO RESET + +C: RUN "CALL dbms.cluster.routing.getServers" {} + PULL_ALL +S: FAILURE {"code": "Neo.ClientError.Procedure.ProcedureNotFound", "message": "Not a router"} + IGNORED +C: ACK_FAILURE +S: SUCCESS {} diff --git a/test/resources/return_1.script b/test/resources/return_1.script new file mode 100644 index 00000000..7d12a1bd --- /dev/null +++ b/test/resources/return_1.script @@ -0,0 +1,8 @@ +!: AUTO INIT +!: AUTO RESET + +C: RUN "RETURN $x" {"x": 1} + PULL_ALL +S: SUCCESS {"fields": ["x"]} + RECORD [1] + SUCCESS {} diff --git a/test/resources/router.script b/test/resources/router.script new file mode 100644 index 00000000..208831ca --- /dev/null +++ b/test/resources/router.script @@ -0,0 +1,8 @@ +!: AUTO INIT +!: AUTO RESET + +C: RUN "CALL dbms.cluster.routing.getServers" {} + PULL_ALL +S: SUCCESS {"fields": ["ttl", "servers"]} + RECORD [300, [{"role":"ROUTE","addresses":["127.0.0.1:9001","127.0.0.1:9002","127.0.0.1:9003"]},{"role":"READ","addresses":["127.0.0.1:9004","127.0.0.1:9005"]},{"role":"WRITE","addresses":["127.0.0.1:9006"]}]] + SUCCESS {} diff --git a/test/resources/silent_router.script b/test/resources/silent_router.script new file mode 100644 index 00000000..a315ae3a --- /dev/null +++ b/test/resources/silent_router.script @@ -0,0 +1,7 @@ +!: AUTO INIT +!: AUTO RESET + +C: RUN "CALL dbms.cluster.routing.getServers" {} + PULL_ALL +S: SUCCESS {"fields": ["ttl", "servers"]} + SUCCESS {} diff --git a/test/test_connection.py b/test/test_connection.py new file mode 100644 index 00000000..ff44c322 --- /dev/null +++ b/test/test_connection.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +# Copyright (c) 2002-2016 "Neo Technology," +# Network Engine for Objects in Lund AB [http://neotechnology.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from socket import create_connection + +from neo4j.v1 import basic_auth, ConnectionRouter, ConnectionPool, connect, ServiceUnavailable + +from test.util import ServerTestCase + + +class QuickConnection(object): + + closed = False + defunct = False + + def __init__(self, socket): + self.socket = socket + self.address = socket.getpeername() + + def reset(self): + pass + + def close(self): + self.socket.close() + + +class ConnectionPoolTestCase(ServerTestCase): + + def setUp(self): + self.pool = ConnectionPool(lambda a: QuickConnection(create_connection(a))) + + def tearDown(self): + self.pool.close() + + def assert_pool_size(self, address, expected_active, expected_inactive): + try: + connections = self.pool.connections[address] + except KeyError: + assert 0 == expected_active + assert 0 == expected_inactive + else: + assert len([c for c in connections if c.in_use]) == expected_active + assert len([c for c in connections if not c.in_use]) == expected_inactive + + def test_can_acquire(self): + address = ("127.0.0.1", 7687) + connection = self.pool.acquire(address) + assert connection.address == address + self.assert_pool_size(address, 1, 0) + + def test_can_acquire_twice(self): + address = ("127.0.0.1", 7687) + connection_1 = self.pool.acquire(address) + connection_2 = self.pool.acquire(address) + assert connection_1.address == address + assert connection_2.address == address + assert connection_1 is not connection_2 + self.assert_pool_size(address, 2, 0) + + def test_can_acquire_two_addresses(self): + address_1 = ("127.0.0.1", 7687) + address_2 = ("127.0.0.1", 7474) + connection_1 = self.pool.acquire(address_1) + connection_2 = self.pool.acquire(address_2) + assert connection_1.address == address_1 + assert connection_2.address == address_2 + self.assert_pool_size(address_1, 1, 0) + self.assert_pool_size(address_2, 1, 0) + + def test_can_acquire_and_release(self): + address = ("127.0.0.1", 7687) + connection = self.pool.acquire(address) + self.assert_pool_size(address, 1, 0) + self.pool.release(connection) + self.assert_pool_size(address, 0, 1) + + def test_releasing_twice(self): + address = ("127.0.0.1", 7687) + connection = self.pool.acquire(address) + self.pool.release(connection) + self.assert_pool_size(address, 0, 1) + self.pool.release(connection) + self.assert_pool_size(address, 0, 1) + + +class RouterTestCase(ServerTestCase): + + def setUp(self): + self.pool = ConnectionPool(lambda a: connect(a, auth=basic_auth("neo4j", "password"))) + + def tearDown(self): + self.pool.close() + + def test_router_is_initially_stale(self): + router = ConnectionRouter(self.pool, ("127.0.0.1", 7687)) + assert router.stale() + + def test_discovery(self): + self.start_stub_server(9001, "router.script") + router = ConnectionRouter(self.pool, ("127.0.0.1", 9001)) + router.timer = lambda: 0 + router.discover() + assert router.expiry_time == 300 + assert router.routers == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003)} + assert router.readers == {('127.0.0.1', 9004), ('127.0.0.1', 9005)} + assert router.writers == {('127.0.0.1', 9006)} + + def test_discovery_after_bad_discovery(self): + self.start_stub_server(9001, "bad_router.script") + self.start_stub_server(9002, "router.script") + router = ConnectionRouter(self.pool, ("127.0.0.1", 9001), ("127.0.0.1", 9002)) + router.timer = lambda: 0 + router.discover() + assert router.expiry_time == 300 + assert router.routers == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003)} + assert router.readers == {('127.0.0.1', 9004), ('127.0.0.1', 9005)} + assert router.writers == {('127.0.0.1', 9006)} + + def test_discovery_against_non_router(self): + self.start_stub_server(9001, "non_router.script") + router = ConnectionRouter(self.pool, ("127.0.0.1", 9001)) + with self.assertRaises(ServiceUnavailable): + router.discover() + + def test_running_out_of_good_routers_on_discovery(self): + self.start_stub_server(9001, "bad_router.script") + self.start_stub_server(9002, "bad_router.script") + self.start_stub_server(9003, "bad_router.script") + router = ConnectionRouter(self.pool, ("127.0.0.1", 9001), ("127.0.0.1", 9002), ("127.0.0.1", 9003)) + with self.assertRaises(ServiceUnavailable): + router.discover() diff --git a/test/test_driver.py b/test/test_driver.py new file mode 100644 index 00000000..de17503f --- /dev/null +++ b/test/test_driver.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +# Copyright (c) 2002-2016 "Neo Technology," +# Network Engine for Objects in Lund AB [http://neotechnology.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from socket import socket +from ssl import SSLSocket +from unittest import skipUnless + +from neo4j.v1 import ServiceUnavailable, ProtocolError, READ_ACCESS, WRITE_ACCESS, \ + TRUST_ON_FIRST_USE, TRUST_CUSTOM_CA_SIGNED_CERTIFICATES, GraphDatabase, basic_auth, \ + SSL_AVAILABLE, SessionExpired, DirectDriver +from test.util import ServerTestCase + +BOLT_URI = "bolt://localhost:7687" +BOLT_ROUTING_URI = "bolt+routing://localhost:7687" +AUTH_TOKEN = basic_auth("neotest", "neotest") + + +class DriverTestCase(ServerTestCase): + + def test_driver_with_block(self): + with GraphDatabase.driver(BOLT_URI, auth=AUTH_TOKEN, encrypted=False) as driver: + assert isinstance(driver, DirectDriver) + + def test_must_use_valid_url_scheme(self): + with self.assertRaises(ProtocolError): + GraphDatabase.driver("x://xxx", auth=AUTH_TOKEN) + + def test_connections_are_reused(self): + driver = GraphDatabase.driver(BOLT_URI, auth=AUTH_TOKEN) + session_1 = driver.session() + connection_1 = session_1.connection + session_1.close() + session_2 = driver.session() + connection_2 = session_2.connection + session_2.close() + assert connection_1 is connection_2 + + def test_connections_are_not_shared_between_sessions(self): + driver = GraphDatabase.driver(BOLT_URI, auth=AUTH_TOKEN) + session_1 = driver.session() + session_2 = driver.session() + try: + assert session_1.connection is not session_2.connection + finally: + session_1.close() + session_2.close() + + def test_fail_nicely_when_connecting_to_http_port(self): + driver = GraphDatabase.driver("bolt://localhost:7474", auth=AUTH_TOKEN, encrypted=False) + with self.assertRaises(ServiceUnavailable) as context: + driver.session() + + +class DirectDriverTestCase(ServerTestCase): + + def tearDown(self): + self.await_all_servers() + + def test_direct_disconnect_on_run(self): + self.start_stub_server(9001, "disconnect_on_run.script") + uri = "bolt://127.0.0.1:9001" + driver = GraphDatabase.driver(uri, auth=basic_auth("neo4j", "password"), encrypted=False) + try: + with driver.session() as session: + with self.assertRaises(ServiceUnavailable): + session.run("RETURN $x", {"x": 1}).consume() + finally: + driver.close() + + def test_direct_disconnect_on_pull_all(self): + self.start_stub_server(9001, "disconnect_on_pull_all.script") + uri = "bolt://127.0.0.1:9001" + driver = GraphDatabase.driver(uri, auth=basic_auth("neo4j", "password"), encrypted=False) + try: + with driver.session() as session: + with self.assertRaises(ServiceUnavailable): + session.run("RETURN $x", {"x": 1}).consume() + finally: + driver.close() + + +class RoutingDriverTestCase(ServerTestCase): + + def tearDown(self): + self.await_all_servers() + + def test_cannot_discover_servers_on_non_router(self): + self.start_stub_server(9001, "non_router.script") + uri = "bolt+routing://127.0.0.1:9001" + with self.assertRaises(ServiceUnavailable): + GraphDatabase.driver(uri, auth=basic_auth("neo4j", "password"), encrypted=False) + + def test_cannot_discover_servers_on_silent_router(self): + self.start_stub_server(9001, "silent_router.script") + uri = "bolt+routing://127.0.0.1:9001" + with self.assertRaises(ServiceUnavailable): + GraphDatabase.driver(uri, auth=basic_auth("neo4j", "password"), encrypted=False) + + def test_should_discover_servers_on_driver_construction(self): + self.start_stub_server(9001, "router.script") + uri = "bolt+routing://127.0.0.1:9001" + driver = GraphDatabase.driver(uri, auth=basic_auth("neo4j", "password"), encrypted=False) + router = driver.router + assert router.routers == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003)} + assert router.readers == {('127.0.0.1', 9004), ('127.0.0.1', 9005)} + assert router.writers == {('127.0.0.1', 9006)} + + def test_should_be_able_to_read(self): + self.start_stub_server(9001, "router.script") + self.start_stub_server(9004, "return_1.script") + uri = "bolt+routing://127.0.0.1:9001" + driver = GraphDatabase.driver(uri, auth=basic_auth("neo4j", "password"), encrypted=False) + try: + with driver.session(READ_ACCESS) as session: + result = session.run("RETURN $x", {"x": 1}) + for record in result: + assert record["x"] == 1 + assert session.connection.address == ('127.0.0.1', 9004) + finally: + driver.close() + + def test_should_be_able_to_write(self): + self.start_stub_server(9001, "router.script") + self.start_stub_server(9006, "create_a.script") + uri = "bolt+routing://127.0.0.1:9001" + driver = GraphDatabase.driver(uri, auth=basic_auth("neo4j", "password"), encrypted=False) + try: + with driver.session(WRITE_ACCESS) as session: + result = session.run("CREATE (a $x)", {"x": {"name": "Alice"}}) + assert not list(result) + assert session.connection.address == ('127.0.0.1', 9006) + finally: + driver.close() + + def test_should_be_able_to_write_as_default(self): + self.start_stub_server(9001, "router.script") + self.start_stub_server(9006, "create_a.script") + uri = "bolt+routing://127.0.0.1:9001" + driver = GraphDatabase.driver(uri, auth=basic_auth("neo4j", "password"), encrypted=False) + try: + with driver.session() as session: + result = session.run("CREATE (a $x)", {"x": {"name": "Alice"}}) + assert not list(result) + assert session.connection.address == ('127.0.0.1', 9006) + finally: + driver.close() + + def test_routing_disconnect_on_run(self): + self.start_stub_server(9001, "router.script") + self.start_stub_server(9004, "disconnect_on_run.script") + uri = "bolt+routing://127.0.0.1:9001" + driver = GraphDatabase.driver(uri, auth=basic_auth("neo4j", "password"), encrypted=False) + try: + with driver.session(READ_ACCESS) as session: + with self.assertRaises(SessionExpired): + session.run("RETURN $x", {"x": 1}).consume() + finally: + driver.close() + + def test_routing_disconnect_on_pull_all(self): + self.start_stub_server(9001, "router.script") + self.start_stub_server(9004, "disconnect_on_pull_all.script") + uri = "bolt+routing://127.0.0.1:9001" + driver = GraphDatabase.driver(uri, auth=basic_auth("neo4j", "password"), encrypted=False) + try: + with driver.session(READ_ACCESS) as session: + with self.assertRaises(SessionExpired): + session.run("RETURN $x", {"x": 1}).consume() + finally: + driver.close() + + +class SecurityTestCase(ServerTestCase): + + def test_insecure_session_uses_normal_socket(self): + driver = GraphDatabase.driver(BOLT_URI, auth=AUTH_TOKEN, encrypted=False) + with driver.session() as session: + connection = session.connection + assert isinstance(connection.channel.socket, socket) + assert connection.der_encoded_server_certificate is None + + @skipUnless(SSL_AVAILABLE, "Bolt over TLS is not supported by this version of Python") + def test_tofu_session_uses_secure_socket(self): + driver = GraphDatabase.driver(BOLT_URI, auth=AUTH_TOKEN, encrypted=True, trust=TRUST_ON_FIRST_USE) + with driver.session() as session: + connection = session.connection + assert isinstance(connection.channel.socket, SSLSocket) + assert connection.der_encoded_server_certificate is not None + + @skipUnless(SSL_AVAILABLE, "Bolt over TLS is not supported by this version of Python") + def test_tofu_session_trusts_certificate_after_first_use(self): + driver = GraphDatabase.driver(BOLT_URI, auth=AUTH_TOKEN, encrypted=True, trust=TRUST_ON_FIRST_USE) + with driver.session() as session: + connection = session.connection + certificate = connection.der_encoded_server_certificate + with driver.session() as session: + connection = session.connection + assert connection.der_encoded_server_certificate == certificate + + def test_routing_driver_not_compatible_with_tofu(self): + with self.assertRaises(ValueError): + GraphDatabase.driver(BOLT_ROUTING_URI, auth=AUTH_TOKEN, trust=TRUST_ON_FIRST_USE) + + def test_custom_ca_not_implemented(self): + with self.assertRaises(NotImplementedError): + GraphDatabase.driver(BOLT_URI, auth=AUTH_TOKEN, + trust=TRUST_CUSTOM_CA_SIGNED_CERTIFICATES) diff --git a/test/test_session.py b/test/test_session.py index 09e785b4..b8d2022c 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -18,123 +18,26 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from socket import socket -from ssl import SSLSocket -from unittest import skipUnless - from mock import patch -from neo4j.v1.constants import TRUST_ON_FIRST_USE -from neo4j.v1.exceptions import CypherError, ProtocolError, ResultError -from neo4j.v1.session import GraphDatabase, basic_auth, Record, SSL_AVAILABLE +from neo4j.v1.exceptions import CypherError, ResultError +from neo4j.v1.session import GraphDatabase, basic_auth, Record from neo4j.v1.types import Node, Relationship, Path from test.util import ServerTestCase -auth_token = basic_auth("neo4j", "neo4j") +BOLT_URI = "bolt://localhost:7687" +AUTH_TOKEN = basic_auth("neotest", "neotest") -class DriverTestCase(ServerTestCase): +class AutoCommitTransactionTestCase(ServerTestCase): - def test_healthy_session_will_be_returned_to_the_pool_on_close(self): - driver = GraphDatabase.driver("bolt://localhost", auth=auth_token) - assert len(driver.session_pool) == 0 - driver.session().close() - assert len(driver.session_pool) == 1 - - def test_unhealthy_session_will_not_be_returned_to_the_pool_on_close(self): - driver = GraphDatabase.driver("bolt://localhost", auth=auth_token) - assert len(driver.session_pool) == 0 - session = driver.session() - session.connection.defunct = True - session.close() - assert len(driver.session_pool) == 0 - - def session_pool_cannot_exceed_max_size(self): - driver = GraphDatabase.driver("bolt://localhost", auth=auth_token, max_pool_size=1) - assert len(driver.session_pool) == 0 - driver.session().close() - assert len(driver.session_pool) == 1 - driver.session().close() - assert len(driver.session_pool) == 1 - - def test_session_that_dies_in_the_pool_will_not_be_given_out(self): - driver = GraphDatabase.driver("bolt://localhost", auth=auth_token) - session_1 = driver.session() - session_1.close() - assert len(driver.session_pool) == 1 - session_1.connection.close() - session_2 = driver.session() - assert session_2 is not session_1 - - def test_must_use_valid_url_scheme(self): - with self.assertRaises(ProtocolError): - GraphDatabase.driver("x://xxx", auth=auth_token) - - def test_sessions_are_reused(self): - driver = GraphDatabase.driver("bolt://localhost", auth=auth_token) - session_1 = driver.session() - session_1.close() - session_2 = driver.session() - session_2.close() - assert session_1 is session_2 - - def test_sessions_are_not_reused_if_still_in_use(self): - driver = GraphDatabase.driver("bolt://localhost", auth=auth_token) - session_1 = driver.session() - session_2 = driver.session() - session_2.close() - session_1.close() - assert session_1 is not session_2 - - def test_fail_nicely_when_connecting_to_http_port(self): - driver = GraphDatabase.driver("bolt://localhost:7474", auth=auth_token, encrypted=False) - with self.assertRaises(ProtocolError) as context: - driver.session() - - assert str(context.exception) == "Server responded HTTP. Make sure you are not trying to connect to the http " \ - "endpoint (HTTP defaults to port 7474 whereas BOLT defaults to port 7687)" - - - -class SecurityTestCase(ServerTestCase): - - def test_insecure_session_uses_normal_socket(self): - driver = GraphDatabase.driver("bolt://localhost", auth=auth_token, encrypted=False) - session = driver.session() - connection = session.connection - assert isinstance(connection.channel.socket, socket) - assert connection.der_encoded_server_certificate is None - session.close() - - @skipUnless(SSL_AVAILABLE, "Bolt over TLS is not supported by this version of Python") - def test_tofu_session_uses_secure_socket(self): - driver = GraphDatabase.driver("bolt://localhost", auth=auth_token, encrypted=True, trust=TRUST_ON_FIRST_USE) - session = driver.session() - connection = session.connection - assert isinstance(connection.channel.socket, SSLSocket) - assert connection.der_encoded_server_certificate is not None - session.close() - - @skipUnless(SSL_AVAILABLE, "Bolt over TLS is not supported by this version of Python") - def test_tofu_session_trusts_certificate_after_first_use(self): - driver = GraphDatabase.driver("bolt://localhost", auth=auth_token, encrypted=True, trust=TRUST_ON_FIRST_USE) - session = driver.session() - connection = session.connection - certificate = connection.der_encoded_server_certificate - session.close() - session = driver.session() - connection = session.connection - assert connection.der_encoded_server_certificate == certificate - session.close() - - -class RunTestCase(ServerTestCase): + def setUp(self): + self.driver = GraphDatabase.driver(BOLT_URI, auth=AUTH_TOKEN) def test_can_run_simple_statement(self): - session = GraphDatabase.driver("bolt://localhost", auth=auth_token).session() + session = self.driver.session() result = session.run("RETURN 1 AS n") for record in result: assert record[0] == 1 @@ -151,7 +54,7 @@ def test_can_run_simple_statement(self): session.close() def test_can_run_simple_statement_with_params(self): - session = GraphDatabase.driver("bolt://localhost", auth=auth_token).session() + session = self.driver.session() count = 0 for record in session.run("RETURN {x} AS n", {"x": {"abc": ["d", "e", "f"]}}): assert record[0] == {"abc": ["d", "e", "f"]} @@ -163,17 +66,17 @@ def test_can_run_simple_statement_with_params(self): assert count == 1 def test_fails_on_bad_syntax(self): - session = GraphDatabase.driver("bolt://localhost", auth=auth_token).session() + session = self.driver.session() with self.assertRaises(CypherError): session.run("X").consume() def test_fails_on_missing_parameter(self): - session = GraphDatabase.driver("bolt://localhost", auth=auth_token).session() + session = self.driver.session() with self.assertRaises(CypherError): session.run("RETURN {x}").consume() def test_can_run_simple_statement_from_bytes_string(self): - session = GraphDatabase.driver("bolt://localhost", auth=auth_token).session() + session = self.driver.session() count = 0 for record in session.run(b"RETURN 1 AS n"): assert record[0] == 1 @@ -185,7 +88,7 @@ def test_can_run_simple_statement_from_bytes_string(self): assert count == 1 def test_can_run_statement_that_returns_multiple_records(self): - session = GraphDatabase.driver("bolt://localhost", auth=auth_token).session() + session = self.driver.session() count = 0 for record in session.run("unwind(range(1, 10)) AS z RETURN z"): assert 1 <= record[0] <= 10 @@ -194,15 +97,15 @@ def test_can_run_statement_that_returns_multiple_records(self): assert count == 10 def test_can_use_with_to_auto_close_session(self): - with GraphDatabase.driver("bolt://localhost", auth=auth_token).session() as session: + with self.driver.session() as session: record_list = list(session.run("RETURN 1")) assert len(record_list) == 1 for record in record_list: assert record[0] == 1 def test_can_return_node(self): - with GraphDatabase.driver("bolt://localhost", auth=auth_token).session() as session: - record_list = list(session.run("MERGE (a:Person {name:'Alice'}) RETURN a")) + with self.driver.session() as session: + record_list = list(session.run("CREATE (a:Person {name:'Alice'}) RETURN a")) assert len(record_list) == 1 for record in record_list: alice = record[0] @@ -211,17 +114,17 @@ def test_can_return_node(self): assert alice.properties == {"name": "Alice"} def test_can_return_relationship(self): - with GraphDatabase.driver("bolt://localhost", auth=auth_token).session() as session: - reocrd_list = list(session.run("MERGE ()-[r:KNOWS {since:1999}]->() RETURN r")) - assert len(reocrd_list) == 1 - for record in reocrd_list: + with self.driver.session() as session: + record_list = list(session.run("MERGE ()-[r:KNOWS {since:1999}]->() RETURN r")) + assert len(record_list) == 1 + for record in record_list: rel = record[0] assert isinstance(rel, Relationship) assert rel.type == "KNOWS" assert rel.properties == {"since": 1999} def test_can_return_path(self): - with GraphDatabase.driver("bolt://localhost", auth=auth_token).session() as session: + with self.driver.session() as session: record_list = list(session.run("MERGE p=({name:'Alice'})-[:KNOWS]->({name:'Bob'}) RETURN p")) assert len(record_list) == 1 for record in record_list: @@ -234,19 +137,19 @@ def test_can_return_path(self): assert len(path.relationships) == 1 def test_can_handle_cypher_error(self): - with GraphDatabase.driver("bolt://localhost", auth=auth_token).session() as session: + with self.driver.session() as session: with self.assertRaises(CypherError): session.run("X").consume() def test_keys_are_available_before_and_after_stream(self): - with GraphDatabase.driver("bolt://localhost", auth=auth_token).session() as session: + with self.driver.session() as session: result = session.run("UNWIND range(1, 10) AS n RETURN n") assert list(result.keys()) == ["n"] list(result) assert list(result.keys()) == ["n"] def test_keys_with_an_error(self): - with GraphDatabase.driver("bolt://localhost", auth=auth_token).session() as session: + with self.driver.session() as session: result = session.run("X") with self.assertRaises(CypherError): list(result.keys()) @@ -254,59 +157,62 @@ def test_keys_with_an_error(self): class SummaryTestCase(ServerTestCase): + def setUp(self): + self.driver = GraphDatabase.driver(BOLT_URI, auth=AUTH_TOKEN) + def test_can_obtain_summary_after_consuming_result(self): - with GraphDatabase.driver("bolt://localhost", auth=auth_token).session() as session: + with self.driver.session() as session: result = session.run("CREATE (n) RETURN n") - summary = result.consume() + summary = result.summary() assert summary.statement == "CREATE (n) RETURN n" assert summary.parameters == {} assert summary.statement_type == "rw" assert summary.counters.nodes_created == 1 def test_no_plan_info(self): - with GraphDatabase.driver("bolt://localhost", auth=auth_token).session() as session: + with self.driver.session() as session: result = session.run("CREATE (n) RETURN n") - summary = result.consume() + summary = result.summary() assert summary.plan is None assert summary.profile is None def test_can_obtain_plan_info(self): - with GraphDatabase.driver("bolt://localhost", auth=auth_token).session() as session: + with self.driver.session() as session: result = session.run("EXPLAIN CREATE (n) RETURN n") - summary = result.consume() + summary = result.summary() plan = summary.plan assert plan.operator_type == "ProduceResults" assert plan.identifiers == ["n"] - assert plan.arguments == {"planner": "COST", "EstimatedRows": 1.0, "version": "CYPHER 3.0", - "KeyNames": "n", "runtime-impl": "INTERPRETED", "planner-impl": "IDP", - "runtime": "INTERPRETED"} + known_keys = ["planner", "EstimatedRows", "version", "KeyNames", "runtime-impl", + "planner-impl", "runtime"] + assert all(key in plan.arguments for key in known_keys) assert len(plan.children) == 1 def test_can_obtain_profile_info(self): - with GraphDatabase.driver("bolt://localhost", auth=auth_token).session() as session: + with self.driver.session() as session: result = session.run("PROFILE CREATE (n) RETURN n") - summary = result.consume() + summary = result.summary() profile = summary.profile assert profile.db_hits == 0 assert profile.rows == 1 assert profile.operator_type == "ProduceResults" assert profile.identifiers == ["n"] - assert profile.arguments == {"planner": "COST", "EstimatedRows": 1.0, "version": "CYPHER 3.0", - "KeyNames": "n", "runtime-impl": "INTERPRETED", "planner-impl": "IDP", - "runtime": "INTERPRETED", "Rows": 1, "DbHits": 0} + known_keys = ["planner", "EstimatedRows", "version", "KeyNames", "runtime-impl", + "planner-impl", "runtime", "Rows", "DbHits"] + assert all(key in profile.arguments for key in known_keys) assert len(profile.children) == 1 def test_no_notification_info(self): - with GraphDatabase.driver("bolt://localhost", auth=auth_token).session() as session: + with self.driver.session() as session: result = session.run("CREATE (n) RETURN n") - summary = result.consume() + summary = result.summary() notifications = summary.notifications assert notifications == [] def test_can_obtain_notification_info(self): - with GraphDatabase.driver("bolt://localhost", auth=auth_token).session() as session: + with self.driver.session() as session: result = session.run("EXPLAIN MATCH (n), (m) RETURN n, m") - summary = result.consume() + summary = result.summary() notifications = summary.notifications assert len(notifications) == 1 @@ -333,8 +239,11 @@ def test_can_obtain_notification_info(self): class ResetTestCase(ServerTestCase): + def setUp(self): + self.driver = GraphDatabase.driver(BOLT_URI, auth=AUTH_TOKEN) + def test_automatic_reset_after_failure(self): - with GraphDatabase.driver("bolt://localhost", auth=auth_token).session() as session: + with self.driver.session() as session: try: session.run("X").consume() except CypherError: @@ -346,7 +255,7 @@ def test_automatic_reset_after_failure(self): def test_defunct(self): from neo4j.v1.bolt import BufferingSocket, ProtocolError - with GraphDatabase.driver("bolt://localhost", auth=auth_token).session() as session: + with self.driver.session() as session: assert not session.connection.defunct with patch.object(BufferingSocket, "fill", side_effect=ProtocolError()): with self.assertRaises(ProtocolError): @@ -425,10 +334,10 @@ def test_record_repr(self): assert repr(a_record) == "" -class TransactionTestCase(ServerTestCase): +class ExplicitTransactionTestCase(ServerTestCase): def test_can_commit_transaction(self): - with GraphDatabase.driver("bolt://localhost", auth=auth_token).session() as session: + with GraphDatabase.driver(BOLT_URI, auth=AUTH_TOKEN).session() as session: tx = session.begin_transaction() # Create a node @@ -451,7 +360,7 @@ def test_can_commit_transaction(self): assert value == "bar" def test_can_rollback_transaction(self): - with GraphDatabase.driver("bolt://localhost", auth=auth_token).session() as session: + with GraphDatabase.driver(BOLT_URI, auth=AUTH_TOKEN).session() as session: tx = session.begin_transaction() # Create a node @@ -472,7 +381,7 @@ def test_can_rollback_transaction(self): assert len(list(result)) == 0 def test_can_commit_transaction_using_with_block(self): - with GraphDatabase.driver("bolt://localhost", auth=auth_token).session() as session: + with GraphDatabase.driver(BOLT_URI, auth=AUTH_TOKEN).session() as session: with session.begin_transaction() as tx: # Create a node result = tx.run("CREATE (a) RETURN id(a)") @@ -494,7 +403,7 @@ def test_can_commit_transaction_using_with_block(self): assert value == "bar" def test_can_rollback_transaction_using_with_block(self): - with GraphDatabase.driver("bolt://localhost", auth=auth_token).session() as session: + with GraphDatabase.driver(BOLT_URI, auth=AUTH_TOKEN).session() as session: with session.begin_transaction() as tx: # Create a node result = tx.run("CREATE (a) RETURN id(a)") @@ -515,7 +424,7 @@ def test_can_rollback_transaction_using_with_block(self): class ResultConsumptionTestCase(ServerTestCase): def setUp(self): - self.driver = GraphDatabase.driver("bolt://localhost", auth=auth_token, encrypted=False) + self.driver = GraphDatabase.driver(BOLT_URI, auth=AUTH_TOKEN, encrypted=False) def test_can_consume_result_immediately(self): session = self.driver.session() @@ -607,14 +516,14 @@ def test_single_consumes_entire_result_if_one_record(self): session = self.driver.session() result = session.run("UNWIND range(1, 1) AS n RETURN n") _ = result.single() - assert result._consumed + assert not result.online() def test_single_consumes_entire_result_if_multiple_records(self): session = self.driver.session() result = session.run("UNWIND range(1, 3) AS n RETURN n") with self.assertRaises(ResultError): _ = result.single() - assert result._consumed + assert not result.online() def test_peek_can_look_one_ahead(self): session = self.driver.session() diff --git a/test/util.py b/test/util.py index 5d7b230a..dacd7d0e 100644 --- a/test/util.py +++ b/test/util.py @@ -21,9 +21,10 @@ import functools from os import getenv, remove, rename -from os.path import isfile +from os.path import isfile, dirname, join as path_join from socket import create_connection from subprocess import check_call, CalledProcessError +from threading import Thread from time import sleep from unittest import TestCase @@ -84,6 +85,7 @@ class ServerTestCase(TestCase): known_hosts = KNOWN_HOSTS known_hosts_backup = known_hosts + ".backup" + servers = [] def setUp(self): if isfile(self.known_hosts): @@ -96,3 +98,25 @@ def tearDown(self): if isfile(self.known_hosts): remove(self.known_hosts) rename(self.known_hosts_backup, self.known_hosts) + + def start_stub_server(self, port, script): + server = StubServer(port, script) + server.start() + sleep(0.5) + self.servers.append(server) + + def await_all_servers(self): + while self.servers: + server = self.servers.pop() + server.join() + + +class StubServer(Thread): + + def __init__(self, port, script): + super(StubServer, self).__init__() + self.port = port + self.script = path_join(dirname(__file__), "resources", script) + + def run(self): + check_call(["boltstub", str(self.port), self.script])