diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 00000000..823ad91d --- /dev/null +++ b/.editorconfig @@ -0,0 +1,22 @@ +# top-most EditorConfig file +root = true + +# Unix-style newlines with a newline ending every file +[*] +end_of_line = lf +insert_final_newline = true +charset = utf-8 + +[*.{py,js,rst,txt,sh,bat}] +trim_trailing_whitespace = true + +[{Makefile,Dockerfile}] +trim_trailing_whitespace = true + +[*.bat] +end_of_line = crlf + +[*.py] +max_line_length = 79 +indent_style = space +indent_size = 4 diff --git a/.gitignore b/.gitignore index 6f5d483d..f485e71b 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,4 @@ neo4j-enterprise-* *.so testkit/CAs +testkit/CustomCAs diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..c4eafcb5 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,32 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.0.1 + hooks: + - id: check-case-conflict + - id: check-docstring-first + - id: check-symlinks + - id: destroyed-symlinks + - id: end-of-file-fixer + exclude_types: + - image + - id: fix-encoding-pragma + args: [ --remove ] + - id: mixed-line-ending + args: [ --fix=lf ] + exclude_types: + - batch + - id: trailing-whitespace + args: [ --markdown-linebreak-ext=md ] + - repo: https://github.com/pycqa/isort + rev: 5.10.0 + hooks: + - id: isort + - repo: local + hooks: + - id: unasync + name: unasync + entry: bin/make-unasync + language: system + files: "^(neo4j/_async|tests/unit/async_|testkitbackend/_async)/.*" diff --git a/CHANGELOG.md b/CHANGELOG.md index e98eeeb8..4f6e2000 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,10 @@ - Python 3.10 support added - Python 3.6 support has been dropped. - +- `Result`, `Session`, and `Transaction` can no longer be imported from + `neo4j.work`. They should've been imported from `neo4j` all along. +- Experimental pipelines feature has been removed. +- Experimental async driver has been added. ## Version 4.4 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b6ebbe8d..d2656fd6 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -51,12 +51,32 @@ Occasionally, we might also have logistical, commercial, or legal reasons why we Remember that many community members have become regular contributors and some are now even Neo employees! +## Specifically for this project: + +All code in `_sync` or `sync` folders is auto-generated. Don't change it, but +install the pre-commit hooks as described below insted. They will take care of +updating the code if necessary. + +Setting up the development environment: + * Install Python 3.7+ + * Install the requirements + ```bash + $ python3 -m pip install -U pip + $ python3 -m pip install -Ur requirements-dev.txt + ``` + * Install the pre-commit hook, that will do some code-format-checking everytime + you commit. + ```bash + $ pre-commit install + ``` + + ## Got an idea for a new project? If you have an idea for a new tool or library, start by talking to other people in the community. Chances are that someone has a similar idea or may have already started working on it. The best software comes from getting like minds together to solve a problem. -And we'll do our best to help you promote and co-ordinate your Neo ecosystem projects. +And we'll do our best to help you promote and co-ordinate your Neo4j ecosystem projects. ## Further reading diff --git a/ISSUE_TEMPLATE.md b/ISSUE_TEMPLATE.md index d413cc9c..0a9addf8 100644 --- a/ISSUE_TEMPLATE.md +++ b/ISSUE_TEMPLATE.md @@ -9,7 +9,7 @@ If you simply want to get started or have a question on how to use a particular [StackOverflow](http://stackoverflow.com/questions/tagged/neo4j) also hosts a ton of questions and might already have a discussion around your problem. Make sure you have a look there too. -If you want to make a feature request, please prefix your issue title with `[Feature Request]` so that it is clear to us. +If you want to make a feature request, please prefix your issue title with `[Feature Request]` so that it is clear to us. If you have a bug report however, please continue reading. To help us understand your issue, please specify important details, primarily: diff --git a/bin/make-unasync b/bin/make-unasync new file mode 100755 index 00000000..c217a539 --- /dev/null +++ b/bin/make-unasync @@ -0,0 +1,311 @@ +#!/usr/bin/env python + +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import collections +import errno +import os +from pathlib import Path +import re +import sys +import tokenize as std_tokenize + +import isort +import isort.files +import unasync + + +ROOT_DIR = Path(__file__).parents[1].absolute() +ASYNC_DIR = ROOT_DIR / "neo4j" / "_async" +SYNC_DIR = ROOT_DIR / "neo4j" / "_sync" +ASYNC_TEST_DIR = ROOT_DIR / "tests" / "unit" / "async_" +SYNC_TEST_DIR = ROOT_DIR / "tests" / "unit" / "sync" +ASYNC_TESTKIT_BACKEND_DIR = ROOT_DIR / "testkitbackend" / "_async" +SYNC_TESTKIT_BACKEND_DIR = ROOT_DIR / "testkitbackend" / "_sync" +UNASYNC_SUFFIX = ".unasync" + +PY_FILE_EXTENSIONS = {".py", ".pyi"} + + +# copy from unasync for customization ----------------------------------------- +# https://github.com/python-trio/unasync +# License: MIT and Apache2 + + +Token = collections.namedtuple( + "Token", ["type", "string", "start", "end", "line"] +) + + +def _makedirs_existok(dir): + try: + os.makedirs(dir) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + +def _get_tokens(f): + if sys.version_info[0] == 2: + for tok in std_tokenize.generate_tokens(f.readline): + type_, string, start, end, line = tok + yield Token(type_, string, start, end, line) + else: + for tok in std_tokenize.tokenize(f.readline): + if tok.type == std_tokenize.ENCODING: + continue + yield tok + + +def _tokenize(f): + last_end = (1, 0) + for tok in _get_tokens(f): + if last_end[0] < tok.start[0]: + yield "", std_tokenize.STRING, " \\\n" + last_end = (tok.start[0], 0) + + space = "" + if tok.start > last_end: + assert tok.start[0] == last_end[0] + space = " " * (tok.start[1] - last_end[1]) + yield space, tok.type, tok.string + + last_end = tok.end + if tok.type in [std_tokenize.NEWLINE, std_tokenize.NL]: + last_end = (tok.end[0] + 1, 0) + + +def _untokenize(tokens): + return "".join(space + tokval for space, tokval in tokens) + + +# end of copy ----------------------------------------------------------------- + + +class CustomRule(unasync.Rule): + def __init__(self, *args, **kwargs): + super(CustomRule, self).__init__(*args, **kwargs) + self.out_files = [] + + def _unasync_tokens(self, tokens): + # copy from unasync to hook into string handling + # https://github.com/python-trio/unasync + # License: MIT and Apache2 + # TODO __await__, ...? + used_space = None + for space, toknum, tokval in tokens: + if tokval in ["async", "await"]: + # When removing async or await, we want to use the whitespace + # that was before async/await before the next token so that + # `print(await stuff)` becomes `print(stuff)` and not + # `print( stuff)` + used_space = space + else: + if toknum == std_tokenize.NAME: + tokval = self._unasync_name(tokval) + elif toknum == std_tokenize.STRING: + if tokval[0] == tokval[1] and len(tokval) > 2: + # multiline string (`"""..."""` or `'''...'''`) + left_quote, name, right_quote = \ + tokval[:3], tokval[3:-3], tokval[-3:] + else: + # simple string (`"..."` or `'...'`) + left_quote, name, right_quote = \ + tokval[:1], tokval[1:-1], tokval[-1:] + tokval = \ + left_quote + self._unasync_string(name) + right_quote + if used_space is None: + used_space = space + yield (used_space, tokval) + used_space = None + + def _unasync_string(self, name): + start = 0 + end = 1 + out = "" + while end < len(name): + sub_name = name[start:end] + if sub_name.isidentifier(): + end += 1 + else: + if end == start + 1: + out += sub_name + start += 1 + end += 1 + else: + out += self._unasync_prefix(name[start:(end - 1)]) + start = end - 1 + + sub_name = name[start:] + if sub_name.isidentifier(): + out += self._unasync_prefix(name[start:]) + else: + out += sub_name + + # very boiled down unasync version that removes "async" and "await" + # substrings. + out = re.subn(r"(^|\s+|(?<=\W))(?:async|await)\s+", r"\1", out, + flags=re.MULTILINE)[0] + # Convert doc-reference names from 'async-xyz' to 'xyz' + out = re.subn(r":ref:`async-", ":ref:`", out)[0] + return out + + def _unasync_prefix(self, name): + # Convert class names from 'AsyncXyz' to 'Xyz' + if len(name) > 5 and name.startswith("Async") and name[5].isupper(): + return name[5:] + # Convert variable/method/function names from 'async_xyz' to 'xyz' + elif len(name) > 6 and name.startswith("async_"): + return name[6:] + return name + + def _unasync_name(self, name): + # copy from unasync to customize renaming rules + # https://github.com/python-trio/unasync + # License: MIT and Apache2 + if name in self.token_replacements: + return self.token_replacements[name] + return self._unasync_prefix(name) + + def _unasync_file(self, filepath): + # copy from unasync to append file suffix to out path + # https://github.com/python-trio/unasync + # License: MIT and Apache2 + with open(filepath, "rb") as f: + write_kwargs = {} + if sys.version_info[0] >= 3: + encoding, _ = std_tokenize.detect_encoding(f.readline) + write_kwargs["encoding"] = encoding + f.seek(0) + tokens = _tokenize(f) + tokens = self._unasync_tokens(tokens) + result = _untokenize(tokens) + outfile_path = filepath.replace(self.fromdir, self.todir) + outfile_path += UNASYNC_SUFFIX + self.out_files.append(outfile_path) + _makedirs_existok(os.path.dirname(outfile_path)) + with open(outfile_path, "w", **write_kwargs) as f: + print(result, file=f, end="") + + +def apply_unasync(files): + """Generate sync code from async code.""" + + additional_main_replacements = {} + additional_test_replacements = { + "_async": "_sync", + "mark_async_test": "mark_sync_test", + } + additional_testkit_backend_replacements = {} + rules = [ + CustomRule( + fromdir=str(ASYNC_DIR), + todir=str(SYNC_DIR), + additional_replacements=additional_main_replacements, + ), + CustomRule( + fromdir=str(ASYNC_TEST_DIR), + todir=str(SYNC_TEST_DIR), + additional_replacements=additional_test_replacements, + ), + CustomRule( + fromdir=str(ASYNC_TESTKIT_BACKEND_DIR), + todir=str(SYNC_TESTKIT_BACKEND_DIR), + additional_replacements=additional_testkit_backend_replacements, + ), + ] + + if not files: + paths = list(ASYNC_DIR.rglob("*")) + paths += list(ASYNC_TEST_DIR.rglob("*")) + paths += list(ASYNC_TESTKIT_BACKEND_DIR.rglob("*")) + else: + paths = [ROOT_DIR / Path(f) for f in files] + filtered_paths = [] + for path in paths: + if path.suffix in PY_FILE_EXTENSIONS: + filtered_paths.append(path) + + unasync.unasync_files(map(str, filtered_paths), rules) + + return [Path(path) for rule in rules for path in rule.out_files] + + +def apply_isort(paths): + """Sort imports in generated sync code. + + Since classes in imports are renamed from AsyncXyz to Xyz, the alphabetical + order of the import can change. + """ + isort_config = isort.Config(settings_path=str(ROOT_DIR), quiet=True) + + for path in paths: + isort.file(str(path), config=isort_config) + + return paths + + +def apply_changes(paths): + def files_equal(path1, path2): + with open(path1, "rb") as f1: + with open(path2, "rb") as f2: + data1 = f1.read(1024) + data2 = f2.read(1024) + while data1 or data2: + if data1 != data2: + changed_paths[path1] = path2 + return False + data1 = f1.read(1024) + data2 = f2.read(1024) + return True + + changed_paths = {} + + for in_path in paths: + out_path = Path(str(in_path)[:-len(UNASYNC_SUFFIX)]) + if not out_path.is_file(): + changed_paths[in_path] = out_path + continue + if not files_equal(in_path, out_path): + changed_paths[in_path] = out_path + continue + in_path.unlink() + + for in_path, out_path in changed_paths.items(): + in_path.replace(out_path) + + return list(changed_paths.values()) + + +def main(): + files = None + if len(sys.argv) >= 1: + files = sys.argv[1:] + paths = apply_unasync(files) + paths = apply_isort(paths) + changed_paths = apply_changes(paths) + + if changed_paths: + for path in changed_paths: + print("Updated " + str(path)) + exit(1) + + +if __name__ == "__main__": + main() diff --git a/docs/source/_static/anchor_new_target.js b/docs/source/_static/anchor_new_target.js index 367af0a2..015df714 100644 --- a/docs/source/_static/anchor_new_target.js +++ b/docs/source/_static/anchor_new_target.js @@ -1,4 +1,4 @@ $(document).ready(function () { $('a[href^="http://"], a[href^="https://"]').not('a[class*=internal]').attr('target', '_blank'); -}); \ No newline at end of file +}); diff --git a/docs/source/api.rst b/docs/source/api.rst index 3481480f..f10c71b9 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -11,34 +11,35 @@ GraphDatabase Driver Construction =================== -The :class:`neo4j.Driver` construction is via a `classmethod` on the :class:`neo4j.GraphDatabase` class. +The :class:`neo4j.Driver` construction is done via a `classmethod` on the :class:`neo4j.GraphDatabase` class. .. autoclass:: neo4j.GraphDatabase :members: driver -Example, driver creation: +Driver creation example: .. code-block:: python from neo4j import GraphDatabase uri = "neo4j://example.com:7687" - driver = GraphDatabase.driver(uri, auth=("neo4j", "password"), max_connection_lifetime=1000) + driver = GraphDatabase.driver(uri, auth=("neo4j", "password")) driver.close() # close the driver object -For basic auth, this can be a simple tuple, for example: +For basic authentication, `auth` can be a simple tuple, for example: .. code-block:: python auth = ("neo4j", "password") -This will implicitly create a :class:`neo4j.Auth` with a ``scheme="basic"`` +This will implicitly create a :class:`neo4j.Auth` with a ``scheme="basic"``. +Other authentication methods are described under :ref:`auth-ref`. -Example, with block context: +``with`` block context example: .. code-block:: python @@ -330,7 +331,8 @@ For example: Connection details held by the :class:`neo4j.Driver` are immutable. Therefore if, for example, a password is changed, a replacement :class:`neo4j.Driver` object must be created. -More than one :class:`.Driver` may be required if connections to multiple databases, or connections as multiple users, are required. +More than one :class:`.Driver` may be required if connections to multiple databases, or connections as multiple users, are required, +unless when using impersonation (:ref:`impersonated-user-ref`). :class:`neo4j.Driver` objects are thread-safe but cannot be shared across processes. Therefore, ``multithreading`` should generally be preferred over ``multiprocessing`` for parallel database access. @@ -345,11 +347,9 @@ BoltDriver URI schemes: ``bolt``, ``bolt+ssc``, ``bolt+s`` -Driver subclass: - :class:`neo4j.BoltDriver` +Will result in: -.. - .. autoclass:: neo4j.BoltDriver +.. autoclass:: neo4j.BoltDriver .. _neo4j-driver-ref: @@ -360,11 +360,9 @@ Neo4jDriver URI schemes: ``neo4j``, ``neo4j+ssc``, ``neo4j+s`` -Driver subclass: - :class:`neo4j.Neo4jDriver` +Will result in: -.. - .. autoclass:: neo4j.Neo4jDriver +.. autoclass:: neo4j.Neo4jDriver *********************** @@ -374,7 +372,7 @@ All database activity is co-ordinated through two mechanisms: the :class:`neo4j. A :class:`neo4j.Session` is a logical container for any number of causally-related transactional units of work. Sessions automatically provide guarantees of causal consistency within a clustered environment but multiple sessions can also be causally chained if required. -Sessions provide the top-level of containment for database activity. +Sessions provide the top level of containment for database activity. Session creation is a lightweight operation and *sessions are not thread safe*. Connections are drawn from the :class:`neo4j.Driver` connection pool as required. @@ -604,7 +602,8 @@ Example: def create_person(driver, name): with driver.session(default_access_mode=neo4j.WRITE_ACCESS) as session: - result = session.run("CREATE (a:Person { name: $name }) RETURN id(a) AS node_id", name=name) + query = "CREATE (a:Person { name: $name }) RETURN id(a) AS node_id" + result = session.run(query, name=name) record = result.single() return record["node_id"] @@ -665,13 +664,15 @@ Example: tx.close() def create_person_node(tx): + query = "CREATE (a:Person { name: $name }) RETURN id(a) AS node_id" name = "default_name" - result = tx.run("CREATE (a:Person { name: $name }) RETURN id(a) AS node_id", name=name) + result = tx.run(query, name=name) record = result.single() return record["node_id"] def set_person_name(tx, node_id, name): - result = tx.run("MATCH (a:Person) WHERE id(a) = $id SET a.name = $name", id=node_id, name=name) + query = "MATCH (a:Person) WHERE id(a) = $id SET a.name = $name" + result = tx.run(query, id=node_id, name=name) info = result.consume() # use the info for logging etc. @@ -698,7 +699,8 @@ Example: node_id = session.write_transaction(create_person_tx, name) def create_person_tx(tx, name): - result = tx.run("CREATE (a:Person { name: $name }) RETURN id(a) AS node_id", name=name) + query = "CREATE (a:Person { name: $name }) RETURN id(a) AS node_id" + result = tx.run(query, name=name) record = result.single() return record["node_id"] @@ -708,12 +710,6 @@ To exert more control over how a transaction function is carried out, the :func: - - - - - - ****** Result ****** diff --git a/docs/source/async_api.rst b/docs/source/async_api.rst new file mode 100644 index 00000000..989ab200 --- /dev/null +++ b/docs/source/async_api.rst @@ -0,0 +1,498 @@ +.. _async-api-documentation: + +####################### +Async API Documentation +####################### + +.. warning:: + The whole async API is currently in experimental phase. + + This means everything documented on this page might be removed or change + its API at any time (including in patch releases). + +****************** +AsyncGraphDatabase +****************** + +Async Driver Construction +========================= + +The :class:`neo4j.AsyncDriver` construction is done via a `classmethod` on the :class:`neo4j.AsyncGraphDatabase` class. + +.. autoclass:: neo4j.AsyncGraphDatabase + :members: driver + + +Driver creation example: + +.. code-block:: python + + import asyncio + + from neo4j import AsyncGraphDatabase + + async def main(): + uri = "neo4j://example.com:7687" + driver = AsyncGraphDatabase.driver(uri, auth=("neo4j", "password")) + + await driver.close() # close the driver object + + asyncio.run(main()) + + +For basic authentication, ``auth`` can be a simple tuple, for example: + +.. code-block:: python + + auth = ("neo4j", "password") + +This will implicitly create a :class:`neo4j.Auth` with a ``scheme="basic"``. +Other authentication methods are described under :ref:`auth-ref`. + +``with`` block context example: + +.. code-block:: python + + import asyncio + + from neo4j import AsyncGraphDatabase + + async def main(): + uri = "neo4j://example.com:7687" + auth = ("neo4j", "password") + async with AsyncGraphDatabase.driver(uri, auth=auth) as driver: + # use the driver + ... + + asyncio.run(main()) + + +.. _async-uri-ref: + +URI +=== + +On construction, the `scheme` of the URI determines the type of :class:`neo4j.AsyncDriver` object created. + +Available valid URIs: + ++ ``bolt://host[:port]`` ++ ``bolt+ssc://host[:port]`` ++ ``bolt+s://host[:port]`` ++ ``neo4j://host[:port][?routing_context]`` ++ ``neo4j+ssc://host[:port][?routing_context]`` ++ ``neo4j+s://host[:port][?routing_context]`` + +.. code-block:: python + + uri = "bolt://example.com:7687" + +.. code-block:: python + + uri = "neo4j://example.com:7687" + +Each supported scheme maps to a particular :class:`neo4j.AsyncDriver` subclass that implements a specific behaviour. + ++------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+ +| URI Scheme | Driver Object and Setting | ++========================+=============================================================================================================================================+ +| bolt | :ref:`async-bolt-driver-ref` with no encryption. | ++------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+ +| bolt+ssc | :ref:`async-bolt-driver-ref` with encryption (accepts self signed certificates). | ++------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+ +| bolt+s | :ref:`async-bolt-driver-ref` with encryption (accepts only certificates signed by a certificate authority), full certificate checks. | ++------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+ +| neo4j | :ref:`async-neo4j-driver-ref` with no encryption. | ++------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+ +| neo4j+ssc | :ref:`async-neo4j-driver-ref` with encryption (accepts self signed certificates). | ++------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+ +| neo4j+s | :ref:`async-neo4j-driver-ref` with encryption (accepts only certificates signed by a certificate authority), full certificate checks. | ++------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+ + +.. note:: + + See https://neo4j.com/docs/operations-manual/current/configuration/ports/ for Neo4j ports. + + + +*********** +AsyncDriver +*********** + +Every Neo4j-backed application will require a :class:`neo4j.AsyncDriver` object. + +This object holds the details required to establish connections with a Neo4j database, including server URIs, credentials and other configuration. +:class:`neo4j.AsyncDriver` objects hold a connection pool from which :class:`neo4j.AsyncSession` objects can borrow connections. +Closing a driver will immediately shut down all connections in the pool. + +.. autoclass:: neo4j.AsyncDriver() + :members: session, close + + +.. _async-driver-configuration-ref: + +Async Driver Configuration +========================== + +:class:`neo4j.AsyncDriver` is configured exactly like :class:`neo4j.Driver` +(see :ref:`driver-configuration-ref`). The only difference is that the async +driver accepts an async custom resolver function: + +.. _async-resolver-ref: + +``resolver`` +------------ +A custom resolver function to resolve host and port values ahead of DNS resolution. +This function is called with a 2-tuple of (host, port) and should return an iterable of 2-tuples (host, port). + +If no custom resolver function is supplied, the internal resolver moves straight to regular DNS resolution. + +The custom resolver function can but does not have to be a coroutine. + +For example: + +.. code-block:: python + + from neo4j import AsyncGraphDatabase + + async def custom_resolver(socket_address): + if socket_address == ("example.com", 9999): + yield "::1", 7687 + yield "127.0.0.1", 7687 + else: + from socket import gaierror + raise gaierror("Unexpected socket address %r" % socket_address) + + # alternatively + def custom_resolver(socket_address): + ... + + driver = AsyncGraphDatabase.driver("neo4j://example.com:9999", + auth=("neo4j", "password"), + resolver=custom_resolver) + + +:Default: ``None`` + + + +Driver Object Lifetime +====================== + +For general applications, it is recommended to create one top-level :class:`neo4j.AsyncDriver` object that lives for the lifetime of the application. + +For example: + +.. code-block:: python + + from neo4j import AsyncGraphDatabase + + class Application: + + def __init__(self, uri, user, password) + self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password)) + + async def close(self): + await self.driver.close() + +Connection details held by the :class:`neo4j.AsyncDriver` are immutable. +Therefore if, for example, a password is changed, a replacement :class:`neo4j.AsyncDriver` object must be created. +More than one :class:`.AsyncDriver` may be required if connections to multiple databases, or connections as multiple users, are required, +unless when using impersonation (:ref:`impersonated-user-ref`). + +:class:`neo4j.AsyncDriver` objects are safe to be used in concurrent coroutines. +They are not thread-safe. + + +.. _async-bolt-driver-ref: + +AsyncBoltDriver +=============== + +URI schemes: + ``bolt``, ``bolt+ssc``, ``bolt+s`` + +Will result in: + +.. autoclass:: neo4j.AsyncBoltDriver + + +.. _async-neo4j-driver-ref: + +AsyncNeo4jDriver +================ + +URI schemes: + ``neo4j``, ``neo4j+ssc``, ``neo4j+s`` + +Will result in: + +.. autoclass:: neo4j.AsyncNeo4jDriver + + +********************************* +AsyncSessions & AsyncTransactions +********************************* +All database activity is co-ordinated through two mechanisms: the :class:`neo4j.AsyncSession` and the :class:`neo4j.AsyncTransaction`. + +A :class:`neo4j.AsyncSession` is a logical container for any number of causally-related transactional units of work. +Sessions automatically provide guarantees of causal consistency within a clustered environment but multiple sessions can also be causally chained if required. +Sessions provide the top level of containment for database activity. +Session creation is a lightweight operation and *sessions cannot be shared between coroutines*. + +Connections are drawn from the :class:`neo4j.AsyncDriver` connection pool as required. + +A :class:`neo4j.AsyncTransaction` is a unit of work that is either committed in its entirety or is rolled back on failure. + + +.. _async-session-construction-ref: + +************************* +AsyncSession Construction +************************* + +To construct a :class:`neo4j.AsyncSession` use the :meth:`neo4j.AsyncDriver.session` method. + +.. code-block:: python + + import asyncio + + from neo4j import AsyncGraphDatabase + + async def main(): + driver = AsyncGraphDatabase(uri, auth=(user, password)) + session = driver.session() + result = await session.run("MATCH (a:Person) RETURN a.name AS name") + names = [record["name"] async for record in result] + await session.close() + await driver.close() + + asyncio.run(main()) + + +Sessions will often be created and destroyed using a *with block context*. + +.. code-block:: python + + async with driver.session() as session: + result = await session.run("MATCH (a:Person) RETURN a.name AS name") + # do something with the result... + + +Sessions will often be created with some configuration settings, see :ref:`async-session-configuration-ref`. + +.. code-block:: python + + async with driver.session(database="example_database", + fetch_size=100) as session: + result = await session.run("MATCH (a:Person) RETURN a.name AS name") + # do something with the result... + + +************ +AsyncSession +************ + +.. autoclass:: neo4j.AsyncSession() + + .. automethod:: close + + .. automethod:: run + + .. automethod:: last_bookmark + + .. automethod:: begin_transaction + + .. automethod:: read_transaction + + .. automethod:: write_transaction + + + +.. _async-session-configuration-ref: + +Session Configuration +===================== + +:class:`neo4j.AsyncSession` is configured exactly like :class:`neo4j.Session` +(see :ref:`session-configuration-ref`). + + +**************** +AsyncTransaction +**************** + +Neo4j supports three kinds of async transaction: + ++ :ref:`async-auto-commit-transactions-ref` ++ :ref:`async-explicit-transactions-ref` ++ :ref:`async-managed-transactions-ref` + +Each has pros and cons but if in doubt, use a managed transaction with a `transaction function`. + + +.. _async-auto-commit-transactions-ref: + +Async Auto-commit Transactions +============================== +Auto-commit transactions are the simplest form of transaction, available via :py:meth:`neo4j.AsyncSession.run`. + +These are easy to use but support only one statement per transaction and are not automatically retried on failure. +Auto-commit transactions are also the only way to run ``PERIODIC COMMIT`` statements, since this Cypher clause manages its own transactions internally. + +Example: + +.. code-block:: python + + import neo4j + + async def create_person(driver, name): + async with driver.session( + default_access_mode=neo4j.WRITE_ACCESS + ) as session: + query = "CREATE (a:Person { name: $name }) RETURN id(a) AS node_id" + result = await session.run(query, name=name) + record = await result.single() + return record["node_id"] + +Example: + +.. code-block:: python + + import neo4j + + async def get_numbers(driver): + numbers = [] + async with driver.session( + default_access_mode=neo4j.READ_ACCESS + ) as session: + result = await session.run("UNWIND [1, 2, 3] AS x RETURN x") + async for record in result: + numbers.append(record["x"]) + return numbers + + +.. _async-explicit-transactions-ref: + +Explicit Async Transactions +=========================== +Explicit transactions support multiple statements and must be created with an explicit :py:meth:`neo4j.AsyncSession.begin_transaction` call. + +This creates a new :class:`neo4j.AsyncTransaction` object that can be used to run Cypher. + +It also gives applications the ability to directly control `commit` and `rollback` activity. + +.. autoclass:: neo4j.AsyncTransaction() + + .. automethod:: run + + .. automethod:: close + + .. automethod:: closed + + .. automethod:: commit + + .. automethod:: rollback + +Closing an explicit transaction can either happen automatically at the end of a ``async with`` block, +or can be explicitly controlled through the :py:meth:`neo4j.AsyncTransaction.commit`, :py:meth:`neo4j.AsyncTransaction.rollback` or :py:meth:`neo4j.AsyncTransaction.close` methods. + +Explicit transactions are most useful for applications that need to distribute Cypher execution across multiple functions for the same transaction. + +Example: + +.. code-block:: python + + import neo4j + + async def create_person(driver, name): + async with driver.session( + default_access_mode=neo4j.WRITE_ACCESS + ) as session: + tx = await session.begin_transaction() + node_id = await create_person_node(tx) + await set_person_name(tx, node_id, name) + await tx.commit() + await tx.close() + + async def create_person_node(tx): + query = "CREATE (a:Person { name: $name }) RETURN id(a) AS node_id" + name = "default_name" + result = await tx.run(query, name=name) + record = await result.single() + return record["node_id"] + + async def set_person_name(tx, node_id, name): + query = "MATCH (a:Person) WHERE id(a) = $id SET a.name = $name" + result = await tx.run(query, id=node_id, name=name) + info = await result.consume() + # use the info for logging etc. + +.. _async-managed-transactions-ref: + + +Managed Async Transactions (`transaction functions`) +==================================================== +Transaction functions are the most powerful form of transaction, providing access mode override and retry capabilities. + ++ :py:meth:`neo4j.AsyncSession.write_transaction` ++ :py:meth:`neo4j.AsyncSession.read_transaction` + +These allow a function object representing the transactional unit of work to be passed as a parameter. +This function is called one or more times, within a configurable time limit, until it succeeds. +Results should be fully consumed within the function and only aggregate or status values should be returned. +Returning a live result object would prevent the driver from correctly managing connections and would break retry guarantees. + +Example: + +.. code-block:: python + + async def create_person(driver, name) + async with driver.session() as session: + node_id = await session.write_transaction(create_person_tx, name) + + async def create_person_tx(tx, name): + query = "CREATE (a:Person { name: $name }) RETURN id(a) AS node_id" + result = await tx.run(query, name=name) + record = await result.single() + return record["node_id"] + +To exert more control over how a transaction function is carried out, the :func:`neo4j.unit_of_work` decorator can be used. + + + +*********** +AsyncResult +*********** + +Every time a query is executed, a :class:`neo4j.AsyncResult` is returned. + +This provides a handle to the result of the query, giving access to the records within it as well as the result metadata. + +Results also contain a buffer that automatically stores unconsumed records when results are consumed out of order. + +A :class:`neo4j.AsyncResult` is attached to an active connection, through a :class:`neo4j.AsyncSession`, until all its content has been buffered or consumed. + +.. autoclass:: neo4j.AsyncResult() + + .. describe:: iter(result) + + .. automethod:: keys + + .. automethod:: consume + + .. automethod:: single + + .. automethod:: peek + + .. automethod:: graph + + **This is experimental.** (See :ref:`filter-warnings-ref`) + + .. automethod:: value + + .. automethod:: values + + .. automethod:: data + +See https://neo4j.com/docs/driver-manual/current/cypher-workflow/#driver-type-mapping for more about type mapping. diff --git a/docs/source/conf.py b/docs/source/conf.py index dae62ce1..72111d76 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# # Neo4j Bolt Driver for Python documentation build configuration file, created by # sphinx-quickstart on Mon Sep 21 11:48:02 2015. # @@ -13,16 +10,20 @@ # All configuration values have a default; values that are commented out # serve to show the default. -import sys + import os -import shlex +import sys + # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. sys.path.insert(0, os.path.abspath(os.path.join("..", ".."))) + + from neo4j.meta import version as project_version + # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. diff --git a/docs/source/errors.rst b/docs/source/errors.rst index 0a4d5927..14db7d5a 100644 --- a/docs/source/errors.rst +++ b/docs/source/errors.rst @@ -43,4 +43,4 @@ Python Version: Python Driver Version: Neo4j Version: -the code block with a description that produced the error and the error message. \ No newline at end of file +the code block with a description that produced the error and the error message. diff --git a/docs/source/index.rst b/docs/source/index.rst index ac4bda16..24ef2386 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -23,6 +23,8 @@ Topics + :ref:`api-documentation` ++ :ref:`async-api-documentation` (experimental) + + :ref:`temporal-data-types` + :ref:`breaking-changes` @@ -32,6 +34,7 @@ Topics :hidden: api.rst + async_api.rst temporal_types.rst breaking_changes.rst diff --git a/docs/source/transactions.rst b/docs/source/transactions.rst index 93dc3510..1bb4db2a 100644 --- a/docs/source/transactions.rst +++ b/docs/source/transactions.rst @@ -11,7 +11,7 @@ Sessions automatically provide guarantees of causal consistency within a cluster Sessions ======== -Sessions provide the top-level of containment for database activity. +Sessions provide the top level of containment for database activity. Session creation is a lightweight operation and sessions are `not` thread safe. Connections are drawn from the :class:`neo4j.Driver` connection pool as required; an idle session will not hold onto a connection. diff --git a/neo4j/__init__.py b/neo4j/__init__.py index cd0139b6..f54aa5e1 100644 --- a/neo4j/__init__.py +++ b/neo4j/__init__.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -21,441 +18,117 @@ __all__ = [ "__version__", - "GraphDatabase", - "Driver", - "BoltDriver", - "Neo4jDriver", + "Address", + "AsyncBoltDriver", + "AsyncDriver", + "AsyncGraphDatabase", + "AsyncNeo4jDriver", + "AsyncResult", + "AsyncSession", + "AsyncTransaction", "Auth", "AuthToken", "basic_auth", - "kerberos_auth", "bearer_auth", - "custom_auth", + "BoltDriver", "Bookmark", - "ServerInfo", - "Version", - "READ_ACCESS", - "WRITE_ACCESS", + "Config", + "custom_auth", "DEFAULT_DATABASE", - "TRUST_ALL_CERTIFICATES", - "TRUST_SYSTEM_CA_SIGNED_CERTIFICATES", - "Address", + "Driver", + "ExperimentalWarning", + "GraphDatabase", "IPv4Address", "IPv6Address", - "Config", + "kerberos_auth", + "Neo4jDriver", "PoolConfig", - "WorkspaceConfig", - "SessionConfig", + "Query", + "READ_ACCESS", "Record", - "Transaction", "Result", "ResultSummary", - "SummaryCounters", - "Query", + "ServerInfo", "Session", + "SessionConfig", + "SummaryCounters", + "Transaction", + "TRUST_ALL_CERTIFICATES", + "TRUST_SYSTEM_CA_SIGNED_CERTIFICATES", "unit_of_work", - "ExperimentalWarning", + "Version", + "WorkspaceConfig", + "WRITE_ACCESS", ] -from logging import getLogger +from logging import getLogger -from neo4j.addressing import ( +from ._async.driver import ( + AsyncBoltDriver, + AsyncDriver, + AsyncGraphDatabase, + AsyncNeo4jDriver, +) +from ._async.work import ( + AsyncResult, + AsyncSession, + AsyncTransaction, +) +from ._sync.driver import ( + BoltDriver, + Driver, + GraphDatabase, + Neo4jDriver, +) +from ._sync.work import ( + Result, + Session, + Transaction, +) +from .addressing import ( Address, IPv4Address, IPv6Address, ) -from neo4j.api import ( +from .api import ( Auth, # TODO: Validate naming for Auth compared to other drivers. +) +from .api import ( AuthToken, basic_auth, - kerberos_auth, bearer_auth, - custom_auth, Bookmark, - ServerInfo, - Version, + custom_auth, + DEFAULT_DATABASE, + kerberos_auth, READ_ACCESS, - WRITE_ACCESS, + ServerInfo, SYSTEM_DATABASE, - DEFAULT_DATABASE, TRUST_ALL_CERTIFICATES, TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, + Version, + WRITE_ACCESS, ) -from neo4j.conf import ( +from .conf import ( Config, PoolConfig, - WorkspaceConfig, SessionConfig, + WorkspaceConfig, ) -from neo4j.meta import ( +from .data import Record +from .meta import ( experimental, ExperimentalWarning, get_user_agent, version as __version__, ) -from neo4j.data import ( - Record, -) -from neo4j.work.simple import ( +from .work import ( Query, - Session, - unit_of_work, -) -from neo4j.work.transaction import ( - Transaction, -) -from neo4j.work.result import ( - Result, -) -from neo4j.work.summary import ( ResultSummary, SummaryCounters, + unit_of_work, ) log = getLogger("neo4j") - - -class GraphDatabase: - """Accessor for :class:`neo4j.Driver` construction. - """ - - @classmethod - def driver(cls, uri, *, auth=None, **config): - """Create a driver. - - :param uri: the connection URI for the driver, see :ref:`uri-ref` for available URIs. - :param auth: the authentication details, see :ref:`auth-ref` for available authentication details. - :param config: driver configuration key-word arguments, see :ref:`driver-configuration-ref` for available key-word arguments. - - :return: :ref:`neo4j-driver-ref` or :ref:`bolt-driver-ref` - """ - - from neo4j.api import ( - parse_neo4j_uri, - parse_routing_context, - DRIVER_BOLT, - DRIVER_NEO4j, - SECURITY_TYPE_NOT_SECURE, - SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, - SECURITY_TYPE_SECURE, - URI_SCHEME_BOLT, - URI_SCHEME_NEO4J, - URI_SCHEME_BOLT_SELF_SIGNED_CERTIFICATE, - URI_SCHEME_BOLT_SECURE, - URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE, - URI_SCHEME_NEO4J_SECURE, - ) - - driver_type, security_type, parsed = parse_neo4j_uri(uri) - - if "trust" in config.keys(): - if config.get("trust") not in [TRUST_ALL_CERTIFICATES, TRUST_SYSTEM_CA_SIGNED_CERTIFICATES]: - from neo4j.exceptions import ConfigurationError - raise ConfigurationError("The config setting `trust` values are {!r}".format( - [ - TRUST_ALL_CERTIFICATES, - TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, - ] - )) - - if security_type in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE] and ("encrypted" in config.keys() or "trust" in config.keys()): - from neo4j.exceptions import ConfigurationError - raise ConfigurationError("The config settings 'encrypted' and 'trust' can only be used with the URI schemes {!r}. Use the other URI schemes {!r} for setting encryption settings.".format( - [ - URI_SCHEME_BOLT, - URI_SCHEME_NEO4J, - ], - [ - URI_SCHEME_BOLT_SELF_SIGNED_CERTIFICATE, - URI_SCHEME_BOLT_SECURE, - URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE, - URI_SCHEME_NEO4J_SECURE, - ] - )) - - if security_type == SECURITY_TYPE_SECURE: - config["encrypted"] = True - elif security_type == SECURITY_TYPE_SELF_SIGNED_CERTIFICATE: - config["encrypted"] = True - config["trust"] = TRUST_ALL_CERTIFICATES - - if driver_type == DRIVER_BOLT: - return cls.bolt_driver(parsed.netloc, auth=auth, **config) - elif driver_type == DRIVER_NEO4j: - routing_context = parse_routing_context(parsed.query) - return cls.neo4j_driver(parsed.netloc, auth=auth, routing_context=routing_context, **config) - - @classmethod - def bolt_driver(cls, target, *, auth=None, **config): - """ Create a driver for direct Bolt server access that uses - socket I/O and thread-based concurrency. - """ - from neo4j._exceptions import BoltHandshakeError, BoltSecurityError - - try: - return BoltDriver.open(target, auth=auth, **config) - except (BoltHandshakeError, BoltSecurityError) as error: - from neo4j.exceptions import ServiceUnavailable - raise ServiceUnavailable(str(error)) from error - - @classmethod - def neo4j_driver(cls, *targets, auth=None, routing_context=None, **config): - """ Create a driver for routing-capable Neo4j service access - that uses socket I/O and thread-based concurrency. - """ - from neo4j._exceptions import BoltHandshakeError, BoltSecurityError - - try: - return Neo4jDriver.open(*targets, auth=auth, routing_context=routing_context, **config) - except (BoltHandshakeError, BoltSecurityError) as error: - from neo4j.exceptions import ServiceUnavailable - raise ServiceUnavailable(str(error)) from error - - -class Direct: - - default_host = "localhost" - default_port = 7687 - - default_target = ":" - - def __init__(self, address): - self._address = address - - @property - def address(self): - return self._address - - @classmethod - def parse_target(cls, target): - """ Parse a target string to produce an address. - """ - if not target: - target = cls.default_target - address = Address.parse(target, default_host=cls.default_host, - default_port=cls.default_port) - return address - - -class Routing: - - default_host = "localhost" - default_port = 7687 - - default_targets = ": :17601 :17687" - - def __init__(self, initial_addresses): - self._initial_addresses = initial_addresses - - @property - def initial_addresses(self): - return self._initial_addresses - - @classmethod - def parse_targets(cls, *targets): - """ Parse a sequence of target strings to produce an address - list. - """ - targets = " ".join(targets) - if not targets: - targets = cls.default_targets - addresses = Address.parse_list(targets, default_host=cls.default_host, default_port=cls.default_port) - return addresses - - -class Driver: - """ Base class for all types of :class:`neo4j.Driver`, instances of which are - used as the primary access point to Neo4j. - """ - - #: Connection pool - _pool = None - - def __init__(self, pool): - assert pool is not None - self._pool = pool - - def __del__(self): - self.close() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.close() - - @property - def encrypted(self): - return bool(self._pool.pool_config.encrypted) - - def session(self, **config): - """Create a session, see :ref:`session-construction-ref` - - :param config: session configuration key-word arguments, see :ref:`session-configuration-ref` for available key-word arguments. - - :returns: new :class:`neo4j.Session` object - """ - raise NotImplementedError - - @experimental("The pipeline API is experimental and may be removed or changed in a future release") - def pipeline(self, **config): - """ Create a pipeline. - """ - raise NotImplementedError - - def close(self): - """ Shut down, closing any open connections in the pool. - """ - self._pool.close() - - @experimental("The configuration may change in the future.") - def verify_connectivity(self, **config): - """ This verifies if the driver can connect to a remote server or a cluster - by establishing a network connection with the remote and possibly exchanging - a few data before closing the connection. It throws exception if fails to connect. - - Use the exception to further understand the cause of the connectivity problem. - - Note: Even if this method throws an exception, the driver still need to be closed via close() to free up all resources. - """ - raise NotImplementedError - - @experimental("Feature support query, based on Bolt Protocol Version and Neo4j Server Version will change in the future.") - def supports_multi_db(self): - """ Check if the server or cluster supports multi-databases. - - :return: Returns true if the server or cluster the driver connects to supports multi-databases, otherwise false. - :rtype: bool - """ - with self.session() as session: - session._connect(READ_ACCESS) - return session._connection.supports_multiple_databases - - -class BoltDriver(Direct, Driver): - """ A :class:`.BoltDriver` is created from a ``bolt`` URI and addresses - a single database machine. This may be a standalone server or could be a - specific member of a cluster. - - Connections established by a :class:`.BoltDriver` are always made to the - exact host and port detailed in the URI. - """ - - @classmethod - def open(cls, target, *, auth=None, **config): - """ - :param target: - :param auth: - :param config: The values that can be specified are found in :class: `neo4j.PoolConfig` and :class: `neo4j.WorkspaceConfig` - - :return: - :rtype: :class: `neo4j.BoltDriver` - """ - from neo4j.io import BoltPool - address = cls.parse_target(target) - pool_config, default_workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) - pool = BoltPool.open(address, auth=auth, pool_config=pool_config, workspace_config=default_workspace_config) - return cls(pool, default_workspace_config) - - def __init__(self, pool, default_workspace_config): - Direct.__init__(self, pool.address) - Driver.__init__(self, pool) - self._default_workspace_config = default_workspace_config - - def session(self, **config): - """ - :param config: The values that can be specified are found in :class: `neo4j.SessionConfig` - - :return: - :rtype: :class: `neo4j.Session` - """ - from neo4j.work.simple import Session - session_config = SessionConfig(self._default_workspace_config, config) - SessionConfig.consume(config) # Consume the config - return Session(self._pool, session_config) - - def pipeline(self, **config): - from neo4j.work.pipelining import Pipeline, PipelineConfig - pipeline_config = PipelineConfig(self._default_workspace_config, config) - PipelineConfig.consume(config) # Consume the config - return Pipeline(self._pool, pipeline_config) - - @experimental("The configuration may change in the future.") - def verify_connectivity(self, **config): - server_agent = None - config["fetch_size"] = -1 - with self.session(**config) as session: - result = session.run("RETURN 1 AS x") - value = result.single().value() - summary = result.consume() - server_agent = summary.server.agent - return server_agent - - -class Neo4jDriver(Routing, Driver): - """ A :class:`.Neo4jDriver` is created from a ``neo4j`` URI. The - routing behaviour works in tandem with Neo4j's `Causal Clustering - `_ - feature by directing read and write behaviour to appropriate - cluster members. - """ - - @classmethod - def open(cls, *targets, auth=None, routing_context=None, **config): - from neo4j.io import Neo4jPool - addresses = cls.parse_targets(*targets) - pool_config, default_workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) - pool = Neo4jPool.open(*addresses, auth=auth, routing_context=routing_context, pool_config=pool_config, workspace_config=default_workspace_config) - return cls(pool, default_workspace_config) - - def __init__(self, pool, default_workspace_config): - Routing.__init__(self, pool.get_default_database_initial_router_addresses()) - Driver.__init__(self, pool) - self._default_workspace_config = default_workspace_config - - def session(self, **config): - session_config = SessionConfig(self._default_workspace_config, config) - SessionConfig.consume(config) # Consume the config - return Session(self._pool, session_config) - - def pipeline(self, **config): - from neo4j.work.pipelining import Pipeline, PipelineConfig - pipeline_config = PipelineConfig(self._default_workspace_config, config) - PipelineConfig.consume(config) # Consume the config - return Pipeline(self._pool, pipeline_config) - - @experimental("The configuration may change in the future.") - def verify_connectivity(self, **config): - """ - :raise ServiceUnavailable: raised if the server does not support routing or if routing support is broken. - """ - # TODO: Improve and update Stub Test Server to be able to test. - return self._verify_routing_connectivity() - - def _verify_routing_connectivity(self): - from neo4j.exceptions import ( - Neo4jError, - ServiceUnavailable, - SessionExpired, - ) - - table = self._pool.get_routing_table_for_default_database() - routing_info = {} - for ix in list(table.routers): - try: - routing_info[ix] = self._pool.fetch_routing_info( - address=table.routers[0], - database=self._default_workspace_config.database, - imp_user=self._default_workspace_config.impersonated_user, - bookmarks=None, - timeout=self._default_workspace_config - .connection_acquisition_timeout - ) - except (ServiceUnavailable, SessionExpired, Neo4jError): - routing_info[ix] = None - for key, val in routing_info.items(): - if val is not None: - return routing_info - raise ServiceUnavailable("Could not connect to any routing servers.") diff --git a/tests/unit/time/__init__.py b/neo4j/_async/__init__.py similarity index 94% rename from tests/unit/time/__init__.py rename to neo4j/_async/__init__.py index 0665bdc9..b81a309d 100644 --- a/tests/unit/time/__init__.py +++ b/neo4j/_async/__init__.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# coding: utf-8 - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # diff --git a/neo4j/_async/driver.py b/neo4j/_async/driver.py new file mode 100644 index 00000000..009efc32 --- /dev/null +++ b/neo4j/_async/driver.py @@ -0,0 +1,380 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio + +from .._async_compat.util import AsyncUtil +from ..addressing import Address +from ..api import ( + READ_ACCESS, + TRUST_ALL_CERTIFICATES, + TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, +) +from ..conf import ( + Config, + PoolConfig, + SessionConfig, + WorkspaceConfig, +) +from ..meta import experimental + + +class AsyncGraphDatabase: + """Accessor for :class:`neo4j.Driver` construction. + """ + + @classmethod + @AsyncUtil.experimental_async( + "neo4j async is in experimental phase. It might be removed or changed " + "at any time (including patch releases)." + ) + def driver(cls, uri, *, auth=None, **config): + """Create a driver. + + :param uri: the connection URI for the driver, see :ref:`async-uri-ref` for available URIs. + :param auth: the authentication details, see :ref:`auth-ref` for available authentication details. + :param config: driver configuration key-word arguments, see :ref:`async-driver-configuration-ref` for available key-word arguments. + + :rtype: AsyncNeo4jDriver or AsyncBoltDriver + """ + + from ..api import ( + DRIVER_BOLT, + DRIVER_NEO4j, + parse_neo4j_uri, + parse_routing_context, + SECURITY_TYPE_NOT_SECURE, + SECURITY_TYPE_SECURE, + SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, + URI_SCHEME_BOLT, + URI_SCHEME_BOLT_SECURE, + URI_SCHEME_BOLT_SELF_SIGNED_CERTIFICATE, + URI_SCHEME_NEO4J, + URI_SCHEME_NEO4J_SECURE, + URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE, + ) + + driver_type, security_type, parsed = parse_neo4j_uri(uri) + + if "trust" in config.keys(): + if config.get("trust") not in [TRUST_ALL_CERTIFICATES, TRUST_SYSTEM_CA_SIGNED_CERTIFICATES]: + from neo4j.exceptions import ConfigurationError + raise ConfigurationError("The config setting `trust` values are {!r}".format( + [ + TRUST_ALL_CERTIFICATES, + TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, + ] + )) + + if security_type in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE] and ("encrypted" in config.keys() or "trust" in config.keys()): + from neo4j.exceptions import ConfigurationError + raise ConfigurationError("The config settings 'encrypted' and 'trust' can only be used with the URI schemes {!r}. Use the other URI schemes {!r} for setting encryption settings.".format( + [ + URI_SCHEME_BOLT, + URI_SCHEME_NEO4J, + ], + [ + URI_SCHEME_BOLT_SELF_SIGNED_CERTIFICATE, + URI_SCHEME_BOLT_SECURE, + URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE, + URI_SCHEME_NEO4J_SECURE, + ] + )) + + if security_type == SECURITY_TYPE_SECURE: + config["encrypted"] = True + elif security_type == SECURITY_TYPE_SELF_SIGNED_CERTIFICATE: + config["encrypted"] = True + config["trust"] = TRUST_ALL_CERTIFICATES + + if driver_type == DRIVER_BOLT: + return cls.bolt_driver(parsed.netloc, auth=auth, **config) + elif driver_type == DRIVER_NEO4j: + routing_context = parse_routing_context(parsed.query) + return cls.neo4j_driver(parsed.netloc, auth=auth, routing_context=routing_context, **config) + + @classmethod + def bolt_driver(cls, target, *, auth=None, **config): + """ Create a driver for direct Bolt server access that uses + socket I/O and thread-based concurrency. + """ + from .._exceptions import ( + BoltHandshakeError, + BoltSecurityError, + ) + + try: + return AsyncBoltDriver.open(target, auth=auth, **config) + except (BoltHandshakeError, BoltSecurityError) as error: + from neo4j.exceptions import ServiceUnavailable + raise ServiceUnavailable(str(error)) from error + + @classmethod + def neo4j_driver(cls, *targets, auth=None, routing_context=None, **config): + """ Create a driver for routing-capable Neo4j service access + that uses socket I/O and thread-based concurrency. + """ + from neo4j._exceptions import ( + BoltHandshakeError, + BoltSecurityError, + ) + + try: + return AsyncNeo4jDriver.open(*targets, auth=auth, routing_context=routing_context, **config) + except (BoltHandshakeError, BoltSecurityError) as error: + from neo4j.exceptions import ServiceUnavailable + raise ServiceUnavailable(str(error)) from error + + +class _Direct: + + default_host = "localhost" + default_port = 7687 + + default_target = ":" + + def __init__(self, address): + self._address = address + + @property + def address(self): + return self._address + + @classmethod + def parse_target(cls, target): + """ Parse a target string to produce an address. + """ + if not target: + target = cls.default_target + address = Address.parse(target, default_host=cls.default_host, + default_port=cls.default_port) + return address + + +class _Routing: + + default_host = "localhost" + default_port = 7687 + + default_targets = ": :17601 :17687" + + def __init__(self, initial_addresses): + self._initial_addresses = initial_addresses + + @property + def initial_addresses(self): + return self._initial_addresses + + @classmethod + def parse_targets(cls, *targets): + """ Parse a sequence of target strings to produce an address + list. + """ + targets = " ".join(targets) + if not targets: + targets = cls.default_targets + addresses = Address.parse_list(targets, default_host=cls.default_host, default_port=cls.default_port) + return addresses + + +class AsyncDriver: + """ Base class for all types of :class:`neo4j.AsyncDriver`, instances of + which are used as the primary access point to Neo4j. + """ + + #: Connection pool + _pool = None + + def __init__(self, pool): + assert pool is not None + self._pool = pool + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() + + def __del__(self): + if not asyncio.iscoroutinefunction(self.close): + self.close() + + @property + def encrypted(self): + return bool(self._pool.pool_config.encrypted) + + def session(self, **config): + """Create a session, see :ref:`async-session-construction-ref` + + :param config: session configuration key-word arguments, + see :ref:`async-session-configuration-ref` for available key-word + arguments. + + :returns: new :class:`neo4j.AsyncSession` object + """ + raise NotImplementedError + + async def close(self): + """ Shut down, closing any open connections in the pool. + """ + await self._pool.close() + + @experimental("The configuration may change in the future.") + async def verify_connectivity(self, **config): + """ This verifies if the driver can connect to a remote server or a cluster + by establishing a network connection with the remote and possibly exchanging + a few data before closing the connection. It throws exception if fails to connect. + + Use the exception to further understand the cause of the connectivity problem. + + Note: Even if this method throws an exception, the driver still need to be closed via close() to free up all resources. + """ + raise NotImplementedError + + @experimental("Feature support query, based on Bolt Protocol Version and Neo4j Server Version will change in the future.") + async def supports_multi_db(self): + """ Check if the server or cluster supports multi-databases. + + :return: Returns true if the server or cluster the driver connects to supports multi-databases, otherwise false. + :rtype: bool + """ + async with self.session() as session: + await session._connect(READ_ACCESS) + return session._connection.supports_multiple_databases + + +class AsyncBoltDriver(_Direct, AsyncDriver): + """:class:`.AsyncBoltDriver` is instantiated for ``bolt`` URIs and + addresses a single database machine. This may be a standalone server or + could be a specific member of a cluster. + + Connections established by a :class:`.AsyncBoltDriver` are always made to + the exact host and port detailed in the URI. + + This class is not supposed to be instantiated externally. Use + :meth:`AsyncGraphDatabase.driver` instead. + """ + + @classmethod + def open(cls, target, *, auth=None, **config): + """ + :param target: + :param auth: + :param config: The values that can be specified are found in :class: `neo4j.PoolConfig` and :class: `neo4j.WorkspaceConfig` + + :return: + :rtype: :class: `neo4j.BoltDriver` + """ + from .io import AsyncBoltPool + address = cls.parse_target(target) + pool_config, default_workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) + pool = AsyncBoltPool.open(address, auth=auth, pool_config=pool_config, workspace_config=default_workspace_config) + return cls(pool, default_workspace_config) + + def __init__(self, pool, default_workspace_config): + _Direct.__init__(self, pool.address) + AsyncDriver.__init__(self, pool) + self._default_workspace_config = default_workspace_config + + def session(self, **config): + """ + :param config: The values that can be specified are found in :class: `neo4j.SessionConfig` + + :return: + :rtype: :class: `neo4j.AsyncSession` + """ + from .work import AsyncSession + session_config = SessionConfig(self._default_workspace_config, config) + SessionConfig.consume(config) # Consume the config + return AsyncSession(self._pool, session_config) + + @experimental("The configuration may change in the future.") + async def verify_connectivity(self, **config): + server_agent = None + config["fetch_size"] = -1 + async with self.session(**config) as session: + result = await session.run("RETURN 1 AS x") + value = await result.single().value() + summary = await result.consume() + server_agent = summary.server.agent + return server_agent + + +class AsyncNeo4jDriver(_Routing, AsyncDriver): + """:class:`.AsyncNeo4jDriver` is instantiated for ``neo4j`` URIs. The + routing behaviour works in tandem with Neo4j's `Causal Clustering + `_ + feature by directing read and write behaviour to appropriate + cluster members. + + This class is not supposed to be instantiated externally. Use + :meth:`AsyncGraphDatabase.driver` instead. + """ + + @classmethod + def open(cls, *targets, auth=None, routing_context=None, **config): + from .io import AsyncNeo4jPool + addresses = cls.parse_targets(*targets) + pool_config, default_workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) + pool = AsyncNeo4jPool.open(*addresses, auth=auth, routing_context=routing_context, pool_config=pool_config, workspace_config=default_workspace_config) + return cls(pool, default_workspace_config) + + def __init__(self, pool, default_workspace_config): + _Routing.__init__(self, pool.get_default_database_initial_router_addresses()) + AsyncDriver.__init__(self, pool) + self._default_workspace_config = default_workspace_config + + def session(self, **config): + from .work import AsyncSession + session_config = SessionConfig(self._default_workspace_config, config) + SessionConfig.consume(config) # Consume the config + return AsyncSession(self._pool, session_config) + + @experimental("The configuration may change in the future.") + async def verify_connectivity(self, **config): + """ + :raise ServiceUnavailable: raised if the server does not support routing or if routing support is broken. + """ + # TODO: Improve and update Stub Test Server to be able to test. + return await self._verify_routing_connectivity() + + async def _verify_routing_connectivity(self): + from ..exceptions import ( + Neo4jError, + ServiceUnavailable, + SessionExpired, + ) + + table = self._pool.get_routing_table_for_default_database() + routing_info = {} + for ix in list(table.routers): + try: + routing_info[ix] = await self._pool.fetch_routing_info( + address=table.routers[0], + database=self._default_workspace_config.database, + imp_user=self._default_workspace_config.impersonated_user, + bookmarks=None, + timeout=self._default_workspace_config + .connection_acquisition_timeout + ) + except (ServiceUnavailable, SessionExpired, Neo4jError): + routing_info[ix] = None + for key, val in routing_info.items(): + if val is not None: + return routing_info + raise ServiceUnavailable("Could not connect to any routing servers.") diff --git a/neo4j/_async/io/__init__.py b/neo4j/_async/io/__init__.py new file mode 100644 index 00000000..32aca11f --- /dev/null +++ b/neo4j/_async/io/__init__.py @@ -0,0 +1,43 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This module contains the low-level functionality required for speaking +Bolt. It is not intended to be used directly by driver users. Instead, +the `session` module provides the main user-facing abstractions. +""" + + +__all__ = [ + "AsyncBolt", + "AsyncBoltPool", + "AsyncNeo4jPool", + "check_supported_server_product", + "ConnectionErrorHandler", +] + + +from ._bolt import AsyncBolt +from ._common import ( + check_supported_server_product, + ConnectionErrorHandler, +) +from ._pool import ( + AsyncBoltPool, + AsyncNeo4jPool, +) diff --git a/neo4j/_async/io/_bolt.py b/neo4j/_async/io/_bolt.py new file mode 100644 index 00000000..503ebeaf --- /dev/null +++ b/neo4j/_async/io/_bolt.py @@ -0,0 +1,571 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import abc +import asyncio +from collections import deque +from logging import getLogger +from time import perf_counter + +from ..._async_compat.network import AsyncBoltSocket +from ..._exceptions import BoltHandshakeError +from ...addressing import Address +from ...api import ( + ServerInfo, + Version, +) +from ...conf import PoolConfig +from ...exceptions import ( + AuthError, + IncompleteCommit, + ServiceUnavailable, + SessionExpired, +) +from ...meta import get_user_agent +from ...packstream import ( + Packer, + Unpacker, +) +from ._common import ( + AsyncInbox, + CommitResponse, + Outbox, +) + + +# Set up logger +log = getLogger("neo4j") + + +class AsyncBolt: + """ Server connection for Bolt protocol. + + A :class:`.Bolt` should be constructed following a + successful .open() + + Bolt handshake and takes the socket over which + the handshake was carried out. + """ + + MAGIC_PREAMBLE = b"\x60\x60\xB0\x17" + + PROTOCOL_VERSION = None + + # flag if connection needs RESET to go back to READY state + is_reset = False + + # The socket + in_use = False + + # The socket + _closed = False + + # The socket + _defunct = False + + #: The pool of which this connection is a member + pool = None + + # Store the id of the most recent ran query to be able to reduce sent bits by + # using the default (-1) to refer to the most recent query when pulling + # results for it. + most_recent_qid = None + + def __init__(self, unresolved_address, sock, max_connection_lifetime, *, + auth=None, user_agent=None, routing_context=None): + self.unresolved_address = unresolved_address + self.socket = sock + self.server_info = ServerInfo(Address(sock.getpeername()), + self.PROTOCOL_VERSION) + # so far `connection.recv_timeout_seconds` is the only available + # configuration hint that exists. Therefore, all hints can be stored at + # connection level. This might change in the future. + self.configuration_hints = {} + self.outbox = Outbox() + self.inbox = AsyncInbox(self.socket, on_error=self._set_defunct_read) + self.packer = Packer(self.outbox) + self.unpacker = Unpacker(self.inbox) + self.responses = deque() + self._max_connection_lifetime = max_connection_lifetime + self._creation_timestamp = perf_counter() + self.routing_context = routing_context + + # Determine the user agent + if user_agent: + self.user_agent = user_agent + else: + self.user_agent = get_user_agent() + + # Determine auth details + if not auth: + self.auth_dict = {} + elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: + from neo4j import Auth + self.auth_dict = vars(Auth("basic", *auth)) + else: + try: + self.auth_dict = vars(auth) + except (KeyError, TypeError): + raise AuthError("Cannot determine auth details from %r" % auth) + + # Check for missing password + try: + credentials = self.auth_dict["credentials"] + except KeyError: + pass + else: + if credentials is None: + raise AuthError("Password cannot be None") + + def __del__(self): + if not asyncio.iscoroutinefunction(self.close): + self.close() + + @property + @abc.abstractmethod + def supports_multiple_results(self): + """ Boolean flag to indicate if the connection version supports multiple + queries to be buffered on the server side (True) or if all results need + to be eagerly pulled before sending the next RUN (False). + """ + pass + + @property + @abc.abstractmethod + def supports_multiple_databases(self): + """ Boolean flag to indicate if the connection version supports multiple + databases. + """ + pass + + @classmethod + def protocol_handlers(cls, protocol_version=None): + """ Return a dictionary of available Bolt protocol handlers, + keyed by version tuple. If an explicit protocol version is + provided, the dictionary will contain either zero or one items, + depending on whether that version is supported. If no protocol + version is provided, all available versions will be returned. + + :param protocol_version: tuple identifying a specific protocol + version (e.g. (3, 5)) or None + :return: dictionary of version tuple to handler class for all + relevant and supported protocol versions + :raise TypeError: if protocol version is not passed in a tuple + """ + + # Carry out Bolt subclass imports locally to avoid circular dependency issues. + from ._bolt3 import AsyncBolt3 + from ._bolt4 import ( + AsyncBolt4x0, + AsyncBolt4x1, + AsyncBolt4x2, + AsyncBolt4x3, + AsyncBolt4x4, + ) + + handlers = { + AsyncBolt3.PROTOCOL_VERSION: AsyncBolt3, + AsyncBolt4x0.PROTOCOL_VERSION: AsyncBolt4x0, + AsyncBolt4x1.PROTOCOL_VERSION: AsyncBolt4x1, + AsyncBolt4x2.PROTOCOL_VERSION: AsyncBolt4x2, + AsyncBolt4x3.PROTOCOL_VERSION: AsyncBolt4x3, + AsyncBolt4x4.PROTOCOL_VERSION: AsyncBolt4x4, + } + + if protocol_version is None: + return handlers + + if not isinstance(protocol_version, tuple): + raise TypeError("Protocol version must be specified as a tuple") + + if protocol_version in handlers: + return {protocol_version: handlers[protocol_version]} + + return {} + + @classmethod + def version_list(cls, versions, limit=4): + """ Return a list of supported protocol versions in order of + preference. The number of protocol versions (or ranges) + returned is limited to four. + """ + # In fact, 4.3 is the fist version to support ranges. However, the range + # support got backported to 4.2. But even if the server is too old to + # have the backport, negotiating BOLT 4.1 is no problem as it's + # equivalent to 4.2 + first_with_range_support = Version(4, 2) + result = [] + for version in versions: + if (result + and version >= first_with_range_support + and result[-1][0] == version[0] + and result[-1][1][1] == version[1] + 1): + # can use range to encompass this version + result[-1][1][1] = version[1] + continue + result.append(Version(version[0], [version[1], version[1]])) + if len(result) == 4: + break + return result + + @classmethod + def get_handshake(cls): + """ Return the supported Bolt versions as bytes. + The length is 16 bytes as specified in the Bolt version negotiation. + :return: bytes + """ + supported_versions = sorted(cls.protocol_handlers().keys(), reverse=True) + offered_versions = cls.version_list(supported_versions) + return b"".join(version.to_bytes() for version in offered_versions).ljust(16, b"\x00") + + @classmethod + async def ping(cls, address, *, timeout=None, **config): + """ Attempt to establish a Bolt connection, returning the + agreed Bolt protocol version if successful. + """ + config = PoolConfig.consume(config) + try: + s, protocol_version, handshake, data = \ + await AsyncBoltSocket.connect( + address, + timeout=timeout, + custom_resolver=config.resolver, + ssl_context=config.get_ssl_context(), + keep_alive=config.keep_alive, + ) + except (ServiceUnavailable, SessionExpired, BoltHandshakeError): + return None + else: + AsyncBoltSocket.close_socket(s) + return protocol_version + + @classmethod + async def open( + cls, address, *, auth=None, timeout=None, routing_context=None, **pool_config + ): + """ Open a new Bolt connection to a given server address. + + :param address: + :param auth: + :param timeout: the connection timeout in seconds + :param routing_context: dict containing routing context + :param pool_config: + :return: + :raise BoltHandshakeError: raised if the Bolt Protocol can not negotiate a protocol version. + :raise ServiceUnavailable: raised if there was a connection issue. + """ + pool_config = PoolConfig.consume(pool_config) + s, pool_config.protocol_version, handshake, data = \ + await AsyncBoltSocket.connect( + address, + timeout=timeout, + custom_resolver=pool_config.resolver, + ssl_context=pool_config.get_ssl_context(), + keep_alive=pool_config.keep_alive, + ) + + # Carry out Bolt subclass imports locally to avoid circular dependency + # issues. + if pool_config.protocol_version == (3, 0): + from ._bolt3 import AsyncBolt3 + bolt_cls = AsyncBolt3 + elif pool_config.protocol_version == (4, 0): + from ._bolt4 import AsyncBolt4x0 + bolt_cls = AsyncBolt4x0 + elif pool_config.protocol_version == (4, 1): + from ._bolt4 import AsyncBolt4x1 + bolt_cls = AsyncBolt4x1 + elif pool_config.protocol_version == (4, 2): + from ._bolt4 import AsyncBolt4x2 + bolt_cls = AsyncBolt4x2 + elif pool_config.protocol_version == (4, 3): + from ._bolt4 import AsyncBolt4x3 + bolt_cls = AsyncBolt4x3 + elif pool_config.protocol_version == (4, 4): + from ._bolt4 import AsyncBolt4x4 + bolt_cls = AsyncBolt4x4 + else: + log.debug("[#%04X] S: ", s.getsockname()[1]) + AsyncBoltSocket.close_socket(s) + + supported_versions = cls.protocol_handlers().keys() + raise BoltHandshakeError("The Neo4J server does not support communication with this driver. This driver have support for Bolt Protocols {}".format(supported_versions), address=address, request_data=handshake, response_data=data) + + connection = bolt_cls( + address, s, pool_config.max_connection_lifetime, auth=auth, + user_agent=pool_config.user_agent, routing_context=routing_context + ) + + try: + await connection.hello() + except Exception: + await connection.close() + raise + + return connection + + @property + @abc.abstractmethod + def encrypted(self): + pass + + @property + @abc.abstractmethod + def der_encoded_server_certificate(self): + pass + + @property + @abc.abstractmethod + def local_port(self): + pass + + @abc.abstractmethod + async def hello(self): + """ Appends a HELLO message to the outgoing queue, sends it and consumes + all remaining messages. + """ + pass + + @abc.abstractmethod + async def route(self, database=None, imp_user=None, bookmarks=None): + """ Fetch a routing table from the server for the given + `database`. For Bolt 4.3 and above, this appends a ROUTE + message; for earlier versions, a procedure call is made via + the regular Cypher execution mechanism. In all cases, this is + sent to the network, and a response is fetched. + + :param database: database for which to fetch a routing table + Requires Bolt 4.0+. + :param imp_user: the user to impersonate + Requires Bolt 4.4+. + :param bookmarks: iterable of bookmark values after which this + transaction should begin + :return: dictionary of raw routing data + """ + pass + + @abc.abstractmethod + def run(self, query, parameters=None, mode=None, bookmarks=None, + metadata=None, timeout=None, db=None, imp_user=None, **handlers): + """ Appends a RUN message to the output queue. + + :param query: Cypher query string + :param parameters: dictionary of Cypher parameters + :param mode: access mode for routing - "READ" or "WRITE" (default) + :param bookmarks: iterable of bookmark values after which this transaction should begin + :param metadata: custom metadata dictionary to attach to the transaction + :param timeout: timeout for transaction execution (seconds) + :param db: name of the database against which to begin the transaction + Requires Bolt 4.0+. + :param imp_user: the user to impersonate + Requires Bolt 4.4+. + :param handlers: handler functions passed into the returned Response object + :return: Response object + """ + pass + + @abc.abstractmethod + def discard(self, n=-1, qid=-1, **handlers): + """ Appends a DISCARD message to the output queue. + + :param n: number of records to discard, default = -1 (ALL) + :param qid: query ID to discard for, default = -1 (last query) + :param handlers: handler functions passed into the returned Response object + :return: Response object + """ + pass + + @abc.abstractmethod + def pull(self, n=-1, qid=-1, **handlers): + """ Appends a PULL message to the output queue. + + :param n: number of records to pull, default = -1 (ALL) + :param qid: query ID to pull for, default = -1 (last query) + :param handlers: handler functions passed into the returned Response object + :return: Response object + """ + pass + + @abc.abstractmethod + def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, + db=None, imp_user=None, **handlers): + """ Appends a BEGIN message to the output queue. + + :param mode: access mode for routing - "READ" or "WRITE" (default) + :param bookmarks: iterable of bookmark values after which this transaction should begin + :param metadata: custom metadata dictionary to attach to the transaction + :param timeout: timeout for transaction execution (seconds) + :param db: name of the database against which to begin the transaction + Requires Bolt 4.0+. + :param imp_user: the user to impersonate + Requires Bolt 4.4+ + :param handlers: handler functions passed into the returned Response object + :return: Response object + """ + pass + + @abc.abstractmethod + def commit(self, **handlers): + """ Appends a COMMIT message to the output queue.""" + pass + + @abc.abstractmethod + def rollback(self, **handlers): + """ Appends a ROLLBACK message to the output queue.""" + pass + + @abc.abstractmethod + async def reset(self): + """ Appends a RESET message to the outgoing queue, sends it and consumes + all remaining messages. + """ + pass + + def _append(self, signature, fields=(), response=None): + """ Appends a message to the outgoing queue. + + :param signature: the signature of the message + :param fields: the fields of the message as a tuple + :param response: a response object to handle callbacks + """ + self.packer.pack_struct(signature, fields) + self.outbox.wrap_message() + self.responses.append(response) + + async def _send_all(self): + data = self.outbox.view() + if data: + try: + await self.socket.sendall(data) + except OSError as error: + await self._set_defunct_write(error) + self.outbox.clear() + + async def send_all(self): + """ Send all queued messages to the server. + """ + if self.closed(): + raise ServiceUnavailable("Failed to write to closed connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address)) + + if self.defunct(): + raise ServiceUnavailable("Failed to write to defunct connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address)) + + await self._send_all() + + @abc.abstractmethod + async def fetch_message(self): + """ Receive at most one message from the server, if available. + + :return: 2-tuple of number of detail messages and number of summary + messages fetched + """ + pass + + async def fetch_all(self): + """ Fetch all outstanding messages. + + :return: 2-tuple of number of detail messages and number of summary + messages fetched + """ + detail_count = summary_count = 0 + while self.responses: + response = self.responses[0] + while not response.complete: + detail_delta, summary_delta = await self.fetch_message() + detail_count += detail_delta + summary_count += summary_delta + return detail_count, summary_count + + async def _set_defunct_read(self, error=None, silent=False): + message = "Failed to read from defunct connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address + ) + await self._set_defunct(message, error=error, silent=silent) + + async def _set_defunct_write(self, error=None, silent=False): + message = "Failed to write data to connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address + ) + await self._set_defunct(message, error=error, silent=silent) + + async def _set_defunct(self, message, error=None, silent=False): + from ._pool import AsyncBoltPool + direct_driver = isinstance(self.pool, AsyncBoltPool) + + if error: + log.debug("[#%04X] %s", self.socket.getsockname()[1], error) + log.error(message) + # We were attempting to receive data but the connection + # has unexpectedly terminated. So, we need to close the + # connection from the client side, and remove the address + # from the connection pool. + self._defunct = True + await self.close() + if self.pool: + await self.pool.deactivate(address=self.unresolved_address) + # Iterate through the outstanding responses, and if any correspond + # to COMMIT requests then raise an error to signal that we are + # unable to confirm that the COMMIT completed successfully. + if silent: + return + for response in self.responses: + if isinstance(response, CommitResponse): + if error: + raise IncompleteCommit(message) from error + else: + raise IncompleteCommit(message) + + if direct_driver: + if error: + raise ServiceUnavailable(message) from error + else: + raise ServiceUnavailable(message) + else: + if error: + raise SessionExpired(message) from error + else: + raise SessionExpired(message) + + def stale(self): + return (self._stale + or (0 <= self._max_connection_lifetime + <= perf_counter() - self._creation_timestamp)) + + _stale = False + + def set_stale(self): + self._stale = True + + @abc.abstractmethod + async def close(self): + """ Close the connection. + """ + pass + + @abc.abstractmethod + def closed(self): + pass + + @abc.abstractmethod + def defunct(self): + pass + + +AsyncBoltSocket.Bolt = AsyncBolt diff --git a/neo4j/_async/io/_bolt3.py b/neo4j/_async/io/_bolt3.py new file mode 100644 index 00000000..8bf7e939 --- /dev/null +++ b/neo4j/_async/io/_bolt3.py @@ -0,0 +1,396 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from enum import Enum +from logging import getLogger +from ssl import SSLSocket + +from ..._async_compat.util import AsyncUtil +from ..._exceptions import ( + BoltError, + BoltProtocolError, +) +from ...api import ( + READ_ACCESS, + Version, +) +from ...exceptions import ( + ConfigurationError, + DatabaseUnavailable, + DriverError, + ForbiddenOnReadOnlyDatabase, + Neo4jError, + NotALeader, + ServiceUnavailable, +) +from ._bolt import AsyncBolt +from ._common import ( + check_supported_server_product, + CommitResponse, + InitResponse, + Response, +) + + +log = getLogger("neo4j") + + +class ServerStates(Enum): + CONNECTED = "CONNECTED" + READY = "READY" + STREAMING = "STREAMING" + TX_READY_OR_TX_STREAMING = "TX_READY||TX_STREAMING" + FAILED = "FAILED" + + +class ServerStateManager: + _STATE_TRANSITIONS = { + ServerStates.CONNECTED: { + "hello": ServerStates.READY, + }, + ServerStates.READY: { + "run": ServerStates.STREAMING, + "begin": ServerStates.TX_READY_OR_TX_STREAMING, + }, + ServerStates.STREAMING: { + "pull": ServerStates.READY, + "discard": ServerStates.READY, + "reset": ServerStates.READY, + }, + ServerStates.TX_READY_OR_TX_STREAMING: { + "commit": ServerStates.READY, + "rollback": ServerStates.READY, + "reset": ServerStates.READY, + }, + ServerStates.FAILED: { + "reset": ServerStates.READY, + } + } + + def __init__(self, init_state, on_change=None): + self.state = init_state + self._on_change = on_change + + def transition(self, message, metadata): + if metadata.get("has_more"): + return + state_before = self.state + self.state = self._STATE_TRANSITIONS\ + .get(self.state, {})\ + .get(message, self.state) + if state_before != self.state and callable(self._on_change): + self._on_change(state_before, self.state) + + +class AsyncBolt3(AsyncBolt): + """ Protocol handler for Bolt 3. + + This is supported by Neo4j versions 3.5, 4.0, 4.1, 4.2, 4.3, and 4.4. + """ + + PROTOCOL_VERSION = Version(3, 0) + + supports_multiple_results = False + + supports_multiple_databases = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._server_state_manager = ServerStateManager( + ServerStates.CONNECTED, on_change=self._on_server_state_change + ) + + def _on_server_state_change(self, old_state, new_state): + log.debug("[#%04X] State: %s > %s", self.local_port, + old_state.name, new_state.name) + + @property + def is_reset(self): + # We can't be sure of the server's state if there are still pending + # responses. Unless the last message we sent was RESET. In that case + # the server state will always be READY when we're done. + if (self.responses and self.responses[-1] + and self.responses[-1].message == "reset"): + return True + return self._server_state_manager.state == ServerStates.READY + + @property + def encrypted(self): + return isinstance(self.socket, SSLSocket) + + @property + def der_encoded_server_certificate(self): + return self.socket.getpeercert(binary_form=True) + + @property + def local_port(self): + try: + return self.socket.getsockname()[1] + except OSError: + return 0 + + def get_base_headers(self): + return { + "user_agent": self.user_agent, + } + + async def hello(self): + headers = self.get_base_headers() + headers.update(self.auth_dict) + logged_headers = dict(headers) + if "credentials" in logged_headers: + logged_headers["credentials"] = "*******" + log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) + self._append(b"\x01", (headers,), + response=InitResponse(self, "hello", + on_success=self.server_info.update)) + await self.send_all() + await self.fetch_all() + check_supported_server_product(self.server_info.agent) + + async def route(self, database=None, imp_user=None, bookmarks=None): + if database is not None: + raise ConfigurationError( + "Database name parameter for selecting database is not " + "supported in Bolt Protocol {!r}. Database name {!r}. " + "Server Agent {!r}".format( + self.PROTOCOL_VERSION, database, self.server_info.agent + ) + ) + if imp_user is not None: + raise ConfigurationError( + "Impersonation is not supported in Bolt Protocol {!r}. " + "Trying to impersonate {!r}.".format( + self.PROTOCOL_VERSION, imp_user + ) + ) + + metadata = {} + records = [] + + # Ignoring database and bookmarks because there is no multi-db support. + # The bookmarks are only relevant for making sure a previously created + # db exists before querying a routing table for it. + self.run( + "CALL dbms.cluster.routing.getRoutingTable($context)", # This is an internal procedure call. Only available if the Neo4j 3.5 is setup with clustering. + {"context": self.routing_context}, + mode="r", # Bolt Protocol Version(3, 0) supports mode="r" + on_success=metadata.update + ) + self.pull(on_success=metadata.update, on_records=records.extend) + await self.send_all() + await self.fetch_all() + routing_info = [dict(zip(metadata.get("fields", ()), values)) for values in records] + return routing_info + + def run(self, query, parameters=None, mode=None, bookmarks=None, + metadata=None, timeout=None, db=None, imp_user=None, **handlers): + if db is not None: + raise ConfigurationError( + "Database name parameter for selecting database is not " + "supported in Bolt Protocol {!r}. Database name {!r}.".format( + self.PROTOCOL_VERSION, db + ) + ) + if imp_user is not None: + raise ConfigurationError( + "Impersonation is not supported in Bolt Protocol {!r}. " + "Trying to impersonate {!r}.".format( + self.PROTOCOL_VERSION, imp_user + ) + ) + if not parameters: + parameters = {} + extra = {} + if mode in (READ_ACCESS, "r"): + extra["mode"] = "r" # It will default to mode "w" if nothing is specified + if bookmarks: + try: + extra["bookmarks"] = list(bookmarks) + except TypeError: + raise TypeError("Bookmarks must be provided within an iterable") + if metadata: + try: + extra["tx_metadata"] = dict(metadata) + except TypeError: + raise TypeError("Metadata must be coercible to a dict") + if timeout: + try: + extra["tx_timeout"] = int(1000 * timeout) + except TypeError: + raise TypeError("Timeout must be specified as a number of seconds") + fields = (query, parameters, extra) + log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) + if query.upper() == u"COMMIT": + self._append(b"\x10", fields, CommitResponse(self, "run", + **handlers)) + else: + self._append(b"\x10", fields, Response(self, "run", **handlers)) + + def discard(self, n=-1, qid=-1, **handlers): + # Just ignore n and qid, it is not supported in the Bolt 3 Protocol. + log.debug("[#%04X] C: DISCARD_ALL", self.local_port) + self._append(b"\x2F", (), Response(self, "discard", **handlers)) + + def pull(self, n=-1, qid=-1, **handlers): + # Just ignore n and qid, it is not supported in the Bolt 3 Protocol. + log.debug("[#%04X] C: PULL_ALL", self.local_port) + self._append(b"\x3F", (), Response(self, "pull", **handlers)) + + def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, + db=None, imp_user=None, **handlers): + if db is not None: + raise ConfigurationError( + "Database name parameter for selecting database is not " + "supported in Bolt Protocol {!r}. Database name {!r}.".format( + self.PROTOCOL_VERSION, db + ) + ) + if imp_user is not None: + raise ConfigurationError( + "Impersonation is not supported in Bolt Protocol {!r}. " + "Trying to impersonate {!r}.".format( + self.PROTOCOL_VERSION, imp_user + ) + ) + extra = {} + if mode in (READ_ACCESS, "r"): + extra["mode"] = "r" # It will default to mode "w" if nothing is specified + if bookmarks: + try: + extra["bookmarks"] = list(bookmarks) + except TypeError: + raise TypeError("Bookmarks must be provided within an iterable") + if metadata: + try: + extra["tx_metadata"] = dict(metadata) + except TypeError: + raise TypeError("Metadata must be coercible to a dict") + if timeout: + try: + extra["tx_timeout"] = int(1000 * timeout) + except TypeError: + raise TypeError("Timeout must be specified as a number of seconds") + log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) + self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) + + def commit(self, **handlers): + log.debug("[#%04X] C: COMMIT", self.local_port) + self._append(b"\x12", (), CommitResponse(self, "commit", **handlers)) + + def rollback(self, **handlers): + log.debug("[#%04X] C: ROLLBACK", self.local_port) + self._append(b"\x13", (), Response(self, "rollback", **handlers)) + + async def reset(self): + """ Add a RESET message to the outgoing queue, send + it and consume all remaining messages. + """ + + def fail(metadata): + raise BoltProtocolError("RESET failed %r" % metadata, address=self.unresolved_address) + + log.debug("[#%04X] C: RESET", self.local_port) + self._append(b"\x0F", response=Response(self, "reset", on_failure=fail)) + await self.send_all() + await self.fetch_all() + + async def fetch_message(self): + """ Receive at most one message from the server, if available. + + :return: 2-tuple of number of detail messages and number of summary + messages fetched + """ + if self._closed: + raise ServiceUnavailable("Failed to read from closed connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address)) + + if self._defunct: + raise ServiceUnavailable("Failed to read from defunct connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address)) + + if not self.responses: + return 0, 0 + + # Receive exactly one message + details, summary_signature, summary_metadata = \ + await AsyncUtil.next(self.inbox) + + if details: + log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # Do not log any data + await self.responses[0].on_records(details) + + if summary_signature is None: + return len(details), 0 + + response = self.responses.popleft() + response.complete = True + if summary_signature == b"\x70": + log.debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata) + self._server_state_manager.transition(response.message, + summary_metadata) + await response.on_success(summary_metadata or {}) + elif summary_signature == b"\x7E": + log.debug("[#%04X] S: IGNORED", self.local_port) + await response.on_ignored(summary_metadata or {}) + elif summary_signature == b"\x7F": + log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata) + self._server_state_manager.state = ServerStates.FAILED + try: + await response.on_failure(summary_metadata or {}) + except (ServiceUnavailable, DatabaseUnavailable): + if self.pool: + await self.pool.deactivate(address=self.unresolved_address) + raise + except (NotALeader, ForbiddenOnReadOnlyDatabase): + if self.pool: + self.pool.on_write_failure(address=self.unresolved_address) + raise + except Neo4jError as e: + if self.pool and e.invalidates_all_connections(): + await self.pool.mark_all_stale() + raise + else: + raise BoltProtocolError("Unexpected response message with signature %02X" % summary_signature, address=self.unresolved_address) + + return len(details), 1 + + async def close(self): + """ Close the connection. + """ + if not self._closed: + if not self._defunct: + log.debug("[#%04X] C: GOODBYE", self.local_port) + self._append(b"\x02", ()) + try: + await self._send_all() + except (OSError, BoltError, DriverError): + pass + log.debug("[#%04X] C: ", self.local_port) + try: + self.socket.close() + except OSError: + pass + finally: + self._closed = True + + def closed(self): + return self._closed + + def defunct(self): + return self._defunct diff --git a/neo4j/_async/io/_bolt4.py b/neo4j/_async/io/_bolt4.py new file mode 100644 index 00000000..326e7212 --- /dev/null +++ b/neo4j/_async/io/_bolt4.py @@ -0,0 +1,537 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from logging import getLogger +from ssl import SSLSocket + +from ..._async_compat.util import AsyncUtil +from ..._exceptions import ( + BoltError, + BoltProtocolError, +) +from ...api import ( + READ_ACCESS, + SYSTEM_DATABASE, + Version, +) +from ...exceptions import ( + ConfigurationError, + DatabaseUnavailable, + DriverError, + ForbiddenOnReadOnlyDatabase, + Neo4jError, + NotALeader, + ServiceUnavailable, +) +from ._bolt3 import ( + ServerStateManager, + ServerStates, +) +from ._bolt import AsyncBolt +from ._common import ( + check_supported_server_product, + CommitResponse, + InitResponse, + Response, +) + + +log = getLogger("neo4j") + + +class AsyncBolt4x0(AsyncBolt): + """ Protocol handler for Bolt 4.0. + + This is supported by Neo4j versions 4.0, 4.1 and 4.2. + """ + + PROTOCOL_VERSION = Version(4, 0) + + supports_multiple_results = True + + supports_multiple_databases = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._server_state_manager = ServerStateManager( + ServerStates.CONNECTED, on_change=self._on_server_state_change + ) + + def _on_server_state_change(self, old_state, new_state): + log.debug("[#%04X] State: %s > %s", self.local_port, + old_state.name, new_state.name) + + @property + def is_reset(self): + # We can't be sure of the server's state if there are still pending + # responses. Unless the last message we sent was RESET. In that case + # the server state will always be READY when we're done. + if (self.responses and self.responses[-1] + and self.responses[-1].message == "reset"): + return True + return self._server_state_manager.state == ServerStates.READY + + @property + def encrypted(self): + return isinstance(self.socket, SSLSocket) + + @property + def der_encoded_server_certificate(self): + return self.socket.getpeercert(binary_form=True) + + @property + def local_port(self): + try: + return self.socket.getsockname()[1] + except OSError: + return 0 + + def get_base_headers(self): + return { + "user_agent": self.user_agent, + } + + async def hello(self): + headers = self.get_base_headers() + headers.update(self.auth_dict) + logged_headers = dict(headers) + if "credentials" in logged_headers: + logged_headers["credentials"] = "*******" + log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) + self._append(b"\x01", (headers,), + response=InitResponse(self, "hello", + on_success=self.server_info.update)) + await self.send_all() + await self.fetch_all() + check_supported_server_product(self.server_info.agent) + + async def route(self, database=None, imp_user=None, bookmarks=None): + if imp_user is not None: + raise ConfigurationError( + "Impersonation is not supported in Bolt Protocol {!r}. " + "Trying to impersonate {!r}.".format( + self.PROTOCOL_VERSION, imp_user + ) + ) + metadata = {} + records = [] + + if database is None: # default database + self.run( + "CALL dbms.routing.getRoutingTable($context)", + {"context": self.routing_context}, + mode="r", + bookmarks=bookmarks, + db=SYSTEM_DATABASE, + on_success=metadata.update + ) + else: + self.run( + "CALL dbms.routing.getRoutingTable($context, $database)", + {"context": self.routing_context, "database": database}, + mode="r", + bookmarks=bookmarks, + db=SYSTEM_DATABASE, + on_success=metadata.update + ) + self.pull(on_success=metadata.update, on_records=records.extend) + await self.send_all() + await self.fetch_all() + routing_info = [dict(zip(metadata.get("fields", ()), values)) for values in records] + return routing_info + + def run(self, query, parameters=None, mode=None, bookmarks=None, + metadata=None, timeout=None, db=None, imp_user=None, **handlers): + if imp_user is not None: + raise ConfigurationError( + "Impersonation is not supported in Bolt Protocol {!r}. " + "Trying to impersonate {!r}.".format( + self.PROTOCOL_VERSION, imp_user + ) + ) + if not parameters: + parameters = {} + extra = {} + if mode in (READ_ACCESS, "r"): + extra["mode"] = "r" # It will default to mode "w" if nothing is specified + if db: + extra["db"] = db + if bookmarks: + try: + extra["bookmarks"] = list(bookmarks) + except TypeError: + raise TypeError("Bookmarks must be provided within an iterable") + if metadata: + try: + extra["tx_metadata"] = dict(metadata) + except TypeError: + raise TypeError("Metadata must be coercible to a dict") + if timeout: + try: + extra["tx_timeout"] = int(1000 * timeout) + except TypeError: + raise TypeError("Timeout must be specified as a number of seconds") + fields = (query, parameters, extra) + log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) + if query.upper() == u"COMMIT": + self._append(b"\x10", fields, CommitResponse(self, "run", + **handlers)) + else: + self._append(b"\x10", fields, Response(self, "run", **handlers)) + + def discard(self, n=-1, qid=-1, **handlers): + extra = {"n": n} + if qid != -1: + extra["qid"] = qid + log.debug("[#%04X] C: DISCARD %r", self.local_port, extra) + self._append(b"\x2F", (extra,), Response(self, "discard", **handlers)) + + def pull(self, n=-1, qid=-1, **handlers): + extra = {"n": n} + if qid != -1: + extra["qid"] = qid + log.debug("[#%04X] C: PULL %r", self.local_port, extra) + self._append(b"\x3F", (extra,), Response(self, "pull", **handlers)) + + def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, + db=None, imp_user=None, **handlers): + if imp_user is not None: + raise ConfigurationError( + "Impersonation is not supported in Bolt Protocol {!r}. " + "Trying to impersonate {!r}.".format( + self.PROTOCOL_VERSION, imp_user + ) + ) + extra = {} + if mode in (READ_ACCESS, "r"): + extra["mode"] = "r" # It will default to mode "w" if nothing is specified + if db: + extra["db"] = db + if bookmarks: + try: + extra["bookmarks"] = list(bookmarks) + except TypeError: + raise TypeError("Bookmarks must be provided within an iterable") + if metadata: + try: + extra["tx_metadata"] = dict(metadata) + except TypeError: + raise TypeError("Metadata must be coercible to a dict") + if timeout: + try: + extra["tx_timeout"] = int(1000 * timeout) + except TypeError: + raise TypeError("Timeout must be specified as a number of seconds") + log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) + self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) + + def commit(self, **handlers): + log.debug("[#%04X] C: COMMIT", self.local_port) + self._append(b"\x12", (), CommitResponse(self, "commit", **handlers)) + + def rollback(self, **handlers): + log.debug("[#%04X] C: ROLLBACK", self.local_port) + self._append(b"\x13", (), Response(self, "rollback", **handlers)) + + async def reset(self): + """ Add a RESET message to the outgoing queue, send + it and consume all remaining messages. + """ + + def fail(metadata): + raise BoltProtocolError("RESET failed %r" % metadata, self.unresolved_address) + + log.debug("[#%04X] C: RESET", self.local_port) + self._append(b"\x0F", response=Response(self, "reset", on_failure=fail)) + await self.send_all() + await self.fetch_all() + + async def fetch_message(self): + """ Receive at most one message from the server, if available. + + :return: 2-tuple of number of detail messages and number of summary + messages fetched + """ + if self._closed: + raise ServiceUnavailable("Failed to read from closed connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address)) + + if self._defunct: + raise ServiceUnavailable("Failed to read from defunct connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address)) + + if not self.responses: + return 0, 0 + + # Receive exactly one message + details, summary_signature, summary_metadata = \ + await AsyncUtil.next(self.inbox) + + if details: + log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # Do not log any data + await self.responses[0].on_records(details) + + if summary_signature is None: + return len(details), 0 + + response = self.responses.popleft() + response.complete = True + if summary_signature == b"\x70": + log.debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata) + self._server_state_manager.transition(response.message, + summary_metadata) + await response.on_success(summary_metadata or {}) + elif summary_signature == b"\x7E": + log.debug("[#%04X] S: IGNORED", self.local_port) + await response.on_ignored(summary_metadata or {}) + elif summary_signature == b"\x7F": + log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata) + self._server_state_manager.state = ServerStates.FAILED + try: + await response.on_failure(summary_metadata or {}) + except (ServiceUnavailable, DatabaseUnavailable): + if self.pool: + await self.pool.deactivate(address=self.unresolved_address) + raise + except (NotALeader, ForbiddenOnReadOnlyDatabase): + if self.pool: + self.pool.on_write_failure(address=self.unresolved_address) + raise + except Neo4jError as e: + if self.pool and e.invalidates_all_connections(): + await self.pool.mark_all_stale() + raise + else: + raise BoltProtocolError("Unexpected response message with signature " + "%02X" % ord(summary_signature), self.unresolved_address) + + return len(details), 1 + + async def close(self): + """ Close the connection. + """ + if not self._closed: + if not self._defunct: + log.debug("[#%04X] C: GOODBYE", self.local_port) + self._append(b"\x02", ()) + try: + await self._send_all() + except (OSError, BoltError, DriverError): + pass + log.debug("[#%04X] C: ", self.local_port) + try: + self.socket.close() + except OSError: + pass + finally: + self._closed = True + + def closed(self): + return self._closed + + def defunct(self): + return self._defunct + + +class AsyncBolt4x1(AsyncBolt4x0): + """ Protocol handler for Bolt 4.1. + + This is supported by Neo4j versions 4.1 - 4.4. + """ + + PROTOCOL_VERSION = Version(4, 1) + + def get_base_headers(self): + """ Bolt 4.1 passes the routing context, originally taken from + the URI, into the connection initialisation message. This + enables server-side routing to propagate the same behaviour + through its driver. + """ + headers = { + "user_agent": self.user_agent, + } + if self.routing_context is not None: + headers["routing"] = self.routing_context + return headers + + +class AsyncBolt4x2(AsyncBolt4x1): + """ Protocol handler for Bolt 4.2. + + This is supported by Neo4j version 4.2 - 4.4. + """ + + PROTOCOL_VERSION = Version(4, 2) + + +class AsyncBolt4x3(AsyncBolt4x2): + """ Protocol handler for Bolt 4.3. + + This is supported by Neo4j version 4.3 - 4.4. + """ + + PROTOCOL_VERSION = Version(4, 3) + + async def route(self, database=None, imp_user=None, bookmarks=None): + if imp_user is not None: + raise ConfigurationError( + "Impersonation is not supported in Bolt Protocol {!r}. " + "Trying to impersonate {!r}.".format( + self.PROTOCOL_VERSION, imp_user + ) + ) + + routing_context = self.routing_context or {} + log.debug("[#%04X] C: ROUTE %r %r %r", self.local_port, + routing_context, bookmarks, database) + metadata = {} + if bookmarks is None: + bookmarks = [] + else: + bookmarks = list(bookmarks) + self._append(b"\x66", (routing_context, bookmarks, database), + response=Response(self, "route", + on_success=metadata.update)) + await self.send_all() + await self.fetch_all() + return [metadata.get("rt")] + + async def hello(self): + def on_success(metadata): + self.configuration_hints.update(metadata.pop("hints", {})) + self.server_info.update(metadata) + if "connection.recv_timeout_seconds" in self.configuration_hints: + recv_timeout = self.configuration_hints[ + "connection.recv_timeout_seconds" + ] + if isinstance(recv_timeout, int) and recv_timeout > 0: + self.socket.settimeout(recv_timeout) + else: + log.info("[#%04X] Server supplied an invalid value for " + "connection.recv_timeout_seconds (%r). Make sure " + "the server and network is set up correctly.", + self.local_port, recv_timeout) + + headers = self.get_base_headers() + headers.update(self.auth_dict) + logged_headers = dict(headers) + if "credentials" in logged_headers: + logged_headers["credentials"] = "*******" + log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) + self._append(b"\x01", (headers,), + response=InitResponse(self, "hello", + on_success=on_success)) + await self.send_all() + await self.fetch_all() + check_supported_server_product(self.server_info.agent) + + +class AsyncBolt4x4(AsyncBolt4x3): + """ Protocol handler for Bolt 4.4. + + This is supported by Neo4j version 4.4. + """ + + PROTOCOL_VERSION = Version(4, 4) + + async def route(self, database=None, imp_user=None, bookmarks=None): + routing_context = self.routing_context or {} + db_context = {} + if database is not None: + db_context.update(db=database) + if imp_user is not None: + db_context.update(imp_user=imp_user) + log.debug("[#%04X] C: ROUTE %r %r %r", self.local_port, + routing_context, bookmarks, db_context) + metadata = {} + if bookmarks is None: + bookmarks = [] + else: + bookmarks = list(bookmarks) + self._append(b"\x66", (routing_context, bookmarks, db_context), + response=Response(self, "route", + on_success=metadata.update)) + await self.send_all() + await self.fetch_all() + return [metadata.get("rt")] + + def run(self, query, parameters=None, mode=None, bookmarks=None, + metadata=None, timeout=None, db=None, imp_user=None, **handlers): + if not parameters: + parameters = {} + extra = {} + if mode in (READ_ACCESS, "r"): + # It will default to mode "w" if nothing is specified + extra["mode"] = "r" + if db: + extra["db"] = db + if imp_user: + extra["imp_user"] = imp_user + if bookmarks: + try: + extra["bookmarks"] = list(bookmarks) + except TypeError: + raise TypeError("Bookmarks must be provided within an iterable") + if metadata: + try: + extra["tx_metadata"] = dict(metadata) + except TypeError: + raise TypeError("Metadata must be coercible to a dict") + if timeout: + try: + extra["tx_timeout"] = int(1000 * timeout) + except TypeError: + raise TypeError("Timeout must be specified as a number of " + "seconds") + fields = (query, parameters, extra) + log.debug("[#%04X] C: RUN %s", self.local_port, + " ".join(map(repr, fields))) + if query.upper() == u"COMMIT": + self._append(b"\x10", fields, CommitResponse(self, "run", + **handlers)) + else: + self._append(b"\x10", fields, Response(self, "run", **handlers)) + + def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, + db=None, imp_user=None, **handlers): + extra = {} + if mode in (READ_ACCESS, "r"): + # It will default to mode "w" if nothing is specified + extra["mode"] = "r" + if db: + extra["db"] = db + if imp_user: + extra["imp_user"] = imp_user + if bookmarks: + try: + extra["bookmarks"] = list(bookmarks) + except TypeError: + raise TypeError("Bookmarks must be provided within an iterable") + if metadata: + try: + extra["tx_metadata"] = dict(metadata) + except TypeError: + raise TypeError("Metadata must be coercible to a dict") + if timeout: + try: + extra["tx_timeout"] = int(1000 * timeout) + except TypeError: + raise TypeError("Timeout must be specified as a number of " + "seconds") + log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) + self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) diff --git a/neo4j/_async/io/_common.py b/neo4j/_async/io/_common.py new file mode 100644 index 00000000..5e0d18f5 --- /dev/null +++ b/neo4j/_async/io/_common.py @@ -0,0 +1,280 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import logging +import socket +from struct import pack as struct_pack + +from ..._async_compat.util import AsyncUtil +from ...exceptions import ( + Neo4jError, + ServiceUnavailable, + SessionExpired, + UnsupportedServerProduct, +) +from ...packstream import ( + UnpackableBuffer, + Unpacker, +) + + +log = logging.getLogger("neo4j") + + +class AsyncMessageInbox: + + def __init__(self, s, on_error): + self.on_error = on_error + self._messages = self._yield_messages(s) + + async def _yield_messages(self, sock): + try: + buffer = UnpackableBuffer() + unpacker = Unpacker(buffer) + chunk_size = 0 + while True: + + while chunk_size == 0: + # Determine the chunk size and skip noop + await receive_into_buffer(sock, buffer, 2) + chunk_size = buffer.pop_u16() + if chunk_size == 0: + log.debug("[#%04X] S: ", sock.getsockname()[1]) + + await receive_into_buffer(sock, buffer, chunk_size + 2) + chunk_size = buffer.pop_u16() + + if chunk_size == 0: + # chunk_size was the end marker for the message + size, tag = unpacker.unpack_structure_header() + fields = [unpacker.unpack() for _ in range(size)] + yield tag, fields + # Reset for new message + unpacker.reset() + + except (OSError, socket.timeout) as error: + await AsyncUtil.callback(self.on_error, error) + + async def pop(self): + return await AsyncUtil.next(self._messages) + + +class AsyncInbox(AsyncMessageInbox): + + async def __anext__(self): + tag, fields = await self.pop() + if tag == b"\x71": + return fields, None, None + elif fields: + return [], tag, fields[0] + else: + return [], tag, None + + +class Outbox: + + def __init__(self, max_chunk_size=16384): + self._max_chunk_size = max_chunk_size + self._chunked_data = bytearray() + self._raw_data = bytearray() + self.write = self._raw_data.extend + + def max_chunk_size(self): + return self._max_chunk_size + + def clear(self): + self._chunked_data = bytearray() + self._raw_data.clear() + + def _chunk_data(self): + data_len = len(self._raw_data) + num_full_chunks, chunk_rest = divmod( + data_len, self._max_chunk_size + ) + num_chunks = num_full_chunks + bool(chunk_rest) + + data_view = memoryview(self._raw_data) + header_start = len(self._chunked_data) + data_start = header_start + 2 + raw_data_start = 0 + for i in range(num_chunks): + chunk_size = min(data_len - raw_data_start, + self._max_chunk_size) + self._chunked_data[header_start:data_start] = struct_pack( + ">H", chunk_size + ) + self._chunked_data[data_start:(data_start + chunk_size)] = \ + data_view[raw_data_start:(raw_data_start + chunk_size)] + header_start += chunk_size + 2 + data_start = header_start + 2 + raw_data_start += chunk_size + del data_view + self._raw_data.clear() + + def wrap_message(self): + self._chunk_data() + self._chunked_data += b"\x00\x00" + + def view(self): + self._chunk_data() + return memoryview(self._chunked_data) + + +class ConnectionErrorHandler: + """ + Wrapper class for handling connection errors. + + The class will wrap each method to invoke a callback if the method raises + Neo4jError, SessionExpired, or ServiceUnavailable. + The error will be re-raised after the callback. + """ + + def __init__(self, connection, on_error): + """ + :param connection the connection object to warp + :type connection Bolt + :param on_error the function to be called when a method of + connection raises of of the caught errors. + :type on_error callable + """ + self.__connection = connection + self.__on_error = on_error + + def __getattr__(self, name): + connection_attr = getattr(self.__connection, name) + if not callable(connection_attr): + return connection_attr + + def outer(func): + def inner(*args, **kwargs): + try: + func(*args, **kwargs) + except (Neo4jError, ServiceUnavailable, SessionExpired) as exc: + assert not asyncio.iscoroutinefunction(self.__on_error) + self.__on_error(exc) + raise + return inner + + def outer_async(coroutine_func): + async def inner(*args, **kwargs): + try: + await coroutine_func(*args, **kwargs) + except (Neo4jError, ServiceUnavailable, SessionExpired) as exc: + await AsyncUtil.callback(self.__on_error, exc) + raise + return inner + + if asyncio.iscoroutinefunction(connection_attr): + return outer_async(connection_attr) + return outer(connection_attr) + + def __setattr__(self, name, value): + if name.startswith("_" + self.__class__.__name__ + "__"): + super().__setattr__(name, value) + else: + setattr(self.__connection, name, value) + + +class Response: + """ Subscriber object for a full response (zero or + more detail messages followed by one summary message). + """ + + def __init__(self, connection, message, **handlers): + self.connection = connection + self.handlers = handlers + self.message = message + self.complete = False + + async def on_records(self, records): + """ Called when one or more RECORD messages have been received. + """ + handler = self.handlers.get("on_records") + await AsyncUtil.callback(handler, records) + + async def on_success(self, metadata): + """ Called when a SUCCESS message has been received. + """ + handler = self.handlers.get("on_success") + await AsyncUtil.callback(handler, metadata) + + if not metadata.get("has_more"): + handler = self.handlers.get("on_summary") + await AsyncUtil.callback(handler) + + async def on_failure(self, metadata): + """ Called when a FAILURE message has been received. + """ + try: + self.connection.reset() + except (SessionExpired, ServiceUnavailable): + pass + handler = self.handlers.get("on_failure") + await AsyncUtil.callback(handler, metadata) + handler = self.handlers.get("on_summary") + await AsyncUtil.callback(handler) + raise Neo4jError.hydrate(**metadata) + + async def on_ignored(self, metadata=None): + """ Called when an IGNORED message has been received. + """ + handler = self.handlers.get("on_ignored") + await AsyncUtil.callback(handler, metadata) + handler = self.handlers.get("on_summary") + await AsyncUtil.callback(handler) + + +class InitResponse(Response): + + async def on_failure(self, metadata): + code = metadata.get("code") + if code == "Neo.ClientError.Security.Unauthorized": + raise Neo4jError.hydrate(**metadata) + else: + raise ServiceUnavailable( + metadata.get("message", "Connection initialisation failed") + ) + + +class CommitResponse(Response): + + pass + + +def check_supported_server_product(agent): + """ Checks that a server product is supported by the driver by + looking at the server agent string. + + :param agent: server agent string to check for validity + :raises UnsupportedServerProduct: if the product is not supported + """ + if not agent.startswith("Neo4j/"): + raise UnsupportedServerProduct(agent) + + +async def receive_into_buffer(sock, buffer, n_bytes): + end = buffer.used + n_bytes + if end > len(buffer.data): + buffer.data += bytearray(end - len(buffer.data)) + view = memoryview(buffer.data) + while buffer.used < end: + n = await sock.recv_into(view[buffer.used:end], end - buffer.used) + if n == 0: + raise OSError("No data") + buffer.used += n diff --git a/neo4j/_async/io/_pool.py b/neo4j/_async/io/_pool.py new file mode 100644 index 00000000..30998b26 --- /dev/null +++ b/neo4j/_async/io/_pool.py @@ -0,0 +1,701 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import abc +from collections import ( + defaultdict, + deque, +) +import logging +from logging import getLogger +from random import choice +from time import perf_counter + +from ..._async_compat.concurrency import ( + AsyncCondition, + AsyncRLock, +) +from ..._async_compat.network import AsyncNetworkUtil +from ..._exceptions import BoltError +from ...api import ( + READ_ACCESS, + WRITE_ACCESS, +) +from ...conf import ( + PoolConfig, + WorkspaceConfig, +) +from ...exceptions import ( + ClientError, + ConfigurationError, + DriverError, + Neo4jError, + ReadServiceUnavailable, + ServiceUnavailable, + SessionExpired, + WriteServiceUnavailable, +) +from ...routing import RoutingTable +from ._bolt import AsyncBolt + + +# Set up logger +log = getLogger("neo4j") + + +class AsyncIOPool(abc.ABC): + """ A collection of connections to one or more server addresses. + """ + + def __init__(self, opener, pool_config, workspace_config): + assert callable(opener) + assert isinstance(pool_config, PoolConfig) + assert isinstance(workspace_config, WorkspaceConfig) + + self.opener = opener + self.pool_config = pool_config + self.workspace_config = workspace_config + self.connections = defaultdict(deque) + self.lock = AsyncRLock() + self.cond = AsyncCondition(self.lock) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() + + async def _acquire(self, address, timeout): + """ Acquire a connection to a given address from the pool. + The address supplied should always be an IP address, not + a host name. + + This method is thread safe. + """ + t0 = perf_counter() + if timeout is None: + timeout = self.workspace_config.connection_acquisition_timeout + + async with self.lock: + def time_remaining(): + t = timeout - (perf_counter() - t0) + return t if t > 0 else 0 + + while True: + # try to find a free connection in pool + for connection in list(self.connections.get(address, [])): + if (connection.closed() or connection.defunct() + or (connection.stale() and not connection.in_use)): + # `close` is a noop on already closed connections. + # This is to make sure that the connection is + # gracefully closed, e.g. if it's just marked as + # `stale` but still alive. + if log.isEnabledFor(logging.DEBUG): + log.debug( + "[#%04X] C: removing old connection " + "(closed=%s, defunct=%s, stale=%s, in_use=%s)", + connection.local_port, + connection.closed(), connection.defunct(), + connection.stale(), connection.in_use + ) + await connection.close() + try: + self.connections.get(address, []).remove(connection) + except ValueError: + # If closure fails (e.g. because the server went + # down), all connections to the same address will + # be removed. Therefore, we silently ignore if the + # connection isn't in the pool anymore. + pass + continue + if not connection.in_use: + connection.in_use = True + return connection + # all connections in pool are in-use + connections = self.connections[address] + max_pool_size = self.pool_config.max_connection_pool_size + infinite_pool_size = (max_pool_size < 0 + or max_pool_size == float("inf")) + can_create_new_connection = ( + infinite_pool_size + or len(connections) < max_pool_size + ) + if can_create_new_connection: + timeout = min(self.pool_config.connection_timeout, + time_remaining()) + try: + connection = await self.opener(address, timeout) + except ServiceUnavailable: + await self.remove(address) + raise + else: + connection.pool = self + connection.in_use = True + connections.append(connection) + return connection + + # failed to obtain a connection from pool because the + # pool is full and no free connection in the pool + if time_remaining(): + await self.cond.wait(time_remaining()) + # if timed out, then we throw error. This time + # computation is needed, as with python 2.7, we + # cannot tell if the condition is notified or + # timed out when we come to this line + if not time_remaining(): + raise ClientError("Failed to obtain a connection from pool " + "within {!r}s".format(timeout)) + else: + raise ClientError("Failed to obtain a connection from pool " + "within {!r}s".format(timeout)) + + @abc.abstractmethod + async def acquire( + self, access_mode=None, timeout=None, database=None, bookmarks=None + ): + """ Acquire a connection to a server that can satisfy a set of parameters. + + :param access_mode: + :param timeout: + :param database: + :param bookmarks: + """ + + async def release(self, *connections): + """ Release a connection back into the pool. + This method is thread safe. + """ + async with self.lock: + for connection in connections: + if not (connection.defunct() + or connection.closed() + or connection.is_reset): + try: + await connection.reset() + except (Neo4jError, DriverError, BoltError) as e: + log.debug( + "Failed to reset connection on release: %s", e + ) + connection.in_use = False + self.cond.notify_all() + + def in_use_connection_count(self, address): + """ Count the number of connections currently in use to a given + address. + """ + try: + connections = self.connections[address] + except KeyError: + return 0 + else: + return sum(1 if connection.in_use else 0 for connection in connections) + + async def mark_all_stale(self): + async with self.lock: + for address in self.connections: + for connection in self.connections[address]: + connection.set_stale() + + async def deactivate(self, address): + """ Deactivate an address from the connection pool, if present, closing + all idle connection to that address + """ + async with self.lock: + try: + connections = self.connections[address] + except KeyError: # already removed from the connection pool + return + for conn in list(connections): + if not conn.in_use: + connections.remove(conn) + try: + await conn.close() + except OSError: + pass + if not connections: + await self.remove(address) + + def on_write_failure(self, address): + raise WriteServiceUnavailable( + "No write service available for pool {}".format(self) + ) + + async def remove(self, address): + """ Remove an address from the connection pool, if present, closing + all connections to that address. + """ + async with self.lock: + for connection in self.connections.pop(address, ()): + try: + await connection.close() + except OSError: + pass + + async def close(self): + """ Close all connections and empty the pool. + This method is thread safe. + """ + try: + async with self.lock: + for address in list(self.connections): + await self.remove(address) + except TypeError: + pass + + +class AsyncBoltPool(AsyncIOPool): + + @classmethod + def open(cls, address, *, auth, pool_config, workspace_config): + """Create a new BoltPool + + :param address: + :param auth: + :param pool_config: + :param workspace_config: + :return: BoltPool + """ + + async def opener(addr, timeout): + return await AsyncBolt.open( + addr, auth=auth, timeout=timeout, routing_context=None, + **pool_config + ) + + pool = cls(opener, pool_config, workspace_config, address) + return pool + + def __init__(self, opener, pool_config, workspace_config, address): + super().__init__(opener, pool_config, workspace_config) + self.address = address + + def __repr__(self): + return "<{} address={!r}>".format(self.__class__.__name__, + self.address) + + async def acquire( + self, access_mode=None, timeout=None, database=None, bookmarks=None + ): + # The access_mode and database is not needed for a direct connection, + # it's just there for consistency. + return await self._acquire(self.address, timeout) + + +class AsyncNeo4jPool(AsyncIOPool): + """ Connection pool with routing table. + """ + + @classmethod + def open(cls, *addresses, auth, pool_config, workspace_config, + routing_context=None): + """Create a new Neo4jPool + + :param addresses: one or more address as positional argument + :param auth: + :param pool_config: + :param workspace_config: + :param routing_context: + :return: Neo4jPool + """ + + address = addresses[0] + if routing_context is None: + routing_context = {} + elif "address" in routing_context: + raise ConfigurationError("The key 'address' is reserved for routing context.") + routing_context["address"] = str(address) + + async def opener(addr, timeout): + return await AsyncBolt.open( + addr, auth=auth, timeout=timeout, + routing_context=routing_context, **pool_config + ) + + pool = cls(opener, pool_config, workspace_config, address) + return pool + + def __init__(self, opener, pool_config, workspace_config, address): + """ + + :param opener: + :param pool_config: + :param workspace_config: + :param address: + """ + super().__init__(opener, pool_config, workspace_config) + # Each database have a routing table, the default database is a special case. + log.debug("[#0000] C: routing address %r", address) + self.address = address + self.routing_tables = {workspace_config.database: RoutingTable(database=workspace_config.database, routers=[address])} + self.refresh_lock = AsyncRLock() + + def __repr__(self): + """ The representation shows the initial routing addresses. + + :return: The representation + :rtype: str + """ + return "<{} addresses={!r}>".format(self.__class__.__name__, self.get_default_database_initial_router_addresses()) + + @property + def first_initial_routing_address(self): + return self.get_default_database_initial_router_addresses()[0] + + def get_default_database_initial_router_addresses(self): + """ Get the initial router addresses for the default database. + + :return: + :rtype: OrderedSet + """ + return self.get_routing_table_for_default_database().initial_routers + + def get_default_database_router_addresses(self): + """ Get the router addresses for the default database. + + :return: + :rtype: OrderedSet + """ + return self.get_routing_table_for_default_database().routers + + def get_routing_table_for_default_database(self): + return self.routing_tables[self.workspace_config.database] + + async def get_or_create_routing_table(self, database): + async with self.refresh_lock: + if database not in self.routing_tables: + self.routing_tables[database] = RoutingTable( + database=database, + routers=self.get_default_database_initial_router_addresses() + ) + return self.routing_tables[database] + + async def fetch_routing_info( + self, address, database, imp_user, bookmarks, timeout + ): + """ Fetch raw routing info from a given router address. + + :param address: router address + :param database: the database name to get routing table for + :param imp_user: the user to impersonate while fetching the routing + table + :type imp_user: str or None + :param bookmarks: iterable of bookmark values after which the routing + info should be fetched + :param timeout: connection acquisition timeout in seconds + + :return: list of routing records, or None if no connection + could be established or if no readers or writers are present + :raise ServiceUnavailable: if the server does not support + routing, or if routing support is broken or outdated + """ + cx = await self._acquire(address, timeout) + try: + routing_table = await cx.route( + database or self.workspace_config.database, + imp_user or self.workspace_config.impersonated_user, + bookmarks + ) + finally: + await self.release(cx) + return routing_table + + async def fetch_routing_table( + self, *, address, timeout, database, imp_user, bookmarks + ): + """ Fetch a routing table from a given router address. + + :param address: router address + :param timeout: seconds + :param database: the database name + :type: str + :param imp_user: the user to impersonate while fetching the routing + table + :type imp_user: str or None + :param bookmarks: bookmarks used when fetching routing table + + :return: a new RoutingTable instance or None if the given router is + currently unable to provide routing information + """ + new_routing_info = None + try: + new_routing_info = await self.fetch_routing_info( + address, database, imp_user, bookmarks, timeout + ) + except Neo4jError as e: + # checks if the code is an error that is caused by the client. In + # this case there is no sense in trying to fetch a RT from another + # router. Hence, the driver should fail fast during discovery. + if e.is_fatal_during_discovery(): + raise + except (ServiceUnavailable, SessionExpired): + pass + if not new_routing_info: + log.debug("Failed to fetch routing info %s", address) + return None + else: + servers = new_routing_info[0]["servers"] + ttl = new_routing_info[0]["ttl"] + database = new_routing_info[0].get("db", database) + new_routing_table = RoutingTable.parse_routing_info( + database=database, servers=servers, ttl=ttl + ) + + # Parse routing info and count the number of each type of server + num_routers = len(new_routing_table.routers) + num_readers = len(new_routing_table.readers) + + # num_writers = len(new_routing_table.writers) + # If no writers are available. This likely indicates a temporary state, + # such as leader switching, so we should not signal an error. + + # No routers + if num_routers == 0: + log.debug("No routing servers returned from server %s", address) + return None + + # No readers + if num_readers == 0: + log.debug("No read servers returned from server %s", address) + return None + + # At least one of each is fine, so return this table + return new_routing_table + + async def _update_routing_table_from( + self, *routers, database=None, imp_user=None, bookmarks=None, + database_callback=None + ): + """ Try to update routing tables with the given routers. + + :return: True if the routing table is successfully updated, + otherwise False + """ + log.debug("Attempting to update routing table from {}".format( + ", ".join(map(repr, routers))) + ) + for router in routers: + async for address in AsyncNetworkUtil.resolve_address( + router, resolver=self.pool_config.resolver + ): + new_routing_table = await self.fetch_routing_table( + address=address, + timeout=self.pool_config.connection_timeout, + database=database, imp_user=imp_user, bookmarks=bookmarks + ) + if new_routing_table is not None: + new_databse = new_routing_table.database + old_routing_table = await self.get_or_create_routing_table( + new_databse + ) + old_routing_table.update(new_routing_table) + log.debug( + "[#0000] C: address=%r (%r)", + address, self.routing_tables[new_databse] + ) + if callable(database_callback): + database_callback(new_databse) + return True + await self.deactivate(router) + return False + + async def update_routing_table( + self, *, database, imp_user, bookmarks, database_callback=None + ): + """ Update the routing table from the first router able to provide + valid routing information. + + :param database: The database name + :param imp_user: the user to impersonate while fetching the routing + table + :type imp_user: str or None + :param bookmarks: bookmarks used when fetching routing table + :param database_callback: A callback function that will be called with + the database name as only argument when a new routing table has been + acquired. This database name might different from `database` if that + was None and the underlying protocol supports reporting back the + actual database. + + :raise neo4j.exceptions.ServiceUnavailable: + """ + async with self.refresh_lock: + routing_table = await self.get_or_create_routing_table(database) + # copied because it can be modified + existing_routers = set(routing_table.routers) + + prefer_initial_routing_address = \ + self.routing_tables[database].initialized_without_writers + + if prefer_initial_routing_address: + # TODO: Test this state + if await self._update_routing_table_from( + self.first_initial_routing_address, database=database, + imp_user=imp_user, bookmarks=bookmarks, + database_callback=database_callback + ): + # Why is only the first initial routing address used? + return + if await self._update_routing_table_from( + *(existing_routers - {self.first_initial_routing_address}), + database=database, imp_user=imp_user, bookmarks=bookmarks, + database_callback=database_callback + ): + return + + if not prefer_initial_routing_address: + if await self._update_routing_table_from( + self.first_initial_routing_address, database=database, + imp_user=imp_user, bookmarks=bookmarks, + database_callback=database_callback + ): + # Why is only the first initial routing address used? + return + + # None of the routers have been successful, so just fail + log.error("Unable to retrieve routing information") + raise ServiceUnavailable("Unable to retrieve routing information") + + async def update_connection_pool(self, *, database): + routing_table = await self.get_or_create_routing_table(database) + servers = routing_table.servers() + for address in list(self.connections): + if address.unresolved not in servers: + await super(AsyncNeo4jPool, self).deactivate(address) + + async def ensure_routing_table_is_fresh( + self, *, access_mode, database, imp_user, bookmarks, + database_callback=None + ): + """ Update the routing table if stale. + + This method performs two freshness checks, before and after acquiring + the refresh lock. If the routing table is already fresh on entry, the + method exits immediately; otherwise, the refresh lock is acquired and + the second freshness check that follows determines whether an update + is still required. + + This method is thread-safe. + + :return: `True` if an update was required, `False` otherwise. + """ + from neo4j.api import READ_ACCESS + async with self.refresh_lock: + routing_table = await self.get_or_create_routing_table(database) + if routing_table.is_fresh(readonly=(access_mode == READ_ACCESS)): + # Readers are fresh. + return False + + await self.update_routing_table( + database=database, imp_user=imp_user, bookmarks=bookmarks, + database_callback=database_callback + ) + await self.update_connection_pool(database=database) + + for database in list(self.routing_tables.keys()): + # Remove unused databases in the routing table + # Remove the routing table after a timeout = TTL + 30s + log.debug("[#0000] C: database=%s", database) + if (self.routing_tables[database].should_be_purged_from_memory() + and database != self.workspace_config.database): + del self.routing_tables[database] + + return True + + async def _select_address(self, *, access_mode, database): + from ...api import READ_ACCESS + """ Selects the address with the fewest in-use connections. + """ + async with self.refresh_lock: + if access_mode == READ_ACCESS: + addresses = self.routing_tables[database].readers + else: + addresses = self.routing_tables[database].writers + addresses_by_usage = {} + for address in addresses: + addresses_by_usage.setdefault( + self.in_use_connection_count(address), [] + ).append(address) + if not addresses_by_usage: + if access_mode == READ_ACCESS: + raise ReadServiceUnavailable( + "No read service currently available" + ) + else: + raise WriteServiceUnavailable( + "No write service currently available" + ) + return choice(addresses_by_usage[min(addresses_by_usage)]) + + async def acquire( + self, access_mode=None, timeout=None, database=None, bookmarks=None + ): + if access_mode not in (WRITE_ACCESS, READ_ACCESS): + raise ClientError("Non valid 'access_mode'; {}".format(access_mode)) + if not timeout: + raise ClientError("'timeout' must be a float larger than 0; {}" + .format(timeout)) + + from neo4j.api import check_access_mode + access_mode = check_access_mode(access_mode) + async with self.refresh_lock: + log.debug("[#0000] C: %r", + self.routing_tables) + await self.ensure_routing_table_is_fresh( + access_mode=access_mode, database=database, imp_user=None, + bookmarks=bookmarks + ) + + while True: + try: + # Get an address for a connection that have the fewest in-use + # connections. + address = await self._select_address( + access_mode=access_mode, database=database + ) + except (ReadServiceUnavailable, WriteServiceUnavailable) as err: + raise SessionExpired("Failed to obtain connection towards '%s' server." % access_mode) from err + try: + log.debug("[#0000] C: database=%r address=%r", database, address) + # should always be a resolved address + connection = await self._acquire(address, timeout=timeout) + except ServiceUnavailable: + await self.deactivate(address=address) + else: + return connection + + async def deactivate(self, address): + """ Deactivate an address from the connection pool, + if present, remove from the routing table and also closing + all idle connections to that address. + """ + log.debug("[#0000] C: Deactivating address %r", address) + # We use `discard` instead of `remove` here since the former + # will not fail if the address has already been removed. + for database in self.routing_tables.keys(): + self.routing_tables[database].routers.discard(address) + self.routing_tables[database].readers.discard(address) + self.routing_tables[database].writers.discard(address) + log.debug("[#0000] C: table=%r", self.routing_tables) + await super(AsyncNeo4jPool, self).deactivate(address) + + def on_write_failure(self, address): + """ Remove a writer address from the routing table, if present. + """ + log.debug("[#0000] C: Removing writer %r", address) + for database in self.routing_tables.keys(): + self.routing_tables[database].writers.discard(address) + log.debug("[#0000] C: table=%r", self.routing_tables) diff --git a/neo4j/_async/work/__init__.py b/neo4j/_async/work/__init__.py new file mode 100644 index 00000000..e48e1c21 --- /dev/null +++ b/neo4j/_async/work/__init__.py @@ -0,0 +1,32 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .session import ( + AsyncResult, + AsyncSession, + AsyncTransaction, + AsyncWorkspace, +) + + +__all__ = [ + "AsyncResult", + "AsyncSession", + "AsyncTransaction", + "AsyncWorkspace", +] diff --git a/neo4j/_async/work/result.py b/neo4j/_async/work/result.py new file mode 100644 index 00000000..6184d963 --- /dev/null +++ b/neo4j/_async/work/result.py @@ -0,0 +1,379 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from collections import deque +from warnings import warn + +from ..._async_compat.util import AsyncUtil +from ...data import DataDehydrator +from ...work import ResultSummary +from ..io import ConnectionErrorHandler + + +class AsyncResult: + """A handler for the result of Cypher query execution. Instances + of this class are typically constructed and returned by + :meth:`.AyncSession.run` and :meth:`.AsyncTransaction.run`. + """ + + def __init__(self, connection, hydrant, fetch_size, on_closed, + on_error): + self._connection = ConnectionErrorHandler(connection, on_error) + self._hydrant = hydrant + self._on_closed = on_closed + self._metadata = None + self._record_buffer = deque() + self._summary = None + self._bookmark = None + self._raw_qid = -1 + self._fetch_size = fetch_size + + # states + self._discarding = False # discard the remainder of records + self._attached = False # attached to a connection + # there are still more response messages we wait for + self._streaming = False + # there ar more records available to pull from the server + self._has_more = False + # the result has been fully iterated or consumed + self._closed = False + + @property + def _qid(self): + if self._raw_qid == self._connection.most_recent_qid: + return -1 + else: + return self._raw_qid + + async def _tx_ready_run(self, query, parameters, **kwargs): + # BEGIN+RUN does not carry any extra on the RUN message. + # BEGIN {extra} + # RUN "query" {parameters} {extra} + await self._run( + query, parameters, None, None, None, None, **kwargs + ) + + async def _run( + self, query, parameters, db, imp_user, access_mode, bookmarks, + **kwargs + ): + query_text = str(query) # Query or string object + query_metadata = getattr(query, "metadata", None) + query_timeout = getattr(query, "timeout", None) + + parameters = DataDehydrator.fix_parameters(dict(parameters or {}, **kwargs)) + + self._metadata = { + "query": query_text, + "parameters": parameters, + "server": self._connection.server_info, + } + + def on_attached(metadata): + self._metadata.update(metadata) + # For auto-commit there is no qid and Bolt 3 does not support qid + self._raw_qid = metadata.get("qid", -1) + if self._raw_qid != -1: + self._connection.most_recent_qid = self._raw_qid + self._keys = metadata.get("fields") + self._attached = True + + async def on_failed_attach(metadata): + self._metadata.update(metadata) + self._attached = False + await AsyncUtil.callback(self._on_closed) + + self._connection.run( + query_text, + parameters=parameters, + mode=access_mode, + bookmarks=bookmarks, + metadata=query_metadata, + timeout=query_timeout, + db=db, + imp_user=imp_user, + on_success=on_attached, + on_failure=on_failed_attach, + ) + self._pull() + await self._connection.send_all() + await self._attach() + + def _pull(self): + def on_records(records): + if not self._discarding: + self._record_buffer.extend(self._hydrant.hydrate_records(self._keys, records)) + + async def on_summary(): + self._attached = False + await AsyncUtil.callback(self._on_closed) + + async def on_failure(metadata): + self._attached = False + await AsyncUtil.callback(self._on_closed) + + def on_success(summary_metadata): + self._streaming = False + has_more = summary_metadata.get("has_more") + self._has_more = bool(has_more) + if has_more: + return + self._metadata.update(summary_metadata) + self._bookmark = summary_metadata.get("bookmark") + + self._connection.pull( + n=self._fetch_size, + qid=self._qid, + on_records=on_records, + on_success=on_success, + on_failure=on_failure, + on_summary=on_summary, + ) + self._streaming = True + + def _discard(self): + async def on_summary(): + self._attached = False + await AsyncUtil.callback(self._on_closed) + + async def on_failure(metadata): + self._metadata.update(metadata) + self._attached = False + await AsyncUtil.callback(self._on_closed) + + def on_success(summary_metadata): + self._streaming = False + has_more = summary_metadata.get("has_more") + self._has_more = bool(has_more) + if has_more: + return + self._discarding = False + self._metadata.update(summary_metadata) + self._bookmark = summary_metadata.get("bookmark") + + # This was the last page received, discard the rest + self._connection.discard( + n=-1, + qid=self._qid, + on_success=on_success, + on_failure=on_failure, + on_summary=on_summary, + ) + self._streaming = True + + async def __aiter__(self): + """Iterator returning Records. + :returns: Record, it is an immutable ordered collection of key-value pairs. + :rtype: :class:`neo4j.Record` + """ + while self._record_buffer or self._attached: + if self._record_buffer: + yield self._record_buffer.popleft() + elif self._streaming: + await self._connection.fetch_message() + elif self._discarding: + self._discard() + await self._connection.send_all() + elif self._has_more: + self._pull() + await self._connection.send_all() + + self._closed = True + + async def _attach(self): + """Sets the Result object in an attached state by fetching messages from + the connection to the buffer. + """ + if self._closed is False: + while self._attached is False: + await self._connection.fetch_message() + + async def _buffer(self, n=None): + """Try to fill `self._record_buffer` with n records. + + Might end up with more records in the buffer if the fetch size makes it + overshoot. + Might ent up with fewer records in the buffer if there are not enough + records available. + """ + record_buffer = deque() + async for record in self: + record_buffer.append(record) + if n is not None and len(record_buffer) >= n: + break + self._closed = False + if n is None: + self._record_buffer = record_buffer + else: + self._record_buffer.extend(record_buffer) + + async def _buffer_all(self): + """Sets the Result object in an detached state by fetching all records + from the connection to the buffer. + """ + await self._buffer() + + def _obtain_summary(self): + """Obtain the summary of this result, buffering any remaining records. + + :returns: The :class:`neo4j.ResultSummary` for this result + """ + if self._summary is None: + if self._metadata: + self._summary = ResultSummary( + self._connection.unresolved_address, **self._metadata + ) + elif self._connection: + self._summary = ResultSummary( + self._connection.unresolved_address, + server=self._connection.server_info + ) + + return self._summary + + def keys(self): + """The keys for the records in this result. + + :returns: tuple of key names + :rtype: tuple + """ + return self._keys + + async def consume(self): + """Consume the remainder of this result and return a :class:`neo4j.ResultSummary`. + + Example:: + + def create_node_tx(tx, name): + result = await tx.run( + "CREATE (n:ExampleNode { name: $name }) RETURN n", name=name + ) + record = await result.single() + value = record.value() + info = await result.consume() + return value, info + + async with driver.session() as session: + node_id, info = await session.write_transaction(create_node_tx, "example") + + Example:: + + async def get_two_tx(tx): + result = await tx.run("UNWIND [1,2,3,4] AS x RETURN x") + values = [] + async for record in result: + if len(values) >= 2: + break + values.append(record.values()) + # discard the remaining records if there are any + info = await result.consume() + # use the info for logging etc. + return values, info + + with driver.session() as session: + values, info = session.read_transaction(get_two_tx) + + :returns: The :class:`neo4j.ResultSummary` for this result + """ + if self._closed is False: + self._discarding = True + async for _ in self: + pass + + return self._obtain_summary() + + async def single(self): + """Obtain the next and only remaining record from this result if available else return None. + Calling this method always exhausts the result. + + A warning is generated if more than one record is available but + the first of these is still returned. + + :returns: the next :class:`neo4j.Record` or :const:`None` if none remain + :warns: if more than one record is available + """ + # TODO in 5.0 replace with this code that raises an error if there's not + # exactly one record in the left result stream. + # self._buffer(2). + # if len(self._record_buffer) != 1: + # raise SomeError("Expected exactly 1 record, found %i" + # % len(self._record_buffer)) + # return self._record_buffer.popleft() + # TODO: exhausts the result with self.consume if there are more records. + records = await AsyncUtil.list(self) + size = len(records) + if size == 0: + return None + if size != 1: + warn("Expected a result with a single record, but this result contains %d" % size) + return records[0] + + async def peek(self): + """Obtain the next record from this result without consuming it. + This leaves the record in the buffer for further processing. + + :returns: the next :class:`.Record` or :const:`None` if none remain + """ + await self._buffer(1) + if self._record_buffer: + return self._record_buffer[0] + + async def graph(self): + """Return a :class:`neo4j.graph.Graph` instance containing all the graph objects + in the result. After calling this method, the result becomes + detached, buffering all remaining records. + + :returns: a result graph + :rtype: :class:`neo4j.graph.Graph` + """ + await self._buffer_all() + return self._hydrant.graph + + async def value(self, key=0, default=None): + """Helper function that return the remainder of the result as a list of values. + + See :class:`neo4j.Record.value` + + :param key: field to return for each remaining record. Obtain a single value from the record by index or key. + :param default: default value, used if the index of key is unavailable + :returns: list of individual values + :rtype: list + """ + return [record.value(key, default) async for record in self] + + async def values(self, *keys): + """Helper function that return the remainder of the result as a list of values lists. + + See :class:`neo4j.Record.values` + + :param keys: fields to return for each remaining record. Optionally filtering to include only certain values by index or key. + :returns: list of values lists + :rtype: list + """ + return [record.values(*keys) async for record in self] + + async def data(self, *keys): + """Helper function that return the remainder of the result as a list of dictionaries. + + See :class:`neo4j.Record.data` + + :param keys: fields to return for each remaining record. Optionally filtering to include only certain values by index or key. + :returns: list of dictionaries + :rtype: list + """ + return [record.data(*keys) async for record in self] diff --git a/neo4j/_async/work/session.py b/neo4j/_async/work/session.py new file mode 100644 index 00000000..bedf2915 --- /dev/null +++ b/neo4j/_async/work/session.py @@ -0,0 +1,447 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +from logging import getLogger +from random import random +from time import perf_counter + +from ..._async_compat import async_sleep +from ...api import ( + READ_ACCESS, + WRITE_ACCESS, +) +from ...conf import SessionConfig +from ...data import DataHydrator +from ...exceptions import ( + ClientError, + IncompleteCommit, + Neo4jError, + ServiceUnavailable, + SessionExpired, + TransactionError, + TransientError, +) +from ...work import Query +from .result import AsyncResult +from .transaction import AsyncTransaction +from .workspace import AsyncWorkspace + + +log = getLogger("neo4j") + + +class AsyncSession(AsyncWorkspace): + """A :class:`.AsyncSession` is a logical context for transactional units + of work. Connections are drawn from the :class:`.AsyncDriver` connection + pool as required. + + Session creation is a lightweight operation and sessions are not safe to + be used in concurrent contexts (multiple threads/coroutines). + Therefore, a session should generally be short-lived, and must not + span multiple threads/coroutines. + + In general, sessions will be created and destroyed within a `with` + context. For example:: + + async with driver.session() as session: + result = await session.run("MATCH (n:Person) RETURN n.name AS name") + # do something with the result... + + :param pool: connection pool instance + :param config: session config instance + """ + + # The current connection. + _connection = None + + # The current :class:`.Transaction` instance, if any. + _transaction = None + + # The current auto-transaction result, if any. + _auto_result = None + + # The state this session is in. + _state_failed = False + + # Session have been properly closed. + _closed = False + + def __init__(self, pool, session_config): + super().__init__(pool, session_config) + assert isinstance(session_config, SessionConfig) + self._bookmarks = tuple(session_config.bookmarks) + + def __del__(self): + if asyncio.iscoroutinefunction(self.close): + return + try: + self.close() + except (OSError, ServiceUnavailable, SessionExpired): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exception_type, exception_value, traceback): + if exception_type: + self._state_failed = True + await self.close() + + async def _connect(self, access_mode): + if access_mode is None: + access_mode = self._config.default_access_mode + await super()._connect(access_mode) + + def _collect_bookmark(self, bookmark): + if bookmark: + self._bookmarks = [bookmark] + + async def _result_closed(self): + if self._auto_result: + self._collect_bookmark(self._auto_result._bookmark) + self._auto_result = None + await self._disconnect() + + async def _result_error(self, _): + if self._auto_result: + self._auto_result = None + await self._disconnect() + + async def close(self): + """Close the session. + + This will release any borrowed resources, such as connections, and will + roll back any outstanding transactions. + """ + if self._connection: + if self._auto_result: + if self._state_failed is False: + try: + await self._auto_result.consume() + self._collect_bookmark(self._auto_result._bookmark) + except Exception as error: + # TODO: Investigate potential non graceful close states + self._auto_result = None + self._state_failed = True + + if self._transaction: + if self._transaction.closed() is False: + await self._transaction.rollback() # roll back the transaction if it is not closed + self._transaction = None + + try: + if self._connection: + await self._connection.send_all() + await self._connection.fetch_all() + # TODO: Investigate potential non graceful close states + except Neo4jError: + pass + except TransactionError: + pass + except ServiceUnavailable: + pass + except SessionExpired: + pass + finally: + await self._disconnect() + + self._state_failed = False + self._closed = True + + async def run(self, query, parameters=None, **kwargs): + """Run a Cypher query within an auto-commit transaction. + + The query is sent and the result header received + immediately but the :class:`neo4j.Result` content is + fetched lazily as consumed by the client application. + + If a query is executed before a previous + :class:`neo4j.AsyncResult` in the same :class:`.AsyncSession` has + been fully consumed, the first result will be fully fetched + and buffered. Note therefore that the generally recommended + pattern of usage is to fully consume one result before + executing a subsequent query. If two results need to be + consumed in parallel, multiple :class:`.AsyncSession` objects + can be used as an alternative to result buffering. + + For more usage details, see :meth:`.AsyncTransaction.run`. + + :param query: cypher query + :type query: str, neo4j.Query + :param parameters: dictionary of parameters + :type parameters: dict + :param kwargs: additional keyword parameters + :returns: a new :class:`neo4j.AsyncResult` object + :rtype: AsyncResult + """ + if not query: + raise ValueError("Cannot run an empty query") + if not isinstance(query, (str, Query)): + raise TypeError("query must be a string or a Query instance") + + if self._transaction: + raise ClientError("Explicit Transaction must be handled explicitly") + + if self._auto_result: + # This will buffer upp all records for the previous auto-transaction + await self._auto_result._buffer_all() + + if not self._connection: + await self._connect(self._config.default_access_mode) + cx = self._connection + protocol_version = cx.PROTOCOL_VERSION + server_info = cx.server_info + + hydrant = DataHydrator() + + self._auto_result = AsyncResult( + cx, hydrant, self._config.fetch_size, self._result_closed, + self._result_error + ) + await self._auto_result._run( + query, parameters, self._config.database, + self._config.impersonated_user, self._config.default_access_mode, + self._bookmarks, **kwargs + ) + + return self._auto_result + + async def last_bookmark(self): + """Return the bookmark received following the last completed transaction. + Note: For auto-transaction (Session.run) this will trigger an consume for the current result. + + :returns: :class:`neo4j.Bookmark` object + """ + # The set of bookmarks to be passed into the next transaction. + + if self._auto_result: + await self._auto_result.consume() + + if self._transaction and self._transaction._closed: + self._collect_bookmark(self._transaction._bookmark) + self._transaction = None + + if len(self._bookmarks): + return self._bookmarks[len(self._bookmarks)-1] + return None + + async def _transaction_closed_handler(self): + if self._transaction: + self._collect_bookmark(self._transaction._bookmark) + self._transaction = None + await self._disconnect() + + async def _transaction_error_handler(self, _): + if self._transaction: + self._transaction = None + await self._disconnect() + + async def _open_transaction(self, *, access_mode, metadata=None, + timeout=None): + await self._connect(access_mode=access_mode) + self._transaction = AsyncTransaction( + self._connection, self._config.fetch_size, + self._transaction_closed_handler, + self._transaction_error_handler + ) + await self._transaction._begin( + self._config.database, self._config.impersonated_user, + self._bookmarks, access_mode, metadata, timeout + ) + + async def begin_transaction(self, metadata=None, timeout=None): + """ Begin a new unmanaged transaction. Creates a new :class:`.AsyncTransaction` within this session. + At most one transaction may exist in a session at any point in time. + To maintain multiple concurrent transactions, use multiple concurrent sessions. + + Note: For auto-transaction (AsyncSession.run) this will trigger an consume for the current result. + + :param metadata: + a dictionary with metadata. + Specified metadata will be attached to the executing transaction and visible in the output of ``dbms.listQueries`` and ``dbms.listTransactions`` procedures. + It will also get logged to the ``query.log``. + This functionality makes it easier to tag transactions and is equivalent to ``dbms.setTXMetaData`` procedure, see https://neo4j.com/docs/operations-manual/current/reference/procedures/ for procedure reference. + :type metadata: dict + + :param timeout: + the transaction timeout in seconds. + Transactions that execute longer than the configured timeout will be terminated by the database. + This functionality allows to limit query/transaction execution time. + Specified timeout overrides the default timeout configured in the database using ``dbms.transaction.timeout`` setting. + Value should not represent a duration of zero or negative duration. + :type timeout: int + + :returns: A new transaction instance. + :rtype: AsyncTransaction + + :raises TransactionError: :class:`neo4j.exceptions.TransactionError` if a transaction is already open. + """ + # TODO: Implement TransactionConfig consumption + + if self._auto_result: + self._auto_result.consume() + + if self._transaction: + raise TransactionError("Explicit transaction already open") + + await self._open_transaction( + access_mode=self._config.default_access_mode, metadata=metadata, + timeout=timeout + ) + + return self._transaction + + async def _run_transaction( + self, access_mode, transaction_function, *args, **kwargs + ): + if not callable(transaction_function): + raise TypeError("Unit of work is not callable") + + metadata = getattr(transaction_function, "metadata", None) + timeout = getattr(transaction_function, "timeout", None) + + retry_delay = retry_delay_generator(self._config.initial_retry_delay, self._config.retry_delay_multiplier, self._config.retry_delay_jitter_factor) + + errors = [] + + t0 = -1 # Timer + + while True: + try: + await self._open_transaction( + access_mode=access_mode, metadata=metadata, + timeout=timeout + ) + tx = self._transaction + try: + result = await transaction_function(tx, *args, **kwargs) + except Exception: + await tx.close() + raise + else: + await tx.commit() + except IncompleteCommit: + raise + except (ServiceUnavailable, SessionExpired) as error: + errors.append(error) + await self._disconnect() + except TransientError as transient_error: + if not transient_error.is_retriable(): + raise + errors.append(transient_error) + else: + return result + if t0 == -1: + t0 = perf_counter() # The timer should be started after the first attempt + t1 = perf_counter() + if t1 - t0 > self._config.max_transaction_retry_time: + break + delay = next(retry_delay) + log.warning("Transaction failed and will be retried in {}s ({})".format(delay, "; ".join(errors[-1].args))) + await async_sleep(delay) + + if errors: + raise errors[-1] + else: + raise ServiceUnavailable("Transaction failed") + + async def read_transaction(self, transaction_function, *args, **kwargs): + """Execute a unit of work in a managed read transaction. + This transaction will automatically be committed unless an exception is thrown during query execution or by the user code. + Note, that this function perform retries and that the supplied `transaction_function` might get invoked more than once. + + Managed transactions should not generally be explicitly committed + (via ``await tx.commit()``). + + Example:: + + async def do_cypher_tx(tx, cypher): + result = await tx.run(cypher) + values = [record.values() async for record in result] + return values + + async with driver.session() as session: + values = await session.read_transaction(do_cypher_tx, "RETURN 1 AS x") + + Example:: + + async def get_two_tx(tx): + result = await tx.run("UNWIND [1,2,3,4] AS x RETURN x") + values = [] + async for record in result: + if len(values) >= 2: + break + values.append(record.values()) + # discard the remaining records if there are any + info = await result.consume() + # use the info for logging etc. + return values + + async with driver.session() as session: + values = await session.read_transaction(get_two_tx) + + :param transaction_function: a function that takes a transaction as an + argument and does work with the transaction. + `transaction_function(tx, *args, **kwargs)` where `tx` is a + :class:`.AsyncTransaction`. + :param args: arguments for the `transaction_function` + :param kwargs: key word arguments for the `transaction_function` + :return: a result as returned by the given unit of work + """ + return await self._run_transaction( + READ_ACCESS, transaction_function, *args, **kwargs + ) + + async def write_transaction(self, transaction_function, *args, **kwargs): + """Execute a unit of work in a managed write transaction. + This transaction will automatically be committed unless an exception is thrown during query execution or by the user code. + Note, that this function perform retries and that the supplied `transaction_function` might get invoked more than once. + + Managed transactions should not generally be explicitly committed (via tx.commit()). + + Example:: + + async def create_node_tx(tx, name): + query = "CREATE (n:NodeExample { name: $name }) RETURN id(n) AS node_id" + result = await tx.run(query, name=name) + record = await result.single() + return record["node_id"] + + async with driver.session() as session: + node_id = await session.write_transaction(create_node_tx, "example") + + :param transaction_function: a function that takes a transaction as an + argument and does work with the transaction. + `transaction_function(tx, *args, **kwargs)` where `tx` is a + :class:`.AsyncTransaction`. + :param args: key word arguments for the `transaction_function` + :param kwargs: key word arguments for the `transaction_function` + :return: a result as returned by the given unit of work + """ + return await self._run_transaction( + WRITE_ACCESS, transaction_function, *args, **kwargs + ) + + +def retry_delay_generator(initial_delay, multiplier, jitter_factor): + delay = initial_delay + while True: + jitter = jitter_factor * delay + yield delay - jitter + (2 * jitter * random()) + delay *= multiplier diff --git a/neo4j/_async/work/transaction.py b/neo4j/_async/work/transaction.py new file mode 100644 index 00000000..a2ae32f5 --- /dev/null +++ b/neo4j/_async/work/transaction.py @@ -0,0 +1,199 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ..._async_compat.util import AsyncUtil +from ...data import DataHydrator +from ...exceptions import TransactionError +from ...work import Query +from ..io import ConnectionErrorHandler +from .result import AsyncResult + + +class AsyncTransaction: + """ Container for multiple Cypher queries to be executed within a single + context. asynctransactions can be used within a :py:const:`async with` + block where the transaction is committed or rolled back on based on + whether an exception is raised:: + + async with session.begin_transaction() as tx: + ... + + """ + + def __init__(self, connection, fetch_size, on_closed, on_error): + self._connection = connection + self._error_handling_connection = ConnectionErrorHandler( + connection, self._error_handler + ) + self._bookmark = None + self._results = [] + self._closed = False + self._last_error = None + self._fetch_size = fetch_size + self._on_closed = on_closed + self._on_error = on_error + + async def __aenter__(self): + return self + + async def __aexit__(self, exception_type, exception_value, traceback): + if self._closed: + return + success = not bool(exception_type) + if success: + await self.commit() + await self.close() + + async def _begin( + self, database, imp_user, bookmarks, access_mode, metadata, timeout + ): + self._connection.begin( + bookmarks=bookmarks, metadata=metadata, timeout=timeout, + mode=access_mode, db=database, imp_user=imp_user + ) + await self._error_handling_connection.send_all() + await self._error_handling_connection.fetch_all() + + async def _result_on_closed_handler(self): + pass + + async def _error_handler(self, exc): + self._last_error = exc + await AsyncUtil.callback(self._on_error, exc) + + async def _consume_results(self): + for result in self._results: + await result.consume() + self._results = [] + + async def run(self, query, parameters=None, **kwparameters): + """ Run a Cypher query within the context of this transaction. + + Cypher is typically expressed as a query template plus a + set of named parameters. In Python, parameters may be expressed + through a dictionary of parameters, through individual parameter + arguments, or as a mixture of both. For example, the `run` + queries below are all equivalent:: + + >>> query = "CREATE (a:Person { name: $name, age: $age })" + >>> result = await tx.run(query, {"name": "Alice", "age": 33}) + >>> result = await tx.run(query, {"name": "Alice"}, age=33) + >>> result = await tx.run(query, name="Alice", age=33) + + Parameter values can be of any type supported by the Neo4j type + system. In Python, this includes :class:`bool`, :class:`int`, + :class:`str`, :class:`list` and :class:`dict`. Note however that + :class:`list` properties must be homogenous. + + :param query: cypher query + :type query: str + :param parameters: dictionary of parameters + :type parameters: dict + :param kwparameters: additional keyword parameters + :returns: a new :class:`neo4j.Result` object + :rtype: :class:`neo4j.Result` + :raise TransactionError: if the transaction is already closed + """ + if isinstance(query, Query): + raise ValueError("Query object is only supported for session.run") + + if self._closed: + raise TransactionError(self, "Transaction closed") + if self._last_error: + raise TransactionError(self, + "Transaction failed") from self._last_error + + if (self._results + and self._connection.supports_multiple_results is False): + # Bolt 3 Support + # Buffer up all records for the previous Result because it does not + # have any qid to fetch in batches. + await self._results[-1]._buffer_all() + + result = AsyncResult( + self._connection, DataHydrator(), self._fetch_size, + self._result_on_closed_handler, + self._error_handler + ) + self._results.append(result) + + await result._tx_ready_run(query, parameters, **kwparameters) + + return result + + async def commit(self): + """Mark this transaction as successful and close in order to trigger a COMMIT. + + :raise TransactionError: if the transaction is already closed + """ + if self._closed: + raise TransactionError(self, "Transaction closed") + if self._last_error: + raise TransactionError(self, + "Transaction failed") from self._last_error + + metadata = {} + try: + # DISCARD pending records then do a commit. + await self._consume_results() + self._connection.commit(on_success=metadata.update) + await self._connection.send_all() + await self._connection.fetch_all() + self._bookmark = metadata.get("bookmark") + finally: + self._closed = True + await AsyncUtil.callback(self._on_closed) + + return self._bookmark + + async def rollback(self): + """Mark this transaction as unsuccessful and close in order to trigger a ROLLBACK. + + :raise TransactionError: if the transaction is already closed + """ + if self._closed: + raise TransactionError(self, "Transaction closed") + + metadata = {} + try: + if not (self._connection.defunct() + or self._connection.closed() + or self._connection.is_reset): + # DISCARD pending records then do a rollback. + await self._consume_results() + self._connection.rollback(on_success=metadata.update) + await self._connection.send_all() + await self._connection.fetch_all() + finally: + self._closed = True + await AsyncUtil.callback(self._on_closed) + + async def close(self): + """Close this transaction, triggering a ROLLBACK if not closed. + """ + if self._closed: + return + await self.rollback() + + def closed(self): + """Indicator to show whether the transaction has been closed. + + :return: :const:`True` if closed, :const:`False` otherwise. + :rtype: bool + """ + return self._closed diff --git a/neo4j/_async/work/workspace.py b/neo4j/_async/work/workspace.py new file mode 100644 index 00000000..5dee191a --- /dev/null +++ b/neo4j/_async/work/workspace.py @@ -0,0 +1,102 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio + +from ...conf import WorkspaceConfig +from ...exceptions import ServiceUnavailable +from ..io import AsyncNeo4jPool + + +class AsyncWorkspace: + + def __init__(self, pool, config): + assert isinstance(config, WorkspaceConfig) + self._pool = pool + self._config = config + self._connection = None + self._connection_access_mode = None + # Sessions are supposed to cache the database on which to operate. + self._cached_database = False + self._bookmarks = None + + def __del__(self): + if asyncio.iscoroutinefunction(self.close): + return + try: + self.close() + except OSError: + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() + + def _set_cached_database(self, database): + self._cached_database = True + self._config.database = database + + async def _connect(self, access_mode): + if self._connection: + # TODO: Investigate this + # log.warning("FIXME: should always disconnect before connect") + await self._connection.send_all() + await self._connection.fetch_all() + await self._disconnect() + if not self._cached_database: + if (self._config.database is not None + or not isinstance(self._pool, AsyncNeo4jPool)): + self._set_cached_database(self._config.database) + else: + # This is the first time we open a connection to a server in a + # cluster environment for this session without explicitly + # configured database. Hence, we request a routing table update + # to try to fetch the home database. If provided by the server, + # we shall use this database explicitly for all subsequent + # actions within this session. + await self._pool.update_routing_table( + database=self._config.database, + imp_user=self._config.impersonated_user, + bookmarks=self._bookmarks, + database_callback=self._set_cached_database + ) + self._connection = await self._pool.acquire( + access_mode=access_mode, + timeout=self._config.connection_acquisition_timeout, + database=self._config.database, + bookmarks=self._bookmarks + ) + self._connection_access_mode = access_mode + + async def _disconnect(self, sync=False): + if self._connection: + if sync: + try: + await self._connection.send_all() + await self._connection.fetch_all() + except ServiceUnavailable: + pass + if self._connection: + await self._pool.release(self._connection) + self._connection = None + self._connection_access_mode = None + + async def close(self): + await self._disconnect(sync=True) diff --git a/neo4j/_async_compat/__init__.py b/neo4j/_async_compat/__init__.py new file mode 100644 index 00000000..87430ab5 --- /dev/null +++ b/neo4j/_async_compat/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from asyncio import sleep as async_sleep +from time import sleep + + +__all__ = [ + "async_sleep", + "sleep", +] diff --git a/neo4j/_async_compat/concurrency.py b/neo4j/_async_compat/concurrency.py new file mode 100644 index 00000000..868ad1bd --- /dev/null +++ b/neo4j/_async_compat/concurrency.py @@ -0,0 +1,231 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import collections +import re +import threading + + +__all__ = [ + "AsyncCondition", + "AsyncRLock", + "Condition", + "RLock", +] + + +class AsyncRLock(asyncio.Lock): + """Reentrant asyncio.lock + + Inspired by Python's RLock implementation + """ + + _WAITERS_RE = re.compile(r"(?:\W|^)waiters[:=](\d+)(?:\W|$)") + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._owner = None + self._count = 0 + + def __repr__(self): + res = object.__repr__(self) + lock_repr = super().__repr__() + extra = "locked" if self._count > 0 else "unlocked" + extra += f" count={self._count}" + waiters_match = self._WAITERS_RE.search(lock_repr) + if waiters_match: + extra += f" waiters={waiters_match.group(1)}" + if self._owner: + extra += f" owner={self._owner}" + return f'<{res[1:-1]} [{extra}]>' + + def is_owner(self, task=None): + if task is None: + task = asyncio.current_task() + return self._owner == task + + async def _acquire(self, me): + if self.is_owner(task=me): + self._count += 1 + return + await super().acquire() + self._owner = me + self._count = 1 + + async def acquire(self, timeout=None): + """Acquire the lock.""" + me = asyncio.current_task() + if timeout is None: + return await self._acquire(me) + return await asyncio.wait_for(self._acquire(me), timeout) + + __aenter__ = acquire + + def _release(self, me): + if not self.is_owner(task=me): + if self._owner is None: + raise RuntimeError("Cannot release un-acquired lock.") + raise RuntimeError("Cannot release foreign lock.") + self._count -= 1 + if not self._count: + self._owner = None + super().release() + + def release(self): + """Release the lock""" + me = asyncio.current_task() + return self._release(me) + + async def __aexit__(self, t, v, tb): + self.release() + + +# copied and modified from asyncio.locks (3.7) +# to add support for `.wait(timeout)` +class AsyncCondition: + """Asynchronous equivalent to threading.Condition. + + This class implements condition variable objects. A condition variable + allows one or more coroutines to wait until they are notified by another + coroutine. + + A new Lock object is created and used as the underlying lock. + """ + + def __init__(self, lock=None, *, loop=None): + if loop is not None: + self._loop = loop + else: + self._loop = asyncio.get_event_loop() + + if lock is None: + lock = asyncio.Lock(loop=self._loop) + elif (hasattr(lock, "_loop") + and lock._loop is not None + and lock._loop is not self._loop): + raise ValueError("loop argument must agree with lock") + + self._lock = lock + # Export the lock's locked(), acquire() and release() methods. + self.locked = lock.locked + self.acquire = lock.acquire + self.release = lock.release + + self._waiters = collections.deque() + + async def __aenter__(self): + await self.acquire() + # We have no use for the "as ..." clause in the with + # statement for locks. + return None + + async def __aexit__(self, exc_type, exc, tb): + self.release() + + def __repr__(self): + res = super().__repr__() + extra = 'locked' if self.locked() else 'unlocked' + if self._waiters: + extra = f'{extra}, waiters:{len(self._waiters)}' + return f'<{res[1:-1]} [{extra}]>' + + async def _wait(self, me=None): + """Wait until notified. + + If the calling coroutine has not acquired the lock when this + method is called, a RuntimeError is raised. + + This method releases the underlying lock, and then blocks + until it is awakened by a notify() or notify_all() call for + the same condition variable in another coroutine. Once + awakened, it re-acquires the lock and returns True. + """ + if not self.locked(): + raise RuntimeError('cannot wait on un-acquired lock') + + if isinstance(self._lock, AsyncRLock): + self._lock._release(me) + else: + self._lock.release() + try: + fut = self._loop.create_future() + self._waiters.append(fut) + try: + await fut + return True + finally: + self._waiters.remove(fut) + + finally: + # Must reacquire lock even if wait is cancelled + cancelled = False + while True: + try: + if isinstance(self._lock, AsyncRLock): + await self._lock._acquire(me) + else: + await self._lock.acquire() + break + except asyncio.CancelledError: + cancelled = True + + if cancelled: + raise asyncio.CancelledError + + async def wait(self, timeout=None): + if not timeout: + return await self._wait() + me = asyncio.current_task() + return await asyncio.wait_for(self._wait(me), timeout) + + def notify(self, n=1): + """By default, wake up one coroutine waiting on this condition, if any. + If the calling coroutine has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up at most n of the coroutines waiting for the + condition variable; it is a no-op if no coroutines are waiting. + + Note: an awakened coroutine does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + if not self.locked(): + raise RuntimeError('cannot notify on un-acquired lock') + + idx = 0 + for fut in self._waiters: + if idx >= n: + break + + if not fut.done(): + idx += 1 + fut.set_result(False) + + def notify_all(self): + """Wake up all threads waiting on this condition. This method acts + like notify(), but wakes up all waiting threads instead of one. If the + calling thread has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._waiters)) + + +Condition = threading.Condition +RLock = threading.RLock diff --git a/neo4j/_async_compat/network/__init__.py b/neo4j/_async_compat/network/__init__.py new file mode 100644 index 00000000..c2611569 --- /dev/null +++ b/neo4j/_async_compat/network/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .bolt_socket import ( + AsyncBoltSocket, + BoltSocket, +) +from .util import ( + AsyncNetworkUtil, + NetworkUtil, +) + + +__all__ = [ + "AsyncBoltSocket", + "AsyncNetworkUtil", + "BoltSocket", + "NetworkUtil", +] diff --git a/neo4j/_async_compat/network/bolt_socket.py b/neo4j/_async_compat/network/bolt_socket.py new file mode 100644 index 00000000..1af0793a --- /dev/null +++ b/neo4j/_async_compat/network/bolt_socket.py @@ -0,0 +1,539 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import logging +import selectors +import socket +from socket import ( + AF_INET, + AF_INET6, + SHUT_RDWR, + SO_KEEPALIVE, + socket, + SOL_SOCKET, + timeout as SocketTimeout, +) +from ssl import ( + CertificateError, + HAS_SNI, + SSLError, +) +import struct + +from ... import addressing +from ..._exceptions import ( + BoltError, + BoltProtocolError, + BoltSecurityError, +) +from ...exceptions import ( + DriverError, + ServiceUnavailable, +) +from .util import ( + AsyncNetworkUtil, + NetworkUtil, +) + + +log = logging.getLogger("neo4j") + + +class AsyncBoltSocket: + Bolt = None + + def __init__(self, reader, protocol, writer): # , loop): + self._reader = reader # type: asyncio.StreamReader + self._protocol = protocol # type: asyncio.StreamReaderProtocol + self._writer = writer # type: asyncio.StreamWriter + # self._loop = loop # type: asyncio.BaseEventLoop + # 0 - non-blocking + # None infinitely blocking + # int - seconds to wait for data + self._timeout = None + + async def _wait_for_io(self, io_fut): + if self._timeout is not None and self._timeout <= 0: + # give the io-operation time for one loop cycle to do its thing + await asyncio.sleep(0) + try: + return await asyncio.wait_for(io_fut, self._timeout) + except asyncio.TimeoutError: + raise SocketTimeout + + @property + def _socket(self) -> socket: + return self._writer.transport.get_extra_info("socket") + + def getsockname(self): + return self._writer.transport.get_extra_info("sockname") + + def getpeername(self): + return self._writer.transport.get_extra_info("peername") + + def getpeercert(self, *args, **kwargs): + return self._writer.transport.get_extra_info("ssl_object")\ + .getpeercert(*args, **kwargs) + + def gettimeout(self): + return self._timeout + + def settimeout(self, timeout): + if timeout is None: + self._timeout = timeout + else: + assert timeout >= 0 + self._timeout = timeout + + async def recv(self, n): + io_fut = self._reader.read(n) + return await self._wait_for_io(io_fut) + + async def recv_into(self, buffer, nbytes): + # FIXME: not particularly memory or time efficient + io_fut = self._reader.read(nbytes) + res = await self._wait_for_io(io_fut) + buffer[:len(res)] = res + return len(res) + + async def sendall(self, data): + self._writer.write(data) + io_fut = self._writer.drain() + return await self._wait_for_io(io_fut) + + def close(self): + self._writer.close() + + @classmethod + async def _connect_secure(cls, resolved_address, timeout, keep_alive, ssl): + """ + + :param resolved_address: + :param timeout: seconds + :param keep_alive: True or False + :param ssl: SSLContext or None + + :return: AsyncBoltSocket object + """ + + loop = asyncio.get_event_loop() + s = None + + # TODO: tomorrow me: fix this mess + try: + if len(resolved_address) == 2: + s = socket(AF_INET) + elif len(resolved_address) == 4: + s = socket(AF_INET6) + else: + raise ValueError( + "Unsupported address {!r}".format(resolved_address)) + s.setblocking(False) # asyncio + blocking = no-no! + log.debug("[#0000] C: %s", resolved_address) + await asyncio.wait_for( + loop.sock_connect(s, resolved_address), + timeout + ) + + keep_alive = 1 if keep_alive else 0 + s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive) + + ssl_kwargs = {} + + if ssl is not None: + hostname = resolved_address.host_name or None + ssl_kwargs.update( + ssl=ssl, server_hostname=hostname if HAS_SNI else None + ) + + reader = asyncio.StreamReader( + limit=2 ** 16, # 64 KiB, + loop=loop + ) + protocol = asyncio.StreamReaderProtocol(reader, loop=loop) + transport, _ = await loop.create_connection( + lambda: protocol, sock=s, **ssl_kwargs + ) + writer = asyncio.StreamWriter(transport, protocol, reader, loop) + + if ssl is not None: + # Check that the server provides a certificate + der_encoded_server_certificate = \ + transport.get_extra_info("ssl_object").getpeercert( + binary_form=True) + if der_encoded_server_certificate is None: + local_port = s.getsockname()[1] + raise BoltProtocolError( + "When using an encrypted socket, the server should " + "always provide a certificate", + address=(resolved_address.host_name, local_port) + ) + + return cls(reader, protocol, writer) + + except asyncio.TimeoutError: + log.debug("[#0000] C: %s", resolved_address) + log.debug("[#0000] C: %s", resolved_address) + if s: + cls.close_socket(s) + raise ServiceUnavailable( + "Timed out trying to establish connection to {!r}".format( + resolved_address)) + except (SSLError, CertificateError) as error: + local_port = s.getsockname()[1] + raise BoltSecurityError( + message="Failed to establish encrypted connection.", + address=(resolved_address.host_name, local_port) + ) from error + except OSError as error: + log.debug("[#0000] C: %s %s", type(error).__name__, + " ".join(map(repr, error.args))) + log.debug("[#0000] C: %s", resolved_address) + s.close() + raise ServiceUnavailable( + "Failed to establish connection to {!r} (reason {})".format( + resolved_address, error)) + + async def _handshake(self, resolved_address): + """ + + :param s: Socket + :param resolved_address: + + :return: (socket, version, client_handshake, server_response_data) + """ + local_port = self.getsockname()[1] + + # TODO: Optimize logging code + handshake = self.Bolt.get_handshake() + handshake = struct.unpack(">16B", handshake) + handshake = [handshake[i:i + 4] for i in range(0, len(handshake), 4)] + + supported_versions = [ + ("0x%02X%02X%02X%02X" % (vx[0], vx[1], vx[2], vx[3])) for vx in + handshake] + + log.debug("[#%04X] C: 0x%08X", local_port, + int.from_bytes(self.Bolt.MAGIC_PREAMBLE, byteorder="big")) + log.debug("[#%04X] C: %s %s %s %s", local_port, + *supported_versions) + + data = self.Bolt.MAGIC_PREAMBLE + self.Bolt.get_handshake() + await self.sendall(data) + + # Handle the handshake response + original_timeout = self.gettimeout() + if original_timeout is not None: + self.settimeout(original_timeout + 1) + try: + data = await self.recv(4) + except OSError: + raise ServiceUnavailable( + "Failed to read any data from server {!r} " + "after connected".format(resolved_address)) + finally: + self.settimeout(original_timeout) + data_size = len(data) + if data_size == 0: + # If no data is returned after a successful select + # response, the server has closed the connection + log.debug("[#%04X] S: ", local_port) + self.close() + raise ServiceUnavailable( + "Connection to {address} closed without handshake response".format( + address=resolved_address)) + if data_size != 4: + # Some garbled data has been received + log.debug("[#%04X] S: @*#!", local_port) + self.close() + raise BoltProtocolError( + "Expected four byte Bolt handshake response from %r, received %r instead; check for incorrect port number" % ( + resolved_address, data), address=resolved_address) + elif data == b"HTTP": + log.debug("[#%04X] S: ", local_port) + self.close() + raise ServiceUnavailable( + "Cannot to connect to Bolt service on {!r} " + "(looks like HTTP)".format(resolved_address)) + agreed_version = data[-1], data[-2] + log.debug("[#%04X] S: 0x%06X%02X", local_port, + agreed_version[1], agreed_version[0]) + return self, agreed_version, handshake, data + + @classmethod + def close_socket(cls, socket_): + if isinstance(socket_, socket): + try: + socket_.shutdown(SHUT_RDWR) + socket_.close() + except OSError: + pass + else: + socket_.close() + + @classmethod + async def connect(cls, address, *, timeout, custom_resolver, ssl_context, + keep_alive): + """ Connect and perform a handshake and return a valid Connection object, + assuming a protocol version can be agreed. + """ + errors = [] + failed_addresses = [] + # Establish a connection to the host and port specified + # Catches refused connections see: + # https://docs.python.org/2/library/errno.html + + resolved_addresses = AsyncNetworkUtil.resolve_address( + addressing.Address(address), resolver=custom_resolver + ) + async for resolved_address in resolved_addresses: + s = None + try: + s = await cls._connect_secure( + resolved_address, timeout, keep_alive, ssl_context + ) + return await s._handshake(resolved_address) + except (BoltError, DriverError, OSError) as error: + try: + local_port = s.getsockname()[1] + except (OSError, AttributeError, TypeError): + local_port = 0 + err_str = error.__class__.__name__ + if str(error): + err_str += ": " + str(error) + log.debug("[#%04X] C: %s", local_port, + err_str) + if s: + cls.close_socket(s) + errors.append(error) + failed_addresses.append(resolved_address) + except Exception: + if s: + cls.close_socket(s) + raise + if not errors: + raise ServiceUnavailable( + "Couldn't connect to %s (resolved to %s)" % ( + str(address), tuple(map(str, failed_addresses))) + ) + else: + raise ServiceUnavailable( + "Couldn't connect to %s (resolved to %s):\n%s" % ( + str(address), tuple(map(str, failed_addresses)), + "\n".join(map(str, errors)) + ) + ) from errors[0] + + +class BoltSocket: + Bolt = None + + @classmethod + def _connect(cls, resolved_address, timeout, keep_alive): + """ + + :param resolved_address: + :param timeout: seconds + :param keep_alive: True or False + :return: socket object + """ + + s = None # The socket + + try: + if len(resolved_address) == 2: + s = socket(AF_INET) + elif len(resolved_address) == 4: + s = socket(AF_INET6) + else: + raise ValueError( + "Unsupported address {!r}".format(resolved_address)) + t = s.gettimeout() + if timeout: + s.settimeout(timeout) + log.debug("[#0000] C: %s", resolved_address) + s.connect(resolved_address) + s.settimeout(t) + keep_alive = 1 if keep_alive else 0 + s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive) + return s + except SocketTimeout: + log.debug("[#0000] C: %s", resolved_address) + log.debug("[#0000] C: %s", resolved_address) + cls.close_socket(s) + raise ServiceUnavailable( + "Timed out trying to establish connection to {!r}".format( + resolved_address)) + except OSError as error: + log.debug("[#0000] C: %s %s", type(error).__name__, + " ".join(map(repr, error.args))) + log.debug("[#0000] C: %s", resolved_address) + s.close() + raise ServiceUnavailable( + "Failed to establish connection to {!r} (reason {})".format( + resolved_address, error)) + + @classmethod + def _secure(cls, s, host, ssl_context): + local_port = s.getsockname()[1] + # Secure the connection if an SSL context has been provided + if ssl_context: + log.debug("[#%04X] C: %s", local_port, host) + try: + sni_host = host if HAS_SNI and host else None + s = ssl_context.wrap_socket(s, server_hostname=sni_host) + except (OSError, SSLError, CertificateError) as cause: + raise BoltSecurityError( + message="Failed to establish encrypted connection.", + address=(host, local_port) + ) from cause + # Check that the server provides a certificate + der_encoded_server_certificate = s.getpeercert(binary_form=True) + if der_encoded_server_certificate is None: + raise BoltProtocolError( + "When using an encrypted socket, the server should always " + "provide a certificate", address=(host, local_port) + ) + return s + return s + + @classmethod + def _handshake(cls, s, resolved_address): + """ + + :param s: Socket + :param resolved_address: + + :return: (socket, version, client_handshake, server_response_data) + """ + local_port = s.getsockname()[1] + + # TODO: Optimize logging code + handshake = cls.Bolt.get_handshake() + handshake = struct.unpack(">16B", handshake) + handshake = [handshake[i:i + 4] for i in range(0, len(handshake), 4)] + + supported_versions = [ + ("0x%02X%02X%02X%02X" % (vx[0], vx[1], vx[2], vx[3])) for vx in + handshake] + + log.debug("[#%04X] C: 0x%08X", local_port, + int.from_bytes(cls.Bolt.MAGIC_PREAMBLE, byteorder="big")) + log.debug("[#%04X] C: %s %s %s %s", local_port, + *supported_versions) + + data = cls.Bolt.MAGIC_PREAMBLE + cls.Bolt.get_handshake() + s.sendall(data) + + # Handle the handshake response + ready_to_read = False + with selectors.DefaultSelector() as selector: + selector.register(s, selectors.EVENT_READ) + selector.select(1) + try: + data = s.recv(4) + except OSError: + raise ServiceUnavailable( + "Failed to read any data from server {!r} " + "after connected".format(resolved_address)) + data_size = len(data) + if data_size == 0: + # If no data is returned after a successful select + # response, the server has closed the connection + log.debug("[#%04X] S: ", local_port) + BoltSocket.close_socket(s) + raise ServiceUnavailable( + "Connection to {address} closed without handshake response".format( + address=resolved_address)) + if data_size != 4: + # Some garbled data has been received + log.debug("[#%04X] S: @*#!", local_port) + s.close() + raise BoltProtocolError( + "Expected four byte Bolt handshake response from %r, received %r instead; check for incorrect port number" % ( + resolved_address, data), address=resolved_address) + elif data == b"HTTP": + log.debug("[#%04X] S: ", local_port) + BoltSocket.close_socket(s) + raise ServiceUnavailable( + "Cannot to connect to Bolt service on {!r} " + "(looks like HTTP)".format(resolved_address)) + agreed_version = data[-1], data[-2] + log.debug("[#%04X] S: 0x%06X%02X", local_port, + agreed_version[1], agreed_version[0]) + return s, agreed_version, handshake, data + + @classmethod + def close_socket(cls, socket_): + try: + socket_.shutdown(SHUT_RDWR) + socket_.close() + except OSError: + pass + + @classmethod + def connect(cls, address, *, timeout, custom_resolver, ssl_context, + keep_alive): + """ Connect and perform a handshake and return a valid Connection object, + assuming a protocol version can be agreed. + """ + errors = [] + # Establish a connection to the host and port specified + # Catches refused connections see: + # https://docs.python.org/2/library/errno.html + + resolved_addresses = NetworkUtil.resolve_address( + addressing.Address(address), resolver=custom_resolver + ) + for resolved_address in resolved_addresses: + s = None + try: + s = BoltSocket._connect(resolved_address, timeout, keep_alive) + s = BoltSocket._secure(s, resolved_address.host_name, + ssl_context) + return BoltSocket._handshake(s, resolved_address) + except (BoltError, DriverError, OSError) as error: + try: + local_port = s.getsockname()[1] + except (OSError, AttributeError): + local_port = 0 + err_str = error.__class__.__name__ + if str(error): + err_str += ": " + str(error) + log.debug("[#%04X] C: %s", local_port, + err_str) + if s: + BoltSocket.close_socket(s) + errors.append(error) + except Exception: + if s: + BoltSocket.close_socket(s) + raise + if not errors: + raise ServiceUnavailable( + "Couldn't connect to %s (resolved to %s)" % ( + str(address), tuple(map(str, resolved_addresses))) + ) + else: + raise ServiceUnavailable( + "Couldn't connect to %s (resolved to %s):\n%s" % ( + str(address), tuple(map(str, resolved_addresses)), + "\n".join(map(str, errors)) + ) + ) from errors[0] diff --git a/neo4j/_async_compat/network/util.py b/neo4j/_async_compat/network/util.py new file mode 100644 index 00000000..a9243512 --- /dev/null +++ b/neo4j/_async_compat/network/util.py @@ -0,0 +1,150 @@ +import asyncio +import logging +import socket + +from ... import addressing + + +log = logging.getLogger("neo4j") + + +def _resolved_addresses_from_info(info, host_name): + resolved = [] + for fam, _, _, _, addr in info: + if fam == socket.AF_INET6 and addr[3] != 0: + # skip any IPv6 addresses with a non-zero scope id + # as these appear to cause problems on some platforms + continue + if addr not in resolved: + resolved.append(addr) + yield addressing.ResolvedAddress( + addr, host_name=host_name + ) + + +class AsyncNetworkUtil: + @staticmethod + async def get_address_info(host, port, *, + family=0, type=0, proto=0, flags=0): + loop = asyncio.get_event_loop() + return await loop.getaddrinfo( + host, port, family=family, type=type, proto=proto, flags=flags + ) + + @staticmethod + async def _dns_resolver(address, family=0): + """ Regular DNS resolver. Takes an address object and optional + address family for filtering. + + :param address: + :param family: + :return: + """ + try: + info = await AsyncNetworkUtil.get_address_info( + address.host, address.port, family=family, + type=socket.SOCK_STREAM + ) + except OSError: + raise ValueError("Cannot resolve address {}".format(address)) + return _resolved_addresses_from_info(info, address.host_name) + + @staticmethod + async def resolve_address(address, family=0, resolver=None): + """ Carry out domain name resolution on this Address object. + + If a resolver function is supplied, and is callable, this is + called first, with this object as its argument. This may yield + multiple output addresses, which are chained into a subsequent + regular DNS resolution call. If no resolver function is passed, + the DNS resolution is carried out on the original Address + object. + + This function returns a list of resolved Address objects. + + :param address: the Address to resolve + :param family: optional address family to filter resolved + addresses by (e.g. `socket.AF_INET6`) + :param resolver: optional customer resolver function to be + called before regular DNS resolution + """ + if isinstance(address, addressing.ResolvedAddress): + yield address + return + + log.debug("[#0000] C: %s", address) + if resolver: + if asyncio.iscoroutinefunction(resolver): + resolved_addresses = await resolver(address) + else: + resolved_addresses = resolver(address) + for address in map(addressing.Address, resolved_addresses): + for resolved_address in await AsyncNetworkUtil._dns_resolver( + address, family=family + ): + yield resolved_address + else: + for resolved_address in await AsyncNetworkUtil._dns_resolver( + address, family=family + ): + yield resolved_address + + +class NetworkUtil: + @staticmethod + def get_address_info(host, port, *, family=0, type=0, proto=0, flags=0): + return socket.getaddrinfo(host, port, family, type, proto, flags) + + @staticmethod + def _dns_resolver(address, family=0): + """ Regular DNS resolver. Takes an address object and optional + address family for filtering. + + :param address: + :param family: + :return: + """ + try: + info = NetworkUtil.get_address_info( + address.host, address.port, family=family, + type=socket.SOCK_STREAM + ) + except OSError: + raise ValueError("Cannot resolve address {}".format(address)) + return _resolved_addresses_from_info(info, address.host_name) + + @staticmethod + def resolve_address(address, family=0, resolver=None): + """ Carry out domain name resolution on this Address object. + + If a resolver function is supplied, and is callable, this is + called first, with this object as its argument. This may yield + multiple output addresses, which are chained into a subsequent + regular DNS resolution call. If no resolver function is passed, + the DNS resolution is carried out on the original Address + object. + + This function returns a list of resolved Address objects. + + :param address: the Address to resolve + :param family: optional address family to filter resolved + addresses by (e.g. `socket.AF_INET6`) + :param resolver: optional customer resolver function to be + called before regular DNS resolution + """ + if isinstance(address, addressing.ResolvedAddress): + yield address + return + + addressing.log.debug("[#0000] C: %s", address) + if resolver: + for address in map(addressing.Address, resolver(address)): + for resolved_address in NetworkUtil._dns_resolver( + address, family=family + ): + yield resolved_address + else: + for resolved_address in NetworkUtil._dns_resolver( + address, family=family + ): + yield resolved_address diff --git a/neo4j/_async_compat/util.py b/neo4j/_async_compat/util.py new file mode 100644 index 00000000..ca432766 --- /dev/null +++ b/neo4j/_async_compat/util.py @@ -0,0 +1,63 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect + +from ..meta import experimental + + +class AsyncUtil: + @staticmethod + async def iter(it): + async for x in it: + yield x + + @staticmethod + async def next(it): + return await it.__anext__() + + @staticmethod + async def list(it): + return [x async for x in it] + + @staticmethod + async def callback(cb, *args, **kwargs): + if callable(cb): + res = cb(*args, **kwargs) + if inspect.isawaitable(res): + return await res + return res + + experimental_async = experimental + + +class Util: + iter = iter + next = next + list = list + + @staticmethod + def callback(cb, *args, **kwargs): + if callable(cb): + return cb(*args, **kwargs) + + @staticmethod + def experimental_async(message): + def f_(f): + return f + return f_ diff --git a/neo4j/_exceptions.py b/neo4j/_exceptions.py index 67db7f6c..a48fec26 100644 --- a/neo4j/_exceptions.py +++ b/neo4j/_exceptions.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # diff --git a/neo4j/_sync/__init__.py b/neo4j/_sync/__init__.py new file mode 100644 index 00000000..b81a309d --- /dev/null +++ b/neo4j/_sync/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/neo4j/_driver.py b/neo4j/_sync/driver.py similarity index 58% rename from neo4j/_driver.py rename to neo4j/_sync/driver.py index c8aa688b..711b99e7 100644 --- a/neo4j/_driver.py +++ b/neo4j/_sync/driver.py @@ -16,19 +16,132 @@ # limitations under the License. -from .addressing import Address -from .api import READ_ACCESS -from .conf import ( +import asyncio + +from .._async_compat.util import Util +from ..addressing import Address +from ..api import ( + READ_ACCESS, + TRUST_ALL_CERTIFICATES, + TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, +) +from ..conf import ( Config, PoolConfig, SessionConfig, WorkspaceConfig, ) -from .meta import experimental -from .work.simple import Session +from ..meta import experimental + + +class GraphDatabase: + """Accessor for :class:`neo4j.Driver` construction. + """ + + @classmethod + @Util.experimental_async( + "neo4j is in experimental phase. It might be removed or changed " + "at any time (including patch releases)." + ) + def driver(cls, uri, *, auth=None, **config): + """Create a driver. + + :param uri: the connection URI for the driver, see :ref:`uri-ref` for available URIs. + :param auth: the authentication details, see :ref:`auth-ref` for available authentication details. + :param config: driver configuration key-word arguments, see :ref:`driver-configuration-ref` for available key-word arguments. + + :rtype: Neo4jDriver or BoltDriver + """ + + from ..api import ( + DRIVER_BOLT, + DRIVER_NEO4j, + parse_neo4j_uri, + parse_routing_context, + SECURITY_TYPE_NOT_SECURE, + SECURITY_TYPE_SECURE, + SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, + URI_SCHEME_BOLT, + URI_SCHEME_BOLT_SECURE, + URI_SCHEME_BOLT_SELF_SIGNED_CERTIFICATE, + URI_SCHEME_NEO4J, + URI_SCHEME_NEO4J_SECURE, + URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE, + ) + + driver_type, security_type, parsed = parse_neo4j_uri(uri) + + if "trust" in config.keys(): + if config.get("trust") not in [TRUST_ALL_CERTIFICATES, TRUST_SYSTEM_CA_SIGNED_CERTIFICATES]: + from neo4j.exceptions import ConfigurationError + raise ConfigurationError("The config setting `trust` values are {!r}".format( + [ + TRUST_ALL_CERTIFICATES, + TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, + ] + )) + + if security_type in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE] and ("encrypted" in config.keys() or "trust" in config.keys()): + from neo4j.exceptions import ConfigurationError + raise ConfigurationError("The config settings 'encrypted' and 'trust' can only be used with the URI schemes {!r}. Use the other URI schemes {!r} for setting encryption settings.".format( + [ + URI_SCHEME_BOLT, + URI_SCHEME_NEO4J, + ], + [ + URI_SCHEME_BOLT_SELF_SIGNED_CERTIFICATE, + URI_SCHEME_BOLT_SECURE, + URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE, + URI_SCHEME_NEO4J_SECURE, + ] + )) + + if security_type == SECURITY_TYPE_SECURE: + config["encrypted"] = True + elif security_type == SECURITY_TYPE_SELF_SIGNED_CERTIFICATE: + config["encrypted"] = True + config["trust"] = TRUST_ALL_CERTIFICATES + + if driver_type == DRIVER_BOLT: + return cls.bolt_driver(parsed.netloc, auth=auth, **config) + elif driver_type == DRIVER_NEO4j: + routing_context = parse_routing_context(parsed.query) + return cls.neo4j_driver(parsed.netloc, auth=auth, routing_context=routing_context, **config) + + @classmethod + def bolt_driver(cls, target, *, auth=None, **config): + """ Create a driver for direct Bolt server access that uses + socket I/O and thread-based concurrency. + """ + from .._exceptions import ( + BoltHandshakeError, + BoltSecurityError, + ) + + try: + return BoltDriver.open(target, auth=auth, **config) + except (BoltHandshakeError, BoltSecurityError) as error: + from neo4j.exceptions import ServiceUnavailable + raise ServiceUnavailable(str(error)) from error + + @classmethod + def neo4j_driver(cls, *targets, auth=None, routing_context=None, **config): + """ Create a driver for routing-capable Neo4j service access + that uses socket I/O and thread-based concurrency. + """ + from neo4j._exceptions import ( + BoltHandshakeError, + BoltSecurityError, + ) + + try: + return Neo4jDriver.open(*targets, auth=auth, routing_context=routing_context, **config) + except (BoltHandshakeError, BoltSecurityError) as error: + from neo4j.exceptions import ServiceUnavailable + raise ServiceUnavailable(str(error)) from error -class Direct: +class _Direct: default_host = "localhost" default_port = 7687 @@ -53,7 +166,7 @@ def parse_target(cls, target): return address -class Routing: +class _Routing: default_host = "localhost" default_port = 7687 @@ -80,8 +193,8 @@ def parse_targets(cls, *targets): class Driver: - """ Base class for all types of :class:`neo4j.Driver`, instances of which are - used as the primary access point to Neo4j. + """ Base class for all types of :class:`neo4j.Driver`, instances of + which are used as the primary access point to Neo4j. """ #: Connection pool @@ -91,15 +204,16 @@ def __init__(self, pool): assert pool is not None self._pool = pool - def __del__(self): - self.close() - def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.close() + def __del__(self): + if not asyncio.iscoroutinefunction(self.close): + self.close() + @property def encrypted(self): return bool(self._pool.pool_config.encrypted) @@ -107,18 +221,14 @@ def encrypted(self): def session(self, **config): """Create a session, see :ref:`session-construction-ref` - :param config: session configuration key-word arguments, see :ref:`session-configuration-ref` for available key-word arguments. + :param config: session configuration key-word arguments, + see :ref:`session-configuration-ref` for available key-word + arguments. :returns: new :class:`neo4j.Session` object """ raise NotImplementedError - @experimental("The pipeline API is experimental and may be removed or changed in a future release") - def pipeline(self, **config): - """ Create a pipeline. - """ - raise NotImplementedError - def close(self): """ Shut down, closing any open connections in the pool. """ @@ -148,13 +258,16 @@ def supports_multi_db(self): return session._connection.supports_multiple_databases -class BoltDriver(Direct, Driver): - """ A :class:`.BoltDriver` is created from a ``bolt`` URI and addresses - a single database machine. This may be a standalone server or could be a - specific member of a cluster. +class BoltDriver(_Direct, Driver): + """:class:`.BoltDriver` is instantiated for ``bolt`` URIs and + addresses a single database machine. This may be a standalone server or + could be a specific member of a cluster. - Connections established by a :class:`.BoltDriver` are always made to the - exact host and port detailed in the URI. + Connections established by a :class:`.BoltDriver` are always made to + the exact host and port detailed in the URI. + + This class is not supposed to be instantiated externally. Use + :meth:`GraphDatabase.driver` instead. """ @classmethod @@ -167,14 +280,14 @@ def open(cls, target, *, auth=None, **config): :return: :rtype: :class: `neo4j.BoltDriver` """ - from neo4j.io import BoltPool + from .io import BoltPool address = cls.parse_target(target) pool_config, default_workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) pool = BoltPool.open(address, auth=auth, pool_config=pool_config, workspace_config=default_workspace_config) return cls(pool, default_workspace_config) def __init__(self, pool, default_workspace_config): - Direct.__init__(self, pool.address) + _Direct.__init__(self, pool.address) Driver.__init__(self, pool) self._default_workspace_config = default_workspace_config @@ -185,20 +298,11 @@ def session(self, **config): :return: :rtype: :class: `neo4j.Session` """ - from neo4j.work.simple import Session + from .work import Session session_config = SessionConfig(self._default_workspace_config, config) SessionConfig.consume(config) # Consume the config return Session(self._pool, session_config) - def pipeline(self, **config): - from neo4j.work.pipelining import ( - Pipeline, - PipelineConfig, - ) - pipeline_config = PipelineConfig(self._default_workspace_config, config) - PipelineConfig.consume(config) # Consume the config - return Pipeline(self._pool, pipeline_config) - @experimental("The configuration may change in the future.") def verify_connectivity(self, **config): server_agent = None @@ -211,41 +315,36 @@ def verify_connectivity(self, **config): return server_agent -class Neo4jDriver(Routing, Driver): - """ A :class:`.Neo4jDriver` is created from a ``neo4j`` URI. The +class Neo4jDriver(_Routing, Driver): + """:class:`.Neo4jDriver` is instantiated for ``neo4j`` URIs. The routing behaviour works in tandem with Neo4j's `Causal Clustering `_ feature by directing read and write behaviour to appropriate cluster members. + + This class is not supposed to be instantiated externally. Use + :meth:`GraphDatabase.driver` instead. """ @classmethod def open(cls, *targets, auth=None, routing_context=None, **config): - from neo4j.io import Neo4jPool + from .io import Neo4jPool addresses = cls.parse_targets(*targets) pool_config, default_workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) pool = Neo4jPool.open(*addresses, auth=auth, routing_context=routing_context, pool_config=pool_config, workspace_config=default_workspace_config) return cls(pool, default_workspace_config) def __init__(self, pool, default_workspace_config): - Routing.__init__(self, pool.get_default_database_initial_router_addresses()) + _Routing.__init__(self, pool.get_default_database_initial_router_addresses()) Driver.__init__(self, pool) self._default_workspace_config = default_workspace_config def session(self, **config): + from .work import Session session_config = SessionConfig(self._default_workspace_config, config) SessionConfig.consume(config) # Consume the config return Session(self._pool, session_config) - def pipeline(self, **config): - from neo4j.work.pipelining import ( - Pipeline, - PipelineConfig, - ) - pipeline_config = PipelineConfig(self._default_workspace_config, config) - PipelineConfig.consume(config) # Consume the config - return Pipeline(self._pool, pipeline_config) - @experimental("The configuration may change in the future.") def verify_connectivity(self, **config): """ @@ -255,7 +354,7 @@ def verify_connectivity(self, **config): return self._verify_routing_connectivity() def _verify_routing_connectivity(self): - from neo4j.exceptions import ( + from ..exceptions import ( Neo4jError, ServiceUnavailable, SessionExpired, diff --git a/neo4j/_sync/io/__init__.py b/neo4j/_sync/io/__init__.py new file mode 100644 index 00000000..b598d07d --- /dev/null +++ b/neo4j/_sync/io/__init__.py @@ -0,0 +1,43 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This module contains the low-level functionality required for speaking +Bolt. It is not intended to be used directly by driver users. Instead, +the `session` module provides the main user-facing abstractions. +""" + + +__all__ = [ + "Bolt", + "BoltPool", + "Neo4jPool", + "check_supported_server_product", + "ConnectionErrorHandler", +] + + +from ._bolt import Bolt +from ._common import ( + check_supported_server_product, + ConnectionErrorHandler, +) +from ._pool import ( + BoltPool, + Neo4jPool, +) diff --git a/neo4j/_sync/io/_bolt.py b/neo4j/_sync/io/_bolt.py new file mode 100644 index 00000000..82ee8b62 --- /dev/null +++ b/neo4j/_sync/io/_bolt.py @@ -0,0 +1,571 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import abc +import asyncio +from collections import deque +from logging import getLogger +from time import perf_counter + +from ..._async_compat.network import BoltSocket +from ..._exceptions import BoltHandshakeError +from ...addressing import Address +from ...api import ( + ServerInfo, + Version, +) +from ...conf import PoolConfig +from ...exceptions import ( + AuthError, + IncompleteCommit, + ServiceUnavailable, + SessionExpired, +) +from ...meta import get_user_agent +from ...packstream import ( + Packer, + Unpacker, +) +from ._common import ( + CommitResponse, + Inbox, + Outbox, +) + + +# Set up logger +log = getLogger("neo4j") + + +class Bolt: + """ Server connection for Bolt protocol. + + A :class:`.Bolt` should be constructed following a + successful .open() + + Bolt handshake and takes the socket over which + the handshake was carried out. + """ + + MAGIC_PREAMBLE = b"\x60\x60\xB0\x17" + + PROTOCOL_VERSION = None + + # flag if connection needs RESET to go back to READY state + is_reset = False + + # The socket + in_use = False + + # The socket + _closed = False + + # The socket + _defunct = False + + #: The pool of which this connection is a member + pool = None + + # Store the id of the most recent ran query to be able to reduce sent bits by + # using the default (-1) to refer to the most recent query when pulling + # results for it. + most_recent_qid = None + + def __init__(self, unresolved_address, sock, max_connection_lifetime, *, + auth=None, user_agent=None, routing_context=None): + self.unresolved_address = unresolved_address + self.socket = sock + self.server_info = ServerInfo(Address(sock.getpeername()), + self.PROTOCOL_VERSION) + # so far `connection.recv_timeout_seconds` is the only available + # configuration hint that exists. Therefore, all hints can be stored at + # connection level. This might change in the future. + self.configuration_hints = {} + self.outbox = Outbox() + self.inbox = Inbox(self.socket, on_error=self._set_defunct_read) + self.packer = Packer(self.outbox) + self.unpacker = Unpacker(self.inbox) + self.responses = deque() + self._max_connection_lifetime = max_connection_lifetime + self._creation_timestamp = perf_counter() + self.routing_context = routing_context + + # Determine the user agent + if user_agent: + self.user_agent = user_agent + else: + self.user_agent = get_user_agent() + + # Determine auth details + if not auth: + self.auth_dict = {} + elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: + from neo4j import Auth + self.auth_dict = vars(Auth("basic", *auth)) + else: + try: + self.auth_dict = vars(auth) + except (KeyError, TypeError): + raise AuthError("Cannot determine auth details from %r" % auth) + + # Check for missing password + try: + credentials = self.auth_dict["credentials"] + except KeyError: + pass + else: + if credentials is None: + raise AuthError("Password cannot be None") + + def __del__(self): + if not asyncio.iscoroutinefunction(self.close): + self.close() + + @property + @abc.abstractmethod + def supports_multiple_results(self): + """ Boolean flag to indicate if the connection version supports multiple + queries to be buffered on the server side (True) or if all results need + to be eagerly pulled before sending the next RUN (False). + """ + pass + + @property + @abc.abstractmethod + def supports_multiple_databases(self): + """ Boolean flag to indicate if the connection version supports multiple + databases. + """ + pass + + @classmethod + def protocol_handlers(cls, protocol_version=None): + """ Return a dictionary of available Bolt protocol handlers, + keyed by version tuple. If an explicit protocol version is + provided, the dictionary will contain either zero or one items, + depending on whether that version is supported. If no protocol + version is provided, all available versions will be returned. + + :param protocol_version: tuple identifying a specific protocol + version (e.g. (3, 5)) or None + :return: dictionary of version tuple to handler class for all + relevant and supported protocol versions + :raise TypeError: if protocol version is not passed in a tuple + """ + + # Carry out Bolt subclass imports locally to avoid circular dependency issues. + from ._bolt3 import Bolt3 + from ._bolt4 import ( + Bolt4x0, + Bolt4x1, + Bolt4x2, + Bolt4x3, + Bolt4x4, + ) + + handlers = { + Bolt3.PROTOCOL_VERSION: Bolt3, + Bolt4x0.PROTOCOL_VERSION: Bolt4x0, + Bolt4x1.PROTOCOL_VERSION: Bolt4x1, + Bolt4x2.PROTOCOL_VERSION: Bolt4x2, + Bolt4x3.PROTOCOL_VERSION: Bolt4x3, + Bolt4x4.PROTOCOL_VERSION: Bolt4x4, + } + + if protocol_version is None: + return handlers + + if not isinstance(protocol_version, tuple): + raise TypeError("Protocol version must be specified as a tuple") + + if protocol_version in handlers: + return {protocol_version: handlers[protocol_version]} + + return {} + + @classmethod + def version_list(cls, versions, limit=4): + """ Return a list of supported protocol versions in order of + preference. The number of protocol versions (or ranges) + returned is limited to four. + """ + # In fact, 4.3 is the fist version to support ranges. However, the range + # support got backported to 4.2. But even if the server is too old to + # have the backport, negotiating BOLT 4.1 is no problem as it's + # equivalent to 4.2 + first_with_range_support = Version(4, 2) + result = [] + for version in versions: + if (result + and version >= first_with_range_support + and result[-1][0] == version[0] + and result[-1][1][1] == version[1] + 1): + # can use range to encompass this version + result[-1][1][1] = version[1] + continue + result.append(Version(version[0], [version[1], version[1]])) + if len(result) == 4: + break + return result + + @classmethod + def get_handshake(cls): + """ Return the supported Bolt versions as bytes. + The length is 16 bytes as specified in the Bolt version negotiation. + :return: bytes + """ + supported_versions = sorted(cls.protocol_handlers().keys(), reverse=True) + offered_versions = cls.version_list(supported_versions) + return b"".join(version.to_bytes() for version in offered_versions).ljust(16, b"\x00") + + @classmethod + def ping(cls, address, *, timeout=None, **config): + """ Attempt to establish a Bolt connection, returning the + agreed Bolt protocol version if successful. + """ + config = PoolConfig.consume(config) + try: + s, protocol_version, handshake, data = \ + BoltSocket.connect( + address, + timeout=timeout, + custom_resolver=config.resolver, + ssl_context=config.get_ssl_context(), + keep_alive=config.keep_alive, + ) + except (ServiceUnavailable, SessionExpired, BoltHandshakeError): + return None + else: + BoltSocket.close_socket(s) + return protocol_version + + @classmethod + def open( + cls, address, *, auth=None, timeout=None, routing_context=None, **pool_config + ): + """ Open a new Bolt connection to a given server address. + + :param address: + :param auth: + :param timeout: the connection timeout in seconds + :param routing_context: dict containing routing context + :param pool_config: + :return: + :raise BoltHandshakeError: raised if the Bolt Protocol can not negotiate a protocol version. + :raise ServiceUnavailable: raised if there was a connection issue. + """ + pool_config = PoolConfig.consume(pool_config) + s, pool_config.protocol_version, handshake, data = \ + BoltSocket.connect( + address, + timeout=timeout, + custom_resolver=pool_config.resolver, + ssl_context=pool_config.get_ssl_context(), + keep_alive=pool_config.keep_alive, + ) + + # Carry out Bolt subclass imports locally to avoid circular dependency + # issues. + if pool_config.protocol_version == (3, 0): + from ._bolt3 import Bolt3 + bolt_cls = Bolt3 + elif pool_config.protocol_version == (4, 0): + from ._bolt4 import Bolt4x0 + bolt_cls = Bolt4x0 + elif pool_config.protocol_version == (4, 1): + from ._bolt4 import Bolt4x1 + bolt_cls = Bolt4x1 + elif pool_config.protocol_version == (4, 2): + from ._bolt4 import Bolt4x2 + bolt_cls = Bolt4x2 + elif pool_config.protocol_version == (4, 3): + from ._bolt4 import Bolt4x3 + bolt_cls = Bolt4x3 + elif pool_config.protocol_version == (4, 4): + from ._bolt4 import Bolt4x4 + bolt_cls = Bolt4x4 + else: + log.debug("[#%04X] S: ", s.getsockname()[1]) + BoltSocket.close_socket(s) + + supported_versions = cls.protocol_handlers().keys() + raise BoltHandshakeError("The Neo4J server does not support communication with this driver. This driver have support for Bolt Protocols {}".format(supported_versions), address=address, request_data=handshake, response_data=data) + + connection = bolt_cls( + address, s, pool_config.max_connection_lifetime, auth=auth, + user_agent=pool_config.user_agent, routing_context=routing_context + ) + + try: + connection.hello() + except Exception: + connection.close() + raise + + return connection + + @property + @abc.abstractmethod + def encrypted(self): + pass + + @property + @abc.abstractmethod + def der_encoded_server_certificate(self): + pass + + @property + @abc.abstractmethod + def local_port(self): + pass + + @abc.abstractmethod + def hello(self): + """ Appends a HELLO message to the outgoing queue, sends it and consumes + all remaining messages. + """ + pass + + @abc.abstractmethod + def route(self, database=None, imp_user=None, bookmarks=None): + """ Fetch a routing table from the server for the given + `database`. For Bolt 4.3 and above, this appends a ROUTE + message; for earlier versions, a procedure call is made via + the regular Cypher execution mechanism. In all cases, this is + sent to the network, and a response is fetched. + + :param database: database for which to fetch a routing table + Requires Bolt 4.0+. + :param imp_user: the user to impersonate + Requires Bolt 4.4+. + :param bookmarks: iterable of bookmark values after which this + transaction should begin + :return: dictionary of raw routing data + """ + pass + + @abc.abstractmethod + def run(self, query, parameters=None, mode=None, bookmarks=None, + metadata=None, timeout=None, db=None, imp_user=None, **handlers): + """ Appends a RUN message to the output queue. + + :param query: Cypher query string + :param parameters: dictionary of Cypher parameters + :param mode: access mode for routing - "READ" or "WRITE" (default) + :param bookmarks: iterable of bookmark values after which this transaction should begin + :param metadata: custom metadata dictionary to attach to the transaction + :param timeout: timeout for transaction execution (seconds) + :param db: name of the database against which to begin the transaction + Requires Bolt 4.0+. + :param imp_user: the user to impersonate + Requires Bolt 4.4+. + :param handlers: handler functions passed into the returned Response object + :return: Response object + """ + pass + + @abc.abstractmethod + def discard(self, n=-1, qid=-1, **handlers): + """ Appends a DISCARD message to the output queue. + + :param n: number of records to discard, default = -1 (ALL) + :param qid: query ID to discard for, default = -1 (last query) + :param handlers: handler functions passed into the returned Response object + :return: Response object + """ + pass + + @abc.abstractmethod + def pull(self, n=-1, qid=-1, **handlers): + """ Appends a PULL message to the output queue. + + :param n: number of records to pull, default = -1 (ALL) + :param qid: query ID to pull for, default = -1 (last query) + :param handlers: handler functions passed into the returned Response object + :return: Response object + """ + pass + + @abc.abstractmethod + def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, + db=None, imp_user=None, **handlers): + """ Appends a BEGIN message to the output queue. + + :param mode: access mode for routing - "READ" or "WRITE" (default) + :param bookmarks: iterable of bookmark values after which this transaction should begin + :param metadata: custom metadata dictionary to attach to the transaction + :param timeout: timeout for transaction execution (seconds) + :param db: name of the database against which to begin the transaction + Requires Bolt 4.0+. + :param imp_user: the user to impersonate + Requires Bolt 4.4+ + :param handlers: handler functions passed into the returned Response object + :return: Response object + """ + pass + + @abc.abstractmethod + def commit(self, **handlers): + """ Appends a COMMIT message to the output queue.""" + pass + + @abc.abstractmethod + def rollback(self, **handlers): + """ Appends a ROLLBACK message to the output queue.""" + pass + + @abc.abstractmethod + def reset(self): + """ Appends a RESET message to the outgoing queue, sends it and consumes + all remaining messages. + """ + pass + + def _append(self, signature, fields=(), response=None): + """ Appends a message to the outgoing queue. + + :param signature: the signature of the message + :param fields: the fields of the message as a tuple + :param response: a response object to handle callbacks + """ + self.packer.pack_struct(signature, fields) + self.outbox.wrap_message() + self.responses.append(response) + + def _send_all(self): + data = self.outbox.view() + if data: + try: + self.socket.sendall(data) + except OSError as error: + self._set_defunct_write(error) + self.outbox.clear() + + def send_all(self): + """ Send all queued messages to the server. + """ + if self.closed(): + raise ServiceUnavailable("Failed to write to closed connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address)) + + if self.defunct(): + raise ServiceUnavailable("Failed to write to defunct connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address)) + + self._send_all() + + @abc.abstractmethod + def fetch_message(self): + """ Receive at most one message from the server, if available. + + :return: 2-tuple of number of detail messages and number of summary + messages fetched + """ + pass + + def fetch_all(self): + """ Fetch all outstanding messages. + + :return: 2-tuple of number of detail messages and number of summary + messages fetched + """ + detail_count = summary_count = 0 + while self.responses: + response = self.responses[0] + while not response.complete: + detail_delta, summary_delta = self.fetch_message() + detail_count += detail_delta + summary_count += summary_delta + return detail_count, summary_count + + def _set_defunct_read(self, error=None, silent=False): + message = "Failed to read from defunct connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address + ) + self._set_defunct(message, error=error, silent=silent) + + def _set_defunct_write(self, error=None, silent=False): + message = "Failed to write data to connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address + ) + self._set_defunct(message, error=error, silent=silent) + + def _set_defunct(self, message, error=None, silent=False): + from ._pool import BoltPool + direct_driver = isinstance(self.pool, BoltPool) + + if error: + log.debug("[#%04X] %s", self.socket.getsockname()[1], error) + log.error(message) + # We were attempting to receive data but the connection + # has unexpectedly terminated. So, we need to close the + # connection from the client side, and remove the address + # from the connection pool. + self._defunct = True + self.close() + if self.pool: + self.pool.deactivate(address=self.unresolved_address) + # Iterate through the outstanding responses, and if any correspond + # to COMMIT requests then raise an error to signal that we are + # unable to confirm that the COMMIT completed successfully. + if silent: + return + for response in self.responses: + if isinstance(response, CommitResponse): + if error: + raise IncompleteCommit(message) from error + else: + raise IncompleteCommit(message) + + if direct_driver: + if error: + raise ServiceUnavailable(message) from error + else: + raise ServiceUnavailable(message) + else: + if error: + raise SessionExpired(message) from error + else: + raise SessionExpired(message) + + def stale(self): + return (self._stale + or (0 <= self._max_connection_lifetime + <= perf_counter() - self._creation_timestamp)) + + _stale = False + + def set_stale(self): + self._stale = True + + @abc.abstractmethod + def close(self): + """ Close the connection. + """ + pass + + @abc.abstractmethod + def closed(self): + pass + + @abc.abstractmethod + def defunct(self): + pass + + +BoltSocket.Bolt = Bolt diff --git a/neo4j/io/_bolt3.py b/neo4j/_sync/io/_bolt3.py similarity index 97% rename from neo4j/io/_bolt3.py rename to neo4j/_sync/io/_bolt3.py index 704d8740..e19bd4c2 100644 --- a/neo4j/io/_bolt3.py +++ b/neo4j/_sync/io/_bolt3.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -18,19 +15,21 @@ # See the License for the specific language governing permissions and # limitations under the License. + from enum import Enum from logging import getLogger from ssl import SSLSocket -from neo4j._exceptions import ( +from ..._async_compat.util import Util +from ..._exceptions import ( BoltError, BoltProtocolError, ) -from neo4j.api import ( +from ...api import ( READ_ACCESS, Version, ) -from neo4j.exceptions import ( +from ...exceptions import ( ConfigurationError, DatabaseUnavailable, DriverError, @@ -39,11 +38,9 @@ NotALeader, ServiceUnavailable, ) -from neo4j.io import ( - Bolt, +from ._bolt import Bolt +from ._common import ( check_supported_server_product, -) -from neo4j.io._common import ( CommitResponse, InitResponse, Response, @@ -93,8 +90,8 @@ def transition(self, message, metadata): if metadata.get("has_more"): return state_before = self.state - self.state = self._STATE_TRANSITIONS\ - .get(self.state, {})\ + self.state = self._STATE_TRANSITIONS \ + .get(self.state, {}) \ .get(message, self.state) if state_before != self.state and callable(self._on_change): self._on_change(state_before, self.state) @@ -331,7 +328,8 @@ def fetch_message(self): return 0, 0 # Receive exactly one message - details, summary_signature, summary_metadata = next(self.inbox) + details, summary_signature, summary_metadata = \ + Util.next(self.inbox) if details: log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # Do not log any data @@ -357,11 +355,11 @@ def fetch_message(self): response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): if self.pool: - self.pool.deactivate(address=self.unresolved_address), + self.pool.deactivate(address=self.unresolved_address) raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: - self.pool.on_write_failure(address=self.unresolved_address), + self.pool.on_write_failure(address=self.unresolved_address) raise except Neo4jError as e: if self.pool and e.invalidates_all_connections(): diff --git a/neo4j/io/_bolt4.py b/neo4j/_sync/io/_bolt4.py similarity index 98% rename from neo4j/io/_bolt4.py rename to neo4j/_sync/io/_bolt4.py index 3bb90ab6..33242216 100644 --- a/neo4j/io/_bolt4.py +++ b/neo4j/_sync/io/_bolt4.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -18,20 +15,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from enum import Enum + from logging import getLogger from ssl import SSLSocket -from neo4j._exceptions import ( +from ..._async_compat.util import Util +from ..._exceptions import ( BoltError, BoltProtocolError, ) -from neo4j.api import ( +from ...api import ( READ_ACCESS, SYSTEM_DATABASE, Version, ) -from neo4j.exceptions import ( +from ...exceptions import ( ConfigurationError, DatabaseUnavailable, DriverError, @@ -40,19 +38,17 @@ NotALeader, ServiceUnavailable, ) -from neo4j.io import ( - Bolt, - check_supported_server_product, +from ._bolt3 import ( + ServerStateManager, + ServerStates, ) -from neo4j.io._common import ( +from ._bolt import Bolt +from ._common import ( + check_supported_server_product, CommitResponse, InitResponse, Response, ) -from neo4j.io._bolt3 import ( - ServerStateManager, - ServerStates, -) log = getLogger("neo4j") @@ -283,7 +279,8 @@ def fetch_message(self): return 0, 0 # Receive exactly one message - details, summary_signature, summary_metadata = next(self.inbox) + details, summary_signature, summary_metadata = \ + Util.next(self.inbox) if details: log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # Do not log any data @@ -309,11 +306,11 @@ def fetch_message(self): response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): if self.pool: - self.pool.deactivate(address=self.unresolved_address), + self.pool.deactivate(address=self.unresolved_address) raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: - self.pool.on_write_failure(address=self.unresolved_address), + self.pool.on_write_failure(address=self.unresolved_address) raise except Neo4jError as e: if self.pool and e.invalidates_all_connections(): diff --git a/neo4j/io/_common.py b/neo4j/_sync/io/_common.py similarity index 78% rename from neo4j/io/_common.py rename to neo4j/_sync/io/_common.py index becb7db4..408de0a1 100644 --- a/neo4j/io/_common.py +++ b/neo4j/_sync/io/_common.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,21 +16,24 @@ # limitations under the License. +import asyncio +import logging import socket from struct import pack as struct_pack -from neo4j.exceptions import ( - AuthError, +from ..._async_compat.util import Util +from ...exceptions import ( Neo4jError, ServiceUnavailable, SessionExpired, + UnsupportedServerProduct, ) -from neo4j.packstream import ( +from ...packstream import ( UnpackableBuffer, Unpacker, ) -import logging + log = logging.getLogger("neo4j") @@ -52,12 +52,12 @@ def _yield_messages(self, sock): while chunk_size == 0: # Determine the chunk size and skip noop - buffer.receive(sock, 2) + receive_into_buffer(sock, buffer, 2) chunk_size = buffer.pop_u16() if chunk_size == 0: log.debug("[#%04X] S: ", sock.getsockname()[1]) - buffer.receive(sock, chunk_size + 2) + receive_into_buffer(sock, buffer, chunk_size + 2) chunk_size = buffer.pop_u16() if chunk_size == 0: @@ -69,10 +69,10 @@ def _yield_messages(self, sock): unpacker.reset() except (OSError, socket.timeout) as error: - self.on_error(error) + Util.callback(self.on_error, error) def pop(self): - return next(self._messages) + return Util.next(self._messages) class Inbox(MessageInbox): @@ -166,10 +166,22 @@ def inner(*args, **kwargs): try: func(*args, **kwargs) except (Neo4jError, ServiceUnavailable, SessionExpired) as exc: + assert not asyncio.iscoroutinefunction(self.__on_error) self.__on_error(exc) raise return inner + def outer_async(coroutine_func): + def inner(*args, **kwargs): + try: + coroutine_func(*args, **kwargs) + except (Neo4jError, ServiceUnavailable, SessionExpired) as exc: + Util.callback(self.__on_error, exc) + raise + return inner + + if asyncio.iscoroutinefunction(connection_attr): + return outer_async(connection_attr) return outer(connection_attr) def __setattr__(self, name, value): @@ -194,20 +206,17 @@ def on_records(self, records): """ Called when one or more RECORD messages have been received. """ handler = self.handlers.get("on_records") - if callable(handler): - handler(records) + Util.callback(handler, records) def on_success(self, metadata): """ Called when a SUCCESS message has been received. """ handler = self.handlers.get("on_success") - if callable(handler): - handler(metadata) + Util.callback(handler, metadata) if not metadata.get("has_more"): handler = self.handlers.get("on_summary") - if callable(handler): - handler() + Util.callback(handler) def on_failure(self, metadata): """ Called when a FAILURE message has been received. @@ -217,22 +226,18 @@ def on_failure(self, metadata): except (SessionExpired, ServiceUnavailable): pass handler = self.handlers.get("on_failure") - if callable(handler): - handler(metadata) + Util.callback(handler, metadata) handler = self.handlers.get("on_summary") - if callable(handler): - handler() + Util.callback(handler) raise Neo4jError.hydrate(**metadata) def on_ignored(self, metadata=None): """ Called when an IGNORED message has been received. """ handler = self.handlers.get("on_ignored") - if callable(handler): - handler(metadata) + Util.callback(handler, metadata) handler = self.handlers.get("on_summary") - if callable(handler): - handler() + Util.callback(handler) class InitResponse(Response): @@ -250,3 +255,26 @@ def on_failure(self, metadata): class CommitResponse(Response): pass + + +def check_supported_server_product(agent): + """ Checks that a server product is supported by the driver by + looking at the server agent string. + + :param agent: server agent string to check for validity + :raises UnsupportedServerProduct: if the product is not supported + """ + if not agent.startswith("Neo4j/"): + raise UnsupportedServerProduct(agent) + + +def receive_into_buffer(sock, buffer, n_bytes): + end = buffer.used + n_bytes + if end > len(buffer.data): + buffer.data += bytearray(end - len(buffer.data)) + view = memoryview(buffer.data) + while buffer.used < end: + n = sock.recv_into(view[buffer.used:end], end - buffer.used) + if n == 0: + raise OSError("No data") + buffer.used += n diff --git a/neo4j/_sync/io/_pool.py b/neo4j/_sync/io/_pool.py new file mode 100644 index 00000000..5969fd6d --- /dev/null +++ b/neo4j/_sync/io/_pool.py @@ -0,0 +1,701 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import abc +from collections import ( + defaultdict, + deque, +) +import logging +from logging import getLogger +from random import choice +from time import perf_counter + +from ..._async_compat.concurrency import ( + Condition, + RLock, +) +from ..._async_compat.network import NetworkUtil +from ..._exceptions import BoltError +from ...api import ( + READ_ACCESS, + WRITE_ACCESS, +) +from ...conf import ( + PoolConfig, + WorkspaceConfig, +) +from ...exceptions import ( + ClientError, + ConfigurationError, + DriverError, + Neo4jError, + ReadServiceUnavailable, + ServiceUnavailable, + SessionExpired, + WriteServiceUnavailable, +) +from ...routing import RoutingTable +from ._bolt import Bolt + + +# Set up logger +log = getLogger("neo4j") + + +class IOPool(abc.ABC): + """ A collection of connections to one or more server addresses. + """ + + def __init__(self, opener, pool_config, workspace_config): + assert callable(opener) + assert isinstance(pool_config, PoolConfig) + assert isinstance(workspace_config, WorkspaceConfig) + + self.opener = opener + self.pool_config = pool_config + self.workspace_config = workspace_config + self.connections = defaultdict(deque) + self.lock = RLock() + self.cond = Condition(self.lock) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def _acquire(self, address, timeout): + """ Acquire a connection to a given address from the pool. + The address supplied should always be an IP address, not + a host name. + + This method is thread safe. + """ + t0 = perf_counter() + if timeout is None: + timeout = self.workspace_config.connection_acquisition_timeout + + with self.lock: + def time_remaining(): + t = timeout - (perf_counter() - t0) + return t if t > 0 else 0 + + while True: + # try to find a free connection in pool + for connection in list(self.connections.get(address, [])): + if (connection.closed() or connection.defunct() + or (connection.stale() and not connection.in_use)): + # `close` is a noop on already closed connections. + # This is to make sure that the connection is + # gracefully closed, e.g. if it's just marked as + # `stale` but still alive. + if log.isEnabledFor(logging.DEBUG): + log.debug( + "[#%04X] C: removing old connection " + "(closed=%s, defunct=%s, stale=%s, in_use=%s)", + connection.local_port, + connection.closed(), connection.defunct(), + connection.stale(), connection.in_use + ) + connection.close() + try: + self.connections.get(address, []).remove(connection) + except ValueError: + # If closure fails (e.g. because the server went + # down), all connections to the same address will + # be removed. Therefore, we silently ignore if the + # connection isn't in the pool anymore. + pass + continue + if not connection.in_use: + connection.in_use = True + return connection + # all connections in pool are in-use + connections = self.connections[address] + max_pool_size = self.pool_config.max_connection_pool_size + infinite_pool_size = (max_pool_size < 0 + or max_pool_size == float("inf")) + can_create_new_connection = ( + infinite_pool_size + or len(connections) < max_pool_size + ) + if can_create_new_connection: + timeout = min(self.pool_config.connection_timeout, + time_remaining()) + try: + connection = self.opener(address, timeout) + except ServiceUnavailable: + self.remove(address) + raise + else: + connection.pool = self + connection.in_use = True + connections.append(connection) + return connection + + # failed to obtain a connection from pool because the + # pool is full and no free connection in the pool + if time_remaining(): + self.cond.wait(time_remaining()) + # if timed out, then we throw error. This time + # computation is needed, as with python 2.7, we + # cannot tell if the condition is notified or + # timed out when we come to this line + if not time_remaining(): + raise ClientError("Failed to obtain a connection from pool " + "within {!r}s".format(timeout)) + else: + raise ClientError("Failed to obtain a connection from pool " + "within {!r}s".format(timeout)) + + @abc.abstractmethod + def acquire( + self, access_mode=None, timeout=None, database=None, bookmarks=None + ): + """ Acquire a connection to a server that can satisfy a set of parameters. + + :param access_mode: + :param timeout: + :param database: + :param bookmarks: + """ + + def release(self, *connections): + """ Release a connection back into the pool. + This method is thread safe. + """ + with self.lock: + for connection in connections: + if not (connection.defunct() + or connection.closed() + or connection.is_reset): + try: + connection.reset() + except (Neo4jError, DriverError, BoltError) as e: + log.debug( + "Failed to reset connection on release: %s", e + ) + connection.in_use = False + self.cond.notify_all() + + def in_use_connection_count(self, address): + """ Count the number of connections currently in use to a given + address. + """ + try: + connections = self.connections[address] + except KeyError: + return 0 + else: + return sum(1 if connection.in_use else 0 for connection in connections) + + def mark_all_stale(self): + with self.lock: + for address in self.connections: + for connection in self.connections[address]: + connection.set_stale() + + def deactivate(self, address): + """ Deactivate an address from the connection pool, if present, closing + all idle connection to that address + """ + with self.lock: + try: + connections = self.connections[address] + except KeyError: # already removed from the connection pool + return + for conn in list(connections): + if not conn.in_use: + connections.remove(conn) + try: + conn.close() + except OSError: + pass + if not connections: + self.remove(address) + + def on_write_failure(self, address): + raise WriteServiceUnavailable( + "No write service available for pool {}".format(self) + ) + + def remove(self, address): + """ Remove an address from the connection pool, if present, closing + all connections to that address. + """ + with self.lock: + for connection in self.connections.pop(address, ()): + try: + connection.close() + except OSError: + pass + + def close(self): + """ Close all connections and empty the pool. + This method is thread safe. + """ + try: + with self.lock: + for address in list(self.connections): + self.remove(address) + except TypeError: + pass + + +class BoltPool(IOPool): + + @classmethod + def open(cls, address, *, auth, pool_config, workspace_config): + """Create a new BoltPool + + :param address: + :param auth: + :param pool_config: + :param workspace_config: + :return: BoltPool + """ + + def opener(addr, timeout): + return Bolt.open( + addr, auth=auth, timeout=timeout, routing_context=None, + **pool_config + ) + + pool = cls(opener, pool_config, workspace_config, address) + return pool + + def __init__(self, opener, pool_config, workspace_config, address): + super().__init__(opener, pool_config, workspace_config) + self.address = address + + def __repr__(self): + return "<{} address={!r}>".format(self.__class__.__name__, + self.address) + + def acquire( + self, access_mode=None, timeout=None, database=None, bookmarks=None + ): + # The access_mode and database is not needed for a direct connection, + # it's just there for consistency. + return self._acquire(self.address, timeout) + + +class Neo4jPool(IOPool): + """ Connection pool with routing table. + """ + + @classmethod + def open(cls, *addresses, auth, pool_config, workspace_config, + routing_context=None): + """Create a new Neo4jPool + + :param addresses: one or more address as positional argument + :param auth: + :param pool_config: + :param workspace_config: + :param routing_context: + :return: Neo4jPool + """ + + address = addresses[0] + if routing_context is None: + routing_context = {} + elif "address" in routing_context: + raise ConfigurationError("The key 'address' is reserved for routing context.") + routing_context["address"] = str(address) + + def opener(addr, timeout): + return Bolt.open( + addr, auth=auth, timeout=timeout, + routing_context=routing_context, **pool_config + ) + + pool = cls(opener, pool_config, workspace_config, address) + return pool + + def __init__(self, opener, pool_config, workspace_config, address): + """ + + :param opener: + :param pool_config: + :param workspace_config: + :param address: + """ + super().__init__(opener, pool_config, workspace_config) + # Each database have a routing table, the default database is a special case. + log.debug("[#0000] C: routing address %r", address) + self.address = address + self.routing_tables = {workspace_config.database: RoutingTable(database=workspace_config.database, routers=[address])} + self.refresh_lock = RLock() + + def __repr__(self): + """ The representation shows the initial routing addresses. + + :return: The representation + :rtype: str + """ + return "<{} addresses={!r}>".format(self.__class__.__name__, self.get_default_database_initial_router_addresses()) + + @property + def first_initial_routing_address(self): + return self.get_default_database_initial_router_addresses()[0] + + def get_default_database_initial_router_addresses(self): + """ Get the initial router addresses for the default database. + + :return: + :rtype: OrderedSet + """ + return self.get_routing_table_for_default_database().initial_routers + + def get_default_database_router_addresses(self): + """ Get the router addresses for the default database. + + :return: + :rtype: OrderedSet + """ + return self.get_routing_table_for_default_database().routers + + def get_routing_table_for_default_database(self): + return self.routing_tables[self.workspace_config.database] + + def get_or_create_routing_table(self, database): + with self.refresh_lock: + if database not in self.routing_tables: + self.routing_tables[database] = RoutingTable( + database=database, + routers=self.get_default_database_initial_router_addresses() + ) + return self.routing_tables[database] + + def fetch_routing_info( + self, address, database, imp_user, bookmarks, timeout + ): + """ Fetch raw routing info from a given router address. + + :param address: router address + :param database: the database name to get routing table for + :param imp_user: the user to impersonate while fetching the routing + table + :type imp_user: str or None + :param bookmarks: iterable of bookmark values after which the routing + info should be fetched + :param timeout: connection acquisition timeout in seconds + + :return: list of routing records, or None if no connection + could be established or if no readers or writers are present + :raise ServiceUnavailable: if the server does not support + routing, or if routing support is broken or outdated + """ + cx = self._acquire(address, timeout) + try: + routing_table = cx.route( + database or self.workspace_config.database, + imp_user or self.workspace_config.impersonated_user, + bookmarks + ) + finally: + self.release(cx) + return routing_table + + def fetch_routing_table( + self, *, address, timeout, database, imp_user, bookmarks + ): + """ Fetch a routing table from a given router address. + + :param address: router address + :param timeout: seconds + :param database: the database name + :type: str + :param imp_user: the user to impersonate while fetching the routing + table + :type imp_user: str or None + :param bookmarks: bookmarks used when fetching routing table + + :return: a new RoutingTable instance or None if the given router is + currently unable to provide routing information + """ + new_routing_info = None + try: + new_routing_info = self.fetch_routing_info( + address, database, imp_user, bookmarks, timeout + ) + except Neo4jError as e: + # checks if the code is an error that is caused by the client. In + # this case there is no sense in trying to fetch a RT from another + # router. Hence, the driver should fail fast during discovery. + if e.is_fatal_during_discovery(): + raise + except (ServiceUnavailable, SessionExpired): + pass + if not new_routing_info: + log.debug("Failed to fetch routing info %s", address) + return None + else: + servers = new_routing_info[0]["servers"] + ttl = new_routing_info[0]["ttl"] + database = new_routing_info[0].get("db", database) + new_routing_table = RoutingTable.parse_routing_info( + database=database, servers=servers, ttl=ttl + ) + + # Parse routing info and count the number of each type of server + num_routers = len(new_routing_table.routers) + num_readers = len(new_routing_table.readers) + + # num_writers = len(new_routing_table.writers) + # If no writers are available. This likely indicates a temporary state, + # such as leader switching, so we should not signal an error. + + # No routers + if num_routers == 0: + log.debug("No routing servers returned from server %s", address) + return None + + # No readers + if num_readers == 0: + log.debug("No read servers returned from server %s", address) + return None + + # At least one of each is fine, so return this table + return new_routing_table + + def _update_routing_table_from( + self, *routers, database=None, imp_user=None, bookmarks=None, + database_callback=None + ): + """ Try to update routing tables with the given routers. + + :return: True if the routing table is successfully updated, + otherwise False + """ + log.debug("Attempting to update routing table from {}".format( + ", ".join(map(repr, routers))) + ) + for router in routers: + for address in NetworkUtil.resolve_address( + router, resolver=self.pool_config.resolver + ): + new_routing_table = self.fetch_routing_table( + address=address, + timeout=self.pool_config.connection_timeout, + database=database, imp_user=imp_user, bookmarks=bookmarks + ) + if new_routing_table is not None: + new_databse = new_routing_table.database + old_routing_table = self.get_or_create_routing_table( + new_databse + ) + old_routing_table.update(new_routing_table) + log.debug( + "[#0000] C: address=%r (%r)", + address, self.routing_tables[new_databse] + ) + if callable(database_callback): + database_callback(new_databse) + return True + self.deactivate(router) + return False + + def update_routing_table( + self, *, database, imp_user, bookmarks, database_callback=None + ): + """ Update the routing table from the first router able to provide + valid routing information. + + :param database: The database name + :param imp_user: the user to impersonate while fetching the routing + table + :type imp_user: str or None + :param bookmarks: bookmarks used when fetching routing table + :param database_callback: A callback function that will be called with + the database name as only argument when a new routing table has been + acquired. This database name might different from `database` if that + was None and the underlying protocol supports reporting back the + actual database. + + :raise neo4j.exceptions.ServiceUnavailable: + """ + with self.refresh_lock: + routing_table = self.get_or_create_routing_table(database) + # copied because it can be modified + existing_routers = set(routing_table.routers) + + prefer_initial_routing_address = \ + self.routing_tables[database].initialized_without_writers + + if prefer_initial_routing_address: + # TODO: Test this state + if self._update_routing_table_from( + self.first_initial_routing_address, database=database, + imp_user=imp_user, bookmarks=bookmarks, + database_callback=database_callback + ): + # Why is only the first initial routing address used? + return + if self._update_routing_table_from( + *(existing_routers - {self.first_initial_routing_address}), + database=database, imp_user=imp_user, bookmarks=bookmarks, + database_callback=database_callback + ): + return + + if not prefer_initial_routing_address: + if self._update_routing_table_from( + self.first_initial_routing_address, database=database, + imp_user=imp_user, bookmarks=bookmarks, + database_callback=database_callback + ): + # Why is only the first initial routing address used? + return + + # None of the routers have been successful, so just fail + log.error("Unable to retrieve routing information") + raise ServiceUnavailable("Unable to retrieve routing information") + + def update_connection_pool(self, *, database): + routing_table = self.get_or_create_routing_table(database) + servers = routing_table.servers() + for address in list(self.connections): + if address.unresolved not in servers: + super(Neo4jPool, self).deactivate(address) + + def ensure_routing_table_is_fresh( + self, *, access_mode, database, imp_user, bookmarks, + database_callback=None + ): + """ Update the routing table if stale. + + This method performs two freshness checks, before and after acquiring + the refresh lock. If the routing table is already fresh on entry, the + method exits immediately; otherwise, the refresh lock is acquired and + the second freshness check that follows determines whether an update + is still required. + + This method is thread-safe. + + :return: `True` if an update was required, `False` otherwise. + """ + from neo4j.api import READ_ACCESS + with self.refresh_lock: + routing_table = self.get_or_create_routing_table(database) + if routing_table.is_fresh(readonly=(access_mode == READ_ACCESS)): + # Readers are fresh. + return False + + self.update_routing_table( + database=database, imp_user=imp_user, bookmarks=bookmarks, + database_callback=database_callback + ) + self.update_connection_pool(database=database) + + for database in list(self.routing_tables.keys()): + # Remove unused databases in the routing table + # Remove the routing table after a timeout = TTL + 30s + log.debug("[#0000] C: database=%s", database) + if (self.routing_tables[database].should_be_purged_from_memory() + and database != self.workspace_config.database): + del self.routing_tables[database] + + return True + + def _select_address(self, *, access_mode, database): + from ...api import READ_ACCESS + """ Selects the address with the fewest in-use connections. + """ + with self.refresh_lock: + if access_mode == READ_ACCESS: + addresses = self.routing_tables[database].readers + else: + addresses = self.routing_tables[database].writers + addresses_by_usage = {} + for address in addresses: + addresses_by_usage.setdefault( + self.in_use_connection_count(address), [] + ).append(address) + if not addresses_by_usage: + if access_mode == READ_ACCESS: + raise ReadServiceUnavailable( + "No read service currently available" + ) + else: + raise WriteServiceUnavailable( + "No write service currently available" + ) + return choice(addresses_by_usage[min(addresses_by_usage)]) + + def acquire( + self, access_mode=None, timeout=None, database=None, bookmarks=None + ): + if access_mode not in (WRITE_ACCESS, READ_ACCESS): + raise ClientError("Non valid 'access_mode'; {}".format(access_mode)) + if not timeout: + raise ClientError("'timeout' must be a float larger than 0; {}" + .format(timeout)) + + from neo4j.api import check_access_mode + access_mode = check_access_mode(access_mode) + with self.refresh_lock: + log.debug("[#0000] C: %r", + self.routing_tables) + self.ensure_routing_table_is_fresh( + access_mode=access_mode, database=database, imp_user=None, + bookmarks=bookmarks + ) + + while True: + try: + # Get an address for a connection that have the fewest in-use + # connections. + address = self._select_address( + access_mode=access_mode, database=database + ) + except (ReadServiceUnavailable, WriteServiceUnavailable) as err: + raise SessionExpired("Failed to obtain connection towards '%s' server." % access_mode) from err + try: + log.debug("[#0000] C: database=%r address=%r", database, address) + # should always be a resolved address + connection = self._acquire(address, timeout=timeout) + except ServiceUnavailable: + self.deactivate(address=address) + else: + return connection + + def deactivate(self, address): + """ Deactivate an address from the connection pool, + if present, remove from the routing table and also closing + all idle connections to that address. + """ + log.debug("[#0000] C: Deactivating address %r", address) + # We use `discard` instead of `remove` here since the former + # will not fail if the address has already been removed. + for database in self.routing_tables.keys(): + self.routing_tables[database].routers.discard(address) + self.routing_tables[database].readers.discard(address) + self.routing_tables[database].writers.discard(address) + log.debug("[#0000] C: table=%r", self.routing_tables) + super(Neo4jPool, self).deactivate(address) + + def on_write_failure(self, address): + """ Remove a writer address from the routing table, if present. + """ + log.debug("[#0000] C: Removing writer %r", address) + for database in self.routing_tables.keys(): + self.routing_tables[database].writers.discard(address) + log.debug("[#0000] C: table=%r", self.routing_tables) diff --git a/neo4j/_sync/work/__init__.py b/neo4j/_sync/work/__init__.py new file mode 100644 index 00000000..3ceebdb1 --- /dev/null +++ b/neo4j/_sync/work/__init__.py @@ -0,0 +1,32 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .session import ( + Result, + Session, + Transaction, + Workspace, +) + + +__all__ = [ + "Result", + "Session", + "Transaction", + "Workspace", +] diff --git a/neo4j/work/result.py b/neo4j/_sync/work/result.py similarity index 90% rename from neo4j/work/result.py rename to neo4j/_sync/work/result.py index bb3dba02..8d6342fd 100644 --- a/neo4j/work/result.py +++ b/neo4j/_sync/work/result.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -22,16 +19,16 @@ from collections import deque from warnings import warn -from neo4j.data import DataDehydrator -from neo4j.io import ConnectionErrorHandler -from neo4j.work.summary import ResultSummary -from neo4j.exceptions import ResultConsumedError +from ..._async_compat.util import Util +from ...data import DataDehydrator +from ...work import ResultSummary +from ..io import ConnectionErrorHandler class Result: """A handler for the result of Cypher query execution. Instances of this class are typically constructed and returned by - :meth:`.Session.run` and :meth:`.Transaction.run`. + :meth:`.AyncSession.run` and :meth:`.Transaction.run`. """ def __init__(self, connection, hydrant, fetch_size, on_closed, @@ -63,19 +60,23 @@ def _qid(self): else: return self._raw_qid - def _tx_ready_run(self, query, parameters, **kwparameters): + def _tx_ready_run(self, query, parameters, **kwargs): # BEGIN+RUN does not carry any extra on the RUN message. # BEGIN {extra} # RUN "query" {parameters} {extra} - self._run(query, parameters, None, None, None, None, **kwparameters) + self._run( + query, parameters, None, None, None, None, **kwargs + ) - def _run(self, query, parameters, db, imp_user, access_mode, bookmarks, - **kwparameters): + def _run( + self, query, parameters, db, imp_user, access_mode, bookmarks, + **kwargs + ): query_text = str(query) # Query or string object query_metadata = getattr(query, "metadata", None) query_timeout = getattr(query, "timeout", None) - parameters = DataDehydrator.fix_parameters(dict(parameters or {}, **kwparameters)) + parameters = DataDehydrator.fix_parameters(dict(parameters or {}, **kwargs)) self._metadata = { "query": query_text, @@ -95,7 +96,7 @@ def on_attached(metadata): def on_failed_attach(metadata): self._metadata.update(metadata) self._attached = False - self._on_closed() + Util.callback(self._on_closed) self._connection.run( query_text, @@ -120,11 +121,11 @@ def on_records(records): def on_summary(): self._attached = False - self._on_closed() + Util.callback(self._on_closed) def on_failure(metadata): self._attached = False - self._on_closed() + Util.callback(self._on_closed) def on_success(summary_metadata): self._streaming = False @@ -148,12 +149,12 @@ def on_success(summary_metadata): def _discard(self): def on_summary(): self._attached = False - self._on_closed() + Util.callback(self._on_closed) def on_failure(metadata): self._metadata.update(metadata) self._attached = False - self._on_closed() + Util.callback(self._on_closed) def on_success(summary_metadata): self._streaming = False @@ -203,7 +204,7 @@ def _attach(self): self._connection.fetch_message() def _buffer(self, n=None): - """Try to fill `self_record_buffer` with n records. + """Try to fill `self._record_buffer` with n records. Might end up with more records in the buffer if the fetch size makes it overshoot. @@ -259,7 +260,9 @@ def consume(self): Example:: def create_node_tx(tx, name): - result = tx.run("CREATE (n:ExampleNode { name: $name }) RETURN n", name=name) + result = tx.run( + "CREATE (n:ExampleNode { name: $name }) RETURN n", name=name + ) record = result.single() value = record.value() info = result.consume() @@ -273,11 +276,12 @@ def create_node_tx(tx, name): def get_two_tx(tx): result = tx.run("UNWIND [1,2,3,4] AS x RETURN x") values = [] - for ix, record in enumerate(result): - if x > 1: + for record in result: + if len(values) >= 2: break values.append(record.values()) - info = result.consume() # discard the remaining records if there are any + # discard the remaining records if there are any + info = result.consume() # use the info for logging etc. return values, info @@ -310,7 +314,8 @@ def single(self): # raise SomeError("Expected exactly 1 record, found %i" # % len(self._record_buffer)) # return self._record_buffer.popleft() - records = list(self) # TODO: exhausts the result with self.consume if there are more records. + # TODO: exhausts the result with self.consume if there are more records. + records = Util.list(self) size = len(records) if size == 0: return None diff --git a/neo4j/work/simple.py b/neo4j/_sync/work/session.py similarity index 74% rename from neo4j/work/simple.py rename to neo4j/_sync/work/session.py index 9c404da4..dbdf9482 100644 --- a/neo4j/work/simple.py +++ b/neo4j/_sync/work/session.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,31 +16,32 @@ # limitations under the License. +import asyncio from logging import getLogger from random import random -from time import ( - perf_counter, - sleep, -) +from time import perf_counter -from neo4j.api import ( +from ..._async_compat import sleep +from ...api import ( READ_ACCESS, WRITE_ACCESS, ) -from neo4j.conf import SessionConfig -from neo4j.data import DataHydrator -from neo4j.exceptions import ( +from ...conf import SessionConfig +from ...data import DataHydrator +from ...exceptions import ( ClientError, IncompleteCommit, Neo4jError, ServiceUnavailable, SessionExpired, - TransientError, TransactionError, + TransientError, ) -from neo4j.work import Workspace -from neo4j.work.result import Result -from neo4j.work.transaction import Transaction +from ...work import Query +from .result import Result +from .transaction import Transaction +from .workspace import Workspace + log = getLogger("neo4j") @@ -53,9 +51,10 @@ class Session(Workspace): of work. Connections are drawn from the :class:`.Driver` connection pool as required. - Session creation is a lightweight operation and sessions are not thread - safe. Therefore a session should generally be short-lived, and not - span multiple threads. + Session creation is a lightweight operation and sessions are not safe to + be used in concurrent contexts (multiple threads/coroutines). + Therefore, a session should generally be short-lived, and must not + span multiple threads/coroutines. In general, sessions will be created and destroyed within a `with` context. For example:: @@ -75,7 +74,7 @@ class Session(Workspace): _transaction = None # The current auto-transaction result, if any. - _autoResult = None + _auto_result = None # The state this session is in. _state_failed = False @@ -89,6 +88,8 @@ def __init__(self, pool, session_config): self._bookmarks = tuple(session_config.bookmarks) def __del__(self): + if asyncio.iscoroutinefunction(self.close): + return try: self.close() except (OSError, ServiceUnavailable, SessionExpired): @@ -112,14 +113,14 @@ def _collect_bookmark(self, bookmark): self._bookmarks = [bookmark] def _result_closed(self): - if self._autoResult: - self._collect_bookmark(self._autoResult._bookmark) - self._autoResult = None + if self._auto_result: + self._collect_bookmark(self._auto_result._bookmark) + self._auto_result = None self._disconnect() def _result_error(self, _): - if self._autoResult: - self._autoResult = None + if self._auto_result: + self._auto_result = None self._disconnect() def close(self): @@ -129,14 +130,14 @@ def close(self): roll back any outstanding transactions. """ if self._connection: - if self._autoResult: + if self._auto_result: if self._state_failed is False: try: - self._autoResult.consume() - self._collect_bookmark(self._autoResult._bookmark) + self._auto_result.consume() + self._collect_bookmark(self._auto_result._bookmark) except Exception as error: # TODO: Investigate potential non graceful close states - self._autoResult = None + self._auto_result = None self._state_failed = True if self._transaction: @@ -163,7 +164,7 @@ def close(self): self._state_failed = False self._closed = True - def run(self, query, parameters=None, **kwparameters): + def run(self, query, parameters=None, **kwargs): """Run a Cypher query within an auto-commit transaction. The query is sent and the result header received @@ -185,9 +186,9 @@ def run(self, query, parameters=None, **kwparameters): :type query: str, neo4j.Query :param parameters: dictionary of parameters :type parameters: dict - :param kwparameters: additional keyword parameters + :param kwargs: additional keyword parameters :returns: a new :class:`neo4j.Result` object - :rtype: :class:`neo4j.Result` + :rtype: Result """ if not query: raise ValueError("Cannot run an empty query") @@ -197,8 +198,9 @@ def run(self, query, parameters=None, **kwparameters): if self._transaction: raise ClientError("Explicit Transaction must be handled explicitly") - if self._autoResult: - self._autoResult._buffer_all() # This will buffer upp all records for the previous auto-transaction + if self._auto_result: + # This will buffer upp all records for the previous auto-transaction + self._auto_result._buffer_all() if not self._connection: self._connect(self._config.default_access_mode) @@ -208,17 +210,17 @@ def run(self, query, parameters=None, **kwparameters): hydrant = DataHydrator() - self._autoResult = Result( + self._auto_result = Result( cx, hydrant, self._config.fetch_size, self._result_closed, self._result_error ) - self._autoResult._run( + self._auto_result._run( query, parameters, self._config.database, self._config.impersonated_user, self._config.default_access_mode, - self._bookmarks, **kwparameters + self._bookmarks, **kwargs ) - return self._autoResult + return self._auto_result def last_bookmark(self): """Return the bookmark received following the last completed transaction. @@ -228,8 +230,8 @@ def last_bookmark(self): """ # The set of bookmarks to be passed into the next transaction. - if self._autoResult: - self._autoResult.consume() + if self._auto_result: + self._auto_result.consume() if self._transaction and self._transaction._closed: self._collect_bookmark(self._transaction._bookmark) @@ -286,25 +288,28 @@ def begin_transaction(self, metadata=None, timeout=None): :type timeout: int :returns: A new transaction instance. - :rtype: :class:`neo4j.Transaction` + :rtype: Transaction :raises TransactionError: :class:`neo4j.exceptions.TransactionError` if a transaction is already open. """ # TODO: Implement TransactionConfig consumption - if self._autoResult: - self._autoResult.consume() + if self._auto_result: + self._auto_result.consume() if self._transaction: raise TransactionError("Explicit transaction already open") - self._open_transaction(access_mode=self._config.default_access_mode, - metadata=metadata, timeout=timeout) + self._open_transaction( + access_mode=self._config.default_access_mode, metadata=metadata, + timeout=timeout + ) return self._transaction - def _run_transaction(self, access_mode, transaction_function, *args, **kwargs): - + def _run_transaction( + self, access_mode, transaction_function, *args, **kwargs + ): if not callable(transaction_function): raise TypeError("Unit of work is not callable") @@ -319,7 +324,10 @@ def _run_transaction(self, access_mode, transaction_function, *args, **kwargs): while True: try: - self._open_transaction(access_mode=access_mode, metadata=metadata, timeout=timeout) + self._open_transaction( + access_mode=access_mode, metadata=metadata, + timeout=timeout + ) tx = self._transaction try: result = transaction_function(tx, *args, **kwargs) @@ -358,15 +366,14 @@ def read_transaction(self, transaction_function, *args, **kwargs): This transaction will automatically be committed unless an exception is thrown during query execution or by the user code. Note, that this function perform retries and that the supplied `transaction_function` might get invoked more than once. - Managed transactions should not generally be explicitly committed (via tx.commit()). + Managed transactions should not generally be explicitly committed + (via ``tx.commit()``). Example:: def do_cypher_tx(tx, cypher): result = tx.run(cypher) - values = [] - for record in result: - values.append(record.values()) + values = [record.values() for record in result] return values with driver.session() as session: @@ -377,23 +384,29 @@ def do_cypher_tx(tx, cypher): def get_two_tx(tx): result = tx.run("UNWIND [1,2,3,4] AS x RETURN x") values = [] - for ix, record in enumerate(result): - if x > 1: + for record in result: + if len(values) >= 2: break values.append(record.values()) - info = result.consume() # discard the remaining records if there are any + # discard the remaining records if there are any + info = result.consume() # use the info for logging etc. return values with driver.session() as session: values = session.read_transaction(get_two_tx) - :param transaction_function: a function that takes a transaction as an argument and does work with the transaction. `tx_function(tx, *args, **kwargs)` + :param transaction_function: a function that takes a transaction as an + argument and does work with the transaction. + `transaction_function(tx, *args, **kwargs)` where `tx` is a + :class:`.Transaction`. :param args: arguments for the `transaction_function` :param kwargs: key word arguments for the `transaction_function` :return: a result as returned by the given unit of work """ - return self._run_transaction(READ_ACCESS, transaction_function, *args, **kwargs) + return self._run_transaction( + READ_ACCESS, transaction_function, *args, **kwargs + ) def write_transaction(self, transaction_function, *args, **kwargs): """Execute a unit of work in a managed write transaction. @@ -405,79 +418,25 @@ def write_transaction(self, transaction_function, *args, **kwargs): Example:: def create_node_tx(tx, name): - result = tx.run("CREATE (n:NodeExample { name: $name }) RETURN id(n) AS node_id", name=name) + query = "CREATE (n:NodeExample { name: $name }) RETURN id(n) AS node_id" + result = tx.run(query, name=name) record = result.single() return record["node_id"] with driver.session() as session: node_id = session.write_transaction(create_node_tx, "example") - - :param transaction_function: a function that takes a transaction as an argument and does work with the transaction. `tx_function(tx, *args, **kwargs)` + :param transaction_function: a function that takes a transaction as an + argument and does work with the transaction. + `transaction_function(tx, *args, **kwargs)` where `tx` is a + :class:`.Transaction`. :param args: key word arguments for the `transaction_function` :param kwargs: key word arguments for the `transaction_function` :return: a result as returned by the given unit of work """ - return self._run_transaction(WRITE_ACCESS, transaction_function, *args, **kwargs) - - -class Query: - """ Create a new query. - - :param text: The query text. - :type text: str - :param metadata: metadata attached to the query. - :type metadata: dict - :param timeout: seconds. - :type timeout: int - """ - def __init__(self, text, metadata=None, timeout=None): - self.text = text - - self.metadata = metadata - self.timeout = timeout - - def __str__(self): - return str(self.text) - - -def unit_of_work(metadata=None, timeout=None): - """This function is a decorator for transaction functions that allows extra control over how the transaction is carried out. - - For example, a timeout may be applied:: - - @unit_of_work(timeout=100) - def count_people_tx(tx): - result = tx.run("MATCH (a:Person) RETURN count(a) AS persons") - record = result.single() - return record["persons"] - - :param metadata: - a dictionary with metadata. - Specified metadata will be attached to the executing transaction and visible in the output of ``dbms.listQueries`` and ``dbms.listTransactions`` procedures. - It will also get logged to the ``query.log``. - This functionality makes it easier to tag transactions and is equivalent to ``dbms.setTXMetaData`` procedure, see https://neo4j.com/docs/operations-manual/current/reference/procedures/ for procedure reference. - :type metadata: dict - - :param timeout: - the transaction timeout in seconds. - Transactions that execute longer than the configured timeout will be terminated by the database. - This functionality allows to limit query/transaction execution time. - Specified timeout overrides the default timeout configured in the database using ``dbms.transaction.timeout`` setting. - Value should not represent a duration of zero or negative duration. - :type timeout: int - """ - - def wrapper(f): - - def wrapped(*args, **kwargs): - return f(*args, **kwargs) - - wrapped.metadata = metadata - wrapped.timeout = timeout - return wrapped - - return wrapper + return self._run_transaction( + WRITE_ACCESS, transaction_function, *args, **kwargs + ) def retry_delay_generator(initial_delay, multiplier, jitter_factor): diff --git a/neo4j/work/transaction.py b/neo4j/_sync/work/transaction.py similarity index 90% rename from neo4j/work/transaction.py rename to neo4j/_sync/work/transaction.py index 8b773987..73d08238 100644 --- a/neo4j/work/transaction.py +++ b/neo4j/_sync/work/transaction.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,22 +16,22 @@ # limitations under the License. -from neo4j.work.result import Result -from neo4j.data import DataHydrator -from neo4j.exceptions import ( - TransactionError, -) -from neo4j.io import ConnectionErrorHandler +from ..._async_compat.util import Util +from ...data import DataHydrator +from ...exceptions import TransactionError +from ...work import Query +from ..io import ConnectionErrorHandler +from .result import Result class Transaction: - """ Container for multiple Cypher queries to be executed within - a single context. Transactions can be used within a :py:const:`with` + """ Container for multiple Cypher queries to be executed within a single + context. asynctransactions can be used within a :py:const:`with` block where the transaction is committed or rolled back on based on - whether or not an exception is raised:: + whether an exception is raised:: with session.begin_transaction() as tx: - pass + ... """ @@ -62,7 +59,9 @@ def __exit__(self, exception_type, exception_value, traceback): self.commit() self.close() - def _begin(self, database, imp_user, bookmarks, access_mode, metadata, timeout): + def _begin( + self, database, imp_user, bookmarks, access_mode, metadata, timeout + ): self._connection.begin( bookmarks=bookmarks, metadata=metadata, timeout=timeout, mode=access_mode, db=database, imp_user=imp_user @@ -75,7 +74,7 @@ def _result_on_closed_handler(self): def _error_handler(self, exc): self._last_error = exc - self._on_error(exc) + Util.callback(self._on_error, exc) def _consume_results(self): for result in self._results: @@ -110,7 +109,6 @@ def run(self, query, parameters=None, **kwparameters): :rtype: :class:`neo4j.Result` :raise TransactionError: if the transaction is already closed """ - from neo4j.work.simple import Query if isinstance(query, Query): raise ValueError("Query object is only supported for session.run") @@ -151,14 +149,15 @@ def commit(self): metadata = {} try: - self._consume_results() # DISCARD pending records then do a commit. + # DISCARD pending records then do a commit. + self._consume_results() self._connection.commit(on_success=metadata.update) self._connection.send_all() self._connection.fetch_all() self._bookmark = metadata.get("bookmark") finally: self._closed = True - self._on_closed() + Util.callback(self._on_closed) return self._bookmark @@ -182,7 +181,7 @@ def rollback(self): self._connection.fetch_all() finally: self._closed = True - self._on_closed() + Util.callback(self._on_closed) def close(self): """Close this transaction, triggering a ROLLBACK if not closed. diff --git a/neo4j/_sync/work/workspace.py b/neo4j/_sync/work/workspace.py new file mode 100644 index 00000000..3ed50ad2 --- /dev/null +++ b/neo4j/_sync/work/workspace.py @@ -0,0 +1,102 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio + +from ...conf import WorkspaceConfig +from ...exceptions import ServiceUnavailable +from ..io import Neo4jPool + + +class Workspace: + + def __init__(self, pool, config): + assert isinstance(config, WorkspaceConfig) + self._pool = pool + self._config = config + self._connection = None + self._connection_access_mode = None + # Sessions are supposed to cache the database on which to operate. + self._cached_database = False + self._bookmarks = None + + def __del__(self): + if asyncio.iscoroutinefunction(self.close): + return + try: + self.close() + except OSError: + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def _set_cached_database(self, database): + self._cached_database = True + self._config.database = database + + def _connect(self, access_mode): + if self._connection: + # TODO: Investigate this + # log.warning("FIXME: should always disconnect before connect") + self._connection.send_all() + self._connection.fetch_all() + self._disconnect() + if not self._cached_database: + if (self._config.database is not None + or not isinstance(self._pool, Neo4jPool)): + self._set_cached_database(self._config.database) + else: + # This is the first time we open a connection to a server in a + # cluster environment for this session without explicitly + # configured database. Hence, we request a routing table update + # to try to fetch the home database. If provided by the server, + # we shall use this database explicitly for all subsequent + # actions within this session. + self._pool.update_routing_table( + database=self._config.database, + imp_user=self._config.impersonated_user, + bookmarks=self._bookmarks, + database_callback=self._set_cached_database + ) + self._connection = self._pool.acquire( + access_mode=access_mode, + timeout=self._config.connection_acquisition_timeout, + database=self._config.database, + bookmarks=self._bookmarks + ) + self._connection_access_mode = access_mode + + def _disconnect(self, sync=False): + if self._connection: + if sync: + try: + self._connection.send_all() + self._connection.fetch_all() + except ServiceUnavailable: + pass + if self._connection: + self._pool.release(self._connection) + self._connection = None + self._connection_access_mode = None + + def close(self): + self._disconnect(sync=True) diff --git a/neo4j/addressing.py b/neo4j/addressing.py index f547f05d..780e2c23 100644 --- a/neo4j/addressing.py +++ b/neo4j/addressing.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,14 +16,12 @@ # limitations under the License. +import logging from socket import ( - getaddrinfo, - getservbyname, - SOCK_STREAM, AF_INET, AF_INET6, + getservbyname, ) -import logging log = logging.getLogger("neo4j") @@ -142,59 +137,6 @@ def port(self): def unresolved(self): return self - @classmethod - def _dns_resolve(cls, address, family=0): - """ Regular DNS resolver. Takes an address object and optional - address family for filtering. - - :param address: - :param family: - :return: - """ - try: - info = getaddrinfo(address.host, address.port, family, SOCK_STREAM) - except OSError: - raise ValueError("Cannot resolve address {}".format(address)) - else: - resolved = [] - for fam, _, _, _, addr in info: - if fam == AF_INET6 and addr[3] != 0: - # skip any IPv6 addresses with a non-zero scope id - # as these appear to cause problems on some platforms - continue - if addr not in resolved: - resolved.append(ResolvedAddress( - addr, host_name=address.host_name) - ) - return resolved - - def resolve(self, family=0, resolver=None): - """ Carry out domain name resolution on this Address object. - - If a resolver function is supplied, and is callable, this is - called first, with this object as its argument. This may yield - multiple output addresses, which are chained into a subsequent - regular DNS resolution call. If no resolver function is passed, - the DNS resolution is carried out on the original Address - object. - - This function returns a list of resolved Address objects. - - :param family: optional address family to filter resolved - addresses by (e.g. AF_INET6) - :param resolver: optional customer resolver function to be - called before regular DNS resolution - """ - - log.debug("[#0000] C: %s", self) - resolved = [] - if resolver: - for address in map(Address, resolver(self)): - resolved.extend(self._dns_resolve(address, family)) - else: - resolved.extend(self._dns_resolve(self, family)) - return resolved - @property def port_number(self): try: @@ -234,9 +176,6 @@ def host_name(self): def unresolved(self): return super().__new__(Address, (self._host_name, *self[1:])) - def resolve(self, family=0, resolver=None): - return [self] - def __new__(cls, iterable, host_name=None): new = super().__new__(cls, iterable) new._host_name = host_name diff --git a/neo4j/api.py b/neo4j/api.py index 1275c2a3..36c05cbb 100644 --- a/neo4j/api.py +++ b/neo4j/api.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -18,20 +15,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" Base classes and helpers. +""" + + from urllib.parse import ( - urlparse, parse_qs, + urlparse, ) -from.exceptions import ( - DriverError, + +from .exceptions import ( ConfigurationError, + DriverError, ) from .meta import deprecated -""" Base classes and helpers. -""" - READ_ACCESS = "READ" WRITE_ACCESS = "WRITE" @@ -107,7 +106,8 @@ def basic_auth(user, password, realm=None): :param realm: specifies the authentication provider :type realm: str or None - :return: auth token for use with :meth:`GraphDatabase.driver` + :return: auth token for use with :meth:`GraphDatabase.driver` or + :meth:`AsyncGraphDatabase.driver` :rtype: :class:`neo4j.Auth` """ return Auth("basic", user, password, realm) @@ -122,7 +122,8 @@ def kerberos_auth(base64_encoded_ticket): the credentials :type base64_encoded_ticket: str - :return: auth token for use with :meth:`GraphDatabase.driver` + :return: auth token for use with :meth:`GraphDatabase.driver` or + :meth:`AsyncGraphDatabase.driver` :rtype: :class:`neo4j.Auth` """ return Auth("kerberos", "", base64_encoded_ticket) @@ -137,7 +138,8 @@ def bearer_auth(base64_encoded_token): by a Single-Sign-On provider. :type base64_encoded_token: str - :return: auth token for use with :meth:`GraphDatabase.driver` + :return: auth token for use with :meth:`GraphDatabase.driver` or + :meth:`AsyncGraphDatabase.driver` :rtype: :class:`neo4j.Auth` """ return Auth("bearer", None, base64_encoded_token) @@ -158,7 +160,8 @@ def custom_auth(principal, credentials, realm, scheme, **parameters): authentication provider :type parameters: Dict[str, Any] - :return: auth token for use with :meth:`GraphDatabase.driver` + :return: auth token for use with :meth:`GraphDatabase.driver` or + :meth:`AsyncGraphDatabase.driver` :rtype: :class:`neo4j.Auth` """ return Auth(scheme, principal, credentials, realm, **parameters) diff --git a/neo4j/conf.py b/neo4j/conf.py index fb86c489..7131ba56 100644 --- a/neo4j/conf.py +++ b/neo4j/conf.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -23,17 +20,14 @@ from collections.abc import Mapping from warnings import warn -from neo4j.meta import get_user_agent - -from neo4j.api import ( - TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, +from .api import ( + DEFAULT_DATABASE, TRUST_ALL_CERTIFICATES, + TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, WRITE_ACCESS, - DEFAULT_DATABASE, -) -from neo4j.exceptions import ( - ConfigurationError, ) +from .exceptions import ConfigurationError +from .meta import get_user_agent def iter_items(iterable): diff --git a/neo4j/data.py b/neo4j/data.py index e1a3a959..d60937cd 100644 --- a/neo4j/data.py +++ b/neo4j/data.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,22 +16,57 @@ # limitations under the License. -from abc import ABCMeta, abstractmethod -from collections.abc import Sequence, Set, Mapping -from datetime import date, time, datetime, timedelta +from abc import ( + ABCMeta, + abstractmethod, +) +from collections.abc import ( + Mapping, + Sequence, + Set, +) +from datetime import ( + date, + datetime, + time, + timedelta, +) from functools import reduce from operator import xor as xor_operator -from neo4j.conf import iter_items -from neo4j.graph import Graph, Node, Relationship, Path -from neo4j.packstream import INT64_MIN, INT64_MAX, Structure -from neo4j.spatial import Point, hydrate_point, dehydrate_point -from neo4j.time import Date, Time, DateTime, Duration -from neo4j.time.hydration import ( - hydrate_date, dehydrate_date, - hydrate_time, dehydrate_time, - hydrate_datetime, dehydrate_datetime, - hydrate_duration, dehydrate_duration, dehydrate_timedelta, +from .conf import iter_items +from .graph import ( + Graph, + Node, + Path, + Relationship, +) +from .packstream import ( + INT64_MAX, + INT64_MIN, + Structure, +) +from .spatial import ( + dehydrate_point, + hydrate_point, + Point, +) +from .time import ( + Date, + DateTime, + Duration, + Time, +) +from .time.hydration import ( + dehydrate_date, + dehydrate_datetime, + dehydrate_duration, + dehydrate_time, + dehydrate_timedelta, + hydrate_date, + hydrate_datetime, + hydrate_duration, + hydrate_time, ) @@ -44,7 +76,7 @@ class Record(tuple, Mapping): """ A :class:`.Record` is an immutable ordered collection of key-value pairs. It is generally closer to a :py:class:`namedtuple` than to a - :py:class:`OrderedDict` inasmuch as iteration of the collection will + :py:class:`OrderedDict` in as much as iteration of the collection will yield values rather than keys. """ diff --git a/neo4j/debug.py b/neo4j/debug.py index b8872a37..7b2f0db5 100644 --- a/neo4j/debug.py +++ b/neo4j/debug.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,7 +16,16 @@ # limitations under the License. -from logging import CRITICAL, ERROR, WARNING, INFO, DEBUG, Formatter, StreamHandler, getLogger +from logging import ( + CRITICAL, + DEBUG, + ERROR, + Formatter, + getLogger, + INFO, + StreamHandler, + WARNING, +) from sys import stderr diff --git a/neo4j/exceptions.py b/neo4j/exceptions.py index c69c75c7..2b944bcf 100644 --- a/neo4j/exceptions.py +++ b/neo4j/exceptions.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # diff --git a/neo4j/graph/__init__.py b/neo4j/graph/__init__.py index 66285af7..bc1e8a52 100644 --- a/neo4j/graph/__init__.py +++ b/neo4j/graph/__init__.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -18,22 +15,23 @@ # See the License for the specific language governing permissions and # limitations under the License. + """ Graph data types """ -from collections.abc import Mapping - - __all__ = [ "Graph", "Node", - "Relationship", "Path", + "Relationship", ] +from collections.abc import Mapping + + class Graph: """ Local, self-contained graph object that acts as a container for :class:`.Node` and :class:`.Relationship` instances. diff --git a/neo4j/io/README.rst b/neo4j/io/README.rst deleted file mode 100644 index dfe8742a..00000000 --- a/neo4j/io/README.rst +++ /dev/null @@ -1 +0,0 @@ -Regular (non-async) I/O for Neo4j. \ No newline at end of file diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py deleted file mode 100644 index 348bbce8..00000000 --- a/neo4j/io/__init__.py +++ /dev/null @@ -1,1414 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [http://neo4j.com] -# -# This file is part of Neo4j. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This module contains the low-level functionality required for speaking -Bolt. It is not intended to be used directly by driver users. Instead, -the `session` module provides the main user-facing abstractions. -""" - - -__all__ = [ - "Bolt", - "BoltPool", - "ConnectionErrorHandler", - "Neo4jPool", - "check_supported_server_product", -] - -import abc -from collections import ( - defaultdict, - deque, -) -import logging -from logging import getLogger -from random import choice -import selectors -from socket import ( - AF_INET, - AF_INET6, - SHUT_RDWR, - SO_KEEPALIVE, - socket, - SOL_SOCKET, - timeout as SocketTimeout, -) -from ssl import ( - CertificateError, - HAS_SNI, - SSLError, -) -from threading import ( - Condition, - RLock, -) -from time import perf_counter - -from neo4j._exceptions import ( - BoltError, - BoltHandshakeError, - BoltProtocolError, - BoltSecurityError, -) -from neo4j.addressing import Address -from neo4j.api import ( - READ_ACCESS, - ServerInfo, - Version, - WRITE_ACCESS, -) -from neo4j.conf import ( - PoolConfig, - WorkspaceConfig, -) -from neo4j.exceptions import ( - AuthError, - ClientError, - ConfigurationError, - DriverError, - IncompleteCommit, - Neo4jError, - ReadServiceUnavailable, - ServiceUnavailable, - SessionExpired, - UnsupportedServerProduct, - WriteServiceUnavailable, -) -from neo4j.io._common import ( - CommitResponse, - ConnectionErrorHandler, - Inbox, - InitResponse, - Outbox, - Response, -) -from neo4j.meta import get_user_agent -from neo4j.packstream import ( - Packer, - Unpacker, -) -from neo4j.routing import RoutingTable - -# Set up logger -log = getLogger("neo4j") - - -class Bolt(abc.ABC): - """ Server connection for Bolt protocol. - - A :class:`.Bolt` should be constructed following a - successful .open() - - Bolt handshake and takes the socket over which - the handshake was carried out. - """ - - MAGIC_PREAMBLE = b"\x60\x60\xB0\x17" - - PROTOCOL_VERSION = None - - # flag if connection needs RESET to go back to READY state - is_reset = False - - # The socket - in_use = False - - # The socket - _closed = False - - # The socket - _defunct = False - - #: The pool of which this connection is a member - pool = None - - # Store the id of the most recent ran query to be able to reduce sent bits by - # using the default (-1) to refer to the most recent query when pulling - # results for it. - most_recent_qid = None - - def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, user_agent=None, routing_context=None): - self.unresolved_address = unresolved_address - self.socket = sock - self.server_info = ServerInfo(Address(sock.getpeername()), self.PROTOCOL_VERSION) - # so far `connection.recv_timeout_seconds` is the only available - # configuration hint that exists. Therefore, all hints can be stored at - # connection level. This might change in the future. - self.configuration_hints = {} - self.outbox = Outbox() - self.inbox = Inbox(self.socket, on_error=self._set_defunct_read) - self.packer = Packer(self.outbox) - self.unpacker = Unpacker(self.inbox) - self.responses = deque() - self._max_connection_lifetime = max_connection_lifetime - self._creation_timestamp = perf_counter() - self.routing_context = routing_context - - # Determine the user agent - if user_agent: - self.user_agent = user_agent - else: - self.user_agent = get_user_agent() - - # Determine auth details - if not auth: - self.auth_dict = {} - elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: - from neo4j import Auth - self.auth_dict = vars(Auth("basic", *auth)) - else: - try: - self.auth_dict = vars(auth) - except (KeyError, TypeError): - raise AuthError("Cannot determine auth details from %r" % auth) - - # Check for missing password - try: - credentials = self.auth_dict["credentials"] - except KeyError: - pass - else: - if credentials is None: - raise AuthError("Password cannot be None") - - @property - @abc.abstractmethod - def supports_multiple_results(self): - """ Boolean flag to indicate if the connection version supports multiple - queries to be buffered on the server side (True) or if all results need - to be eagerly pulled before sending the next RUN (False). - """ - pass - - @property - @abc.abstractmethod - def supports_multiple_databases(self): - """ Boolean flag to indicate if the connection version supports multiple - databases. - """ - pass - - @classmethod - def protocol_handlers(cls, protocol_version=None): - """ Return a dictionary of available Bolt protocol handlers, - keyed by version tuple. If an explicit protocol version is - provided, the dictionary will contain either zero or one items, - depending on whether that version is supported. If no protocol - version is provided, all available versions will be returned. - - :param protocol_version: tuple identifying a specific protocol - version (e.g. (3, 5)) or None - :return: dictionary of version tuple to handler class for all - relevant and supported protocol versions - :raise TypeError: if protocol version is not passed in a tuple - """ - - # Carry out Bolt subclass imports locally to avoid circular dependency issues. - from neo4j.io._bolt3 import Bolt3 - from neo4j.io._bolt4 import Bolt4x0, Bolt4x1, Bolt4x2, Bolt4x3, Bolt4x4 - - handlers = { - Bolt3.PROTOCOL_VERSION: Bolt3, - Bolt4x0.PROTOCOL_VERSION: Bolt4x0, - Bolt4x1.PROTOCOL_VERSION: Bolt4x1, - Bolt4x2.PROTOCOL_VERSION: Bolt4x2, - Bolt4x3.PROTOCOL_VERSION: Bolt4x3, - Bolt4x4.PROTOCOL_VERSION: Bolt4x4, - } - - if protocol_version is None: - return handlers - - if not isinstance(protocol_version, tuple): - raise TypeError("Protocol version must be specified as a tuple") - - if protocol_version in handlers: - return {protocol_version: handlers[protocol_version]} - - return {} - - @classmethod - def version_list(cls, versions, limit=4): - """ Return a list of supported protocol versions in order of - preference. The number of protocol versions (or ranges) - returned is limited to four. - """ - # In fact, 4.3 is the fist version to support ranges. However, the range - # support got backported to 4.2. But even if the server is too old to - # have the backport, negotiating BOLT 4.1 is no problem as it's - # equivalent to 4.2 - first_with_range_support = Version(4, 2) - result = [] - for version in versions: - if (result - and version >= first_with_range_support - and result[-1][0] == version[0] - and result[-1][1][1] == version[1] + 1): - # can use range to encompass this version - result[-1][1][1] = version[1] - continue - result.append(Version(version[0], [version[1], version[1]])) - if len(result) == 4: - break - return result - - @classmethod - def get_handshake(cls): - """ Return the supported Bolt versions as bytes. - The length is 16 bytes as specified in the Bolt version negotiation. - :return: bytes - """ - supported_versions = sorted(cls.protocol_handlers().keys(), reverse=True) - offered_versions = cls.version_list(supported_versions) - return b"".join(version.to_bytes() for version in offered_versions).ljust(16, b"\x00") - - @classmethod - def ping(cls, address, *, timeout=None, **config): - """ Attempt to establish a Bolt connection, returning the - agreed Bolt protocol version if successful. - """ - config = PoolConfig.consume(config) - try: - s, protocol_version, handshake, data = connect( - address, - timeout=timeout, - custom_resolver=config.resolver, - ssl_context=config.get_ssl_context(), - keep_alive=config.keep_alive, - ) - except (ServiceUnavailable, SessionExpired, BoltHandshakeError): - return None - else: - _close_socket(s) - return protocol_version - - @classmethod - def open(cls, address, *, auth=None, timeout=None, routing_context=None, **pool_config): - """ Open a new Bolt connection to a given server address. - - :param address: - :param auth: - :param timeout: the connection timeout in seconds - :param routing_context: dict containing routing context - :param pool_config: - :return: - :raise BoltHandshakeError: raised if the Bolt Protocol can not negotiate a protocol version. - :raise ServiceUnavailable: raised if there was a connection issue. - """ - pool_config = PoolConfig.consume(pool_config) - s, pool_config.protocol_version, handshake, data = connect( - address, - timeout=timeout, - custom_resolver=pool_config.resolver, - ssl_context=pool_config.get_ssl_context(), - keep_alive=pool_config.keep_alive, - ) - - # Carry out Bolt subclass imports locally to avoid circular dependency - # issues. - if pool_config.protocol_version == (3, 0): - from neo4j.io._bolt3 import Bolt3 - bolt_cls = Bolt3 - elif pool_config.protocol_version == (4, 0): - from neo4j.io._bolt4 import Bolt4x0 - bolt_cls = Bolt4x0 - elif pool_config.protocol_version == (4, 1): - from neo4j.io._bolt4 import Bolt4x1 - bolt_cls = Bolt4x1 - elif pool_config.protocol_version == (4, 2): - from neo4j.io._bolt4 import Bolt4x2 - bolt_cls = Bolt4x2 - elif pool_config.protocol_version == (4, 3): - from neo4j.io._bolt4 import Bolt4x3 - bolt_cls = Bolt4x3 - elif pool_config.protocol_version == (4, 4): - from neo4j.io._bolt4 import Bolt4x4 - bolt_cls = Bolt4x4 - else: - log.debug("[#%04X] S: ", s.getsockname()[1]) - _close_socket(s) - - supported_versions = Bolt.protocol_handlers().keys() - raise BoltHandshakeError("The Neo4J server does not support communication with this driver. This driver have support for Bolt Protocols {}".format(supported_versions), address=address, request_data=handshake, response_data=data) - - connection = bolt_cls( - address, s, pool_config.max_connection_lifetime, auth=auth, - user_agent=pool_config.user_agent, routing_context=routing_context - ) - - try: - connection.hello() - except Exception: - connection.close() - raise - - return connection - - @property - @abc.abstractmethod - def encrypted(self): - pass - - @property - @abc.abstractmethod - def der_encoded_server_certificate(self): - pass - - @property - @abc.abstractmethod - def local_port(self): - pass - - @abc.abstractmethod - def hello(self): - """ Appends a HELLO message to the outgoing queue, sends it and consumes - all remaining messages. - """ - pass - - def __del__(self): - try: - self.close() - except OSError: - pass - - @abc.abstractmethod - def route(self, database=None, imp_user=None, bookmarks=None): - """ Fetch a routing table from the server for the given - `database`. For Bolt 4.3 and above, this appends a ROUTE - message; for earlier versions, a procedure call is made via - the regular Cypher execution mechanism. In all cases, this is - sent to the network, and a response is fetched. - - :param database: database for which to fetch a routing table - :param imp_user: the user to impersonate - :param bookmarks: iterable of bookmark values after which this - transaction should begin - :return: dictionary of raw routing data - """ - pass - - @abc.abstractmethod - def run(self, query, parameters=None, mode=None, bookmarks=None, - metadata=None, timeout=None, db=None, imp_user=None, **handlers): - """ Appends a RUN message to the output queue. - - :param query: Cypher query string - :param parameters: dictionary of Cypher parameters - :param mode: access mode for routing - "READ" or "WRITE" (default) - :param bookmarks: iterable of bookmark values after which this transaction should begin - :param metadata: custom metadata dictionary to attach to the transaction - :param timeout: timeout for transaction execution (seconds) - :param db: name of the database against which to begin the transaction - :param imp_user: the user to impersonate - :param handlers: handler functions passed into the returned Response object - :return: Response object - """ - pass - - @abc.abstractmethod - def discard(self, n=-1, qid=-1, **handlers): - """ Appends a DISCARD message to the output queue. - - :param n: number of records to discard, default = -1 (ALL) - :param qid: query ID to discard for, default = -1 (last query) - :param handlers: handler functions passed into the returned Response object - :return: Response object - """ - pass - - @abc.abstractmethod - def pull(self, n=-1, qid=-1, **handlers): - """ Appends a PULL message to the output queue. - - :param n: number of records to pull, default = -1 (ALL) - :param qid: query ID to pull for, default = -1 (last query) - :param handlers: handler functions passed into the returned Response object - :return: Response object - """ - pass - - @abc.abstractmethod - def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, - db=None, imp_user=None, **handlers): - """ Appends a BEGIN message to the output queue. - - :param mode: access mode for routing - "READ" or "WRITE" (default) - :param bookmarks: iterable of bookmark values after which this transaction should begin - :param metadata: custom metadata dictionary to attach to the transaction - :param timeout: timeout for transaction execution (seconds) - :param db: name of the database against which to begin the transaction - :param imp_user: the user to impersonate - :param handlers: handler functions passed into the returned Response object - :return: Response object - """ - pass - - @abc.abstractmethod - def commit(self, **handlers): - """ Appends a COMMIT message to the output queue.""" - pass - - @abc.abstractmethod - def rollback(self, **handlers): - """ Appends a ROLLBACK message to the output queue.""" - pass - - @abc.abstractmethod - def reset(self): - """ Appends a RESET message to the outgoing queue, sends it and consumes - all remaining messages. - """ - pass - - def _append(self, signature, fields=(), response=None): - """ Appends a message to the outgoing queue. - - :param signature: the signature of the message - :param fields: the fields of the message as a tuple - :param response: a response object to handle callbacks - """ - self.packer.pack_struct(signature, fields) - self.outbox.wrap_message() - self.responses.append(response) - - def _send_all(self): - data = self.outbox.view() - if data: - try: - self.socket.sendall(data) - except OSError as error: - self._set_defunct_write(error) - self.outbox.clear() - - def send_all(self): - """ Send all queued messages to the server. - """ - if self.closed(): - raise ServiceUnavailable("Failed to write to closed connection {!r} ({!r})".format( - self.unresolved_address, self.server_info.address)) - - if self.defunct(): - raise ServiceUnavailable("Failed to write to defunct connection {!r} ({!r})".format( - self.unresolved_address, self.server_info.address)) - - self._send_all() - - @abc.abstractmethod - def fetch_message(self): - """ Receive at most one message from the server, if available. - - :return: 2-tuple of number of detail messages and number of summary - messages fetched - """ - pass - - def fetch_all(self): - """ Fetch all outstanding messages. - - :return: 2-tuple of number of detail messages and number of summary - messages fetched - """ - detail_count = summary_count = 0 - while self.responses: - response = self.responses[0] - while not response.complete: - detail_delta, summary_delta = self.fetch_message() - detail_count += detail_delta - summary_count += summary_delta - return detail_count, summary_count - - def _set_defunct_read(self, error=None, silent=False): - message = "Failed to read from defunct connection {!r} ({!r})".format( - self.unresolved_address, self.server_info.address - ) - self._set_defunct(message, error=error, silent=silent) - - def _set_defunct_write(self, error=None, silent=False): - message = "Failed to write data to connection {!r} ({!r})".format( - self.unresolved_address, self.server_info.address - ) - self._set_defunct(message, error=error, silent=silent) - - def _set_defunct(self, message, error=None, silent=False): - direct_driver = isinstance(self.pool, BoltPool) - - if error: - log.debug("[#%04X] %s", self.socket.getsockname()[1], error) - log.error(message) - # We were attempting to receive data but the connection - # has unexpectedly terminated. So, we need to close the - # connection from the client side, and remove the address - # from the connection pool. - self._defunct = True - self.close() - if self.pool: - self.pool.deactivate(address=self.unresolved_address) - # Iterate through the outstanding responses, and if any correspond - # to COMMIT requests then raise an error to signal that we are - # unable to confirm that the COMMIT completed successfully. - if silent: - return - for response in self.responses: - if isinstance(response, CommitResponse): - if error: - raise IncompleteCommit(message) from error - else: - raise IncompleteCommit(message) - - if direct_driver: - if error: - raise ServiceUnavailable(message) from error - else: - raise ServiceUnavailable(message) - else: - if error: - raise SessionExpired(message) from error - else: - raise SessionExpired(message) - - def stale(self): - return (self._stale - or (0 <= self._max_connection_lifetime - <= perf_counter() - self._creation_timestamp)) - - _stale = False - - def set_stale(self): - self._stale = True - - @abc.abstractmethod - def close(self): - """ Close the connection. - """ - pass - - @abc.abstractmethod - def closed(self): - pass - - @abc.abstractmethod - def defunct(self): - pass - - -class IOPool: - """ A collection of connections to one or more server addresses. - """ - - def __init__(self, opener, pool_config, workspace_config): - assert callable(opener) - assert isinstance(pool_config, PoolConfig) - assert isinstance(workspace_config, WorkspaceConfig) - - self.opener = opener - self.pool_config = pool_config - self.workspace_config = workspace_config - self.connections = defaultdict(deque) - self.lock = RLock() - self.cond = Condition(self.lock) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.close() - - def _acquire(self, address, timeout): - """ Acquire a connection to a given address from the pool. - The address supplied should always be an IP address, not - a host name. - - This method is thread safe. - """ - t0 = perf_counter() - if timeout is None: - timeout = self.workspace_config.connection_acquisition_timeout - - with self.lock: - def time_remaining(): - t = timeout - (perf_counter() - t0) - return t if t > 0 else 0 - - while True: - # try to find a free connection in pool - for connection in list(self.connections.get(address, [])): - if (connection.closed() or connection.defunct() - or (connection.stale() and not connection.in_use)): - # `close` is a noop on already closed connections. - # This is to make sure that the connection is gracefully - # closed, e.g. if it's just marked as `stale` but still - # alive. - if log.isEnabledFor(logging.DEBUG): - log.debug( - "[#%04X] C: removing old connection " - "(closed=%s, defunct=%s, stale=%s, in_use=%s)", - connection.local_port, - connection.closed(), connection.defunct(), - connection.stale(), connection.in_use - ) - connection.close() - try: - self.connections.get(address, []).remove(connection) - except ValueError: - # If closure fails (e.g. because the server went - # down), all connections to the same address will - # be removed. Therefore, we silently ignore if the - # connection isn't in the pool anymore. - pass - continue - if not connection.in_use: - connection.in_use = True - return connection - # all connections in pool are in-use - connections = self.connections[address] - max_pool_size = self.pool_config.max_connection_pool_size - infinite_pool_size = (max_pool_size < 0 - or max_pool_size == float("inf")) - can_create_new_connection = ( - infinite_pool_size - or len(connections) < max_pool_size - ) - if can_create_new_connection: - timeout = min(self.pool_config.connection_timeout, - time_remaining()) - try: - connection = self.opener(address, timeout) - except ServiceUnavailable: - self.remove(address) - raise - else: - connection.pool = self - connection.in_use = True - connections.append(connection) - return connection - - # failed to obtain a connection from pool because the - # pool is full and no free connection in the pool - if time_remaining(): - self.cond.wait(time_remaining()) - # if timed out, then we throw error. This time - # computation is needed, as with python 2.7, we - # cannot tell if the condition is notified or - # timed out when we come to this line - if not time_remaining(): - raise ClientError("Failed to obtain a connection from pool " - "within {!r}s".format(timeout)) - else: - raise ClientError("Failed to obtain a connection from pool " - "within {!r}s".format(timeout)) - - def acquire(self, access_mode=None, timeout=None, database=None, - bookmarks=None): - """ Acquire a connection to a server that can satisfy a set of parameters. - - :param access_mode: - :param timeout: - :param database: - :param bookmarks: - """ - - def release(self, *connections): - """ Release a connection back into the pool. - This method is thread safe. - """ - with self.lock: - for connection in connections: - if not (connection.defunct() - or connection.closed() - or connection.is_reset): - try: - connection.reset() - except (Neo4jError, DriverError, BoltError) as e: - log.debug( - "Failed to reset connection on release: %s", e - ) - connection.in_use = False - self.cond.notify_all() - - def in_use_connection_count(self, address): - """ Count the number of connections currently in use to a given - address. - """ - try: - connections = self.connections[address] - except KeyError: - return 0 - else: - return sum(1 if connection.in_use else 0 for connection in connections) - - def mark_all_stale(self): - with self.lock: - for address in self.connections: - for connection in self.connections[address]: - connection.set_stale() - - def deactivate(self, address): - """ Deactivate an address from the connection pool, if present, closing - all idle connection to that address - """ - with self.lock: - try: - connections = self.connections[address] - except KeyError: # already removed from the connection pool - return - for conn in list(connections): - if not conn.in_use: - connections.remove(conn) - try: - conn.close() - except OSError: - pass - if not connections: - self.remove(address) - - def on_write_failure(self, address): - raise WriteServiceUnavailable("No write service available for pool {}".format(self)) - - def remove(self, address): - """ Remove an address from the connection pool, if present, closing - all connections to that address. - """ - with self.lock: - for connection in self.connections.pop(address, ()): - try: - connection.close() - except OSError: - pass - - def close(self): - """ Close all connections and empty the pool. - This method is thread safe. - """ - try: - with self.lock: - for address in list(self.connections): - self.remove(address) - except TypeError: - pass - - -class BoltPool(IOPool): - - @classmethod - def open(cls, address, *, auth, pool_config, workspace_config): - """Create a new BoltPool - - :param address: - :param auth: - :param pool_config: - :param workspace_config: - :return: BoltPool - """ - - def opener(addr, timeout): - return Bolt.open( - addr, auth=auth, timeout=timeout, routing_context=None, - **pool_config - ) - - pool = cls(opener, pool_config, workspace_config, address) - return pool - - def __init__(self, opener, pool_config, workspace_config, address): - super(BoltPool, self).__init__(opener, pool_config, workspace_config) - self.address = address - - def __repr__(self): - return "<{} address={!r}>".format(self.__class__.__name__, self.address) - - def acquire(self, access_mode=None, timeout=None, database=None, bookmarks=None): - # The access_mode and database is not needed for a direct connection, its just there for consistency. - return self._acquire(self.address, timeout) - - -class Neo4jPool(IOPool): - """ Connection pool with routing table. - """ - - @classmethod - def open(cls, *addresses, auth, pool_config, workspace_config, routing_context=None): - """Create a new Neo4jPool - - :param addresses: one or more address as positional argument - :param auth: - :param pool_config: - :param workspace_config: - :param routing_context: - :return: Neo4jPool - """ - - address = addresses[0] - if routing_context is None: - routing_context = {} - elif "address" in routing_context: - raise ConfigurationError("The key 'address' is reserved for routing context.") - routing_context["address"] = str(address) - - def opener(addr, timeout): - return Bolt.open(addr, auth=auth, timeout=timeout, - routing_context=routing_context, **pool_config) - - pool = cls(opener, pool_config, workspace_config, address) - return pool - - def __init__(self, opener, pool_config, workspace_config, address): - """ - - :param opener: - :param pool_config: - :param workspace_config: - :param address: - """ - super(Neo4jPool, self).__init__(opener, pool_config, workspace_config) - # Each database have a routing table, the default database is a special case. - log.debug("[#0000] C: routing address %r", address) - self.address = address - self.routing_tables = {workspace_config.database: RoutingTable(database=workspace_config.database, routers=[address])} - self.refresh_lock = RLock() - - def __repr__(self): - """ The representation shows the initial routing addresses. - - :return: The representation - :rtype: str - """ - return "<{} addresses={!r}>".format(self.__class__.__name__, self.get_default_database_initial_router_addresses()) - - @property - def first_initial_routing_address(self): - return self.get_default_database_initial_router_addresses()[0] - - def get_default_database_initial_router_addresses(self): - """ Get the initial router addresses for the default database. - - :return: - :rtype: OrderedSet - """ - return self.get_routing_table_for_default_database().initial_routers - - def get_default_database_router_addresses(self): - """ Get the router addresses for the default database. - - :return: - :rtype: OrderedSet - """ - return self.get_routing_table_for_default_database().routers - - def get_routing_table_for_default_database(self): - return self.routing_tables[self.workspace_config.database] - - def get_or_create_routing_table(self, database): - with self.refresh_lock: - if database not in self.routing_tables: - self.routing_tables[database] = RoutingTable( - database=database, - routers=self.get_default_database_initial_router_addresses() - ) - return self.routing_tables[database] - - def fetch_routing_info(self, address, database, imp_user, bookmarks, - timeout): - """ Fetch raw routing info from a given router address. - - :param address: router address - :param database: the database name to get routing table for - :param imp_user: the user to impersonate while fetching the routing - table - :type imp_user: str or None - :param bookmarks: iterable of bookmark values after which the routing - info should be fetched - :param timeout: connection acquisition timeout in seconds - - :return: list of routing records, or None if no connection - could be established or if no readers or writers are present - :raise ServiceUnavailable: if the server does not support - routing, or if routing support is broken or outdated - """ - cx = self._acquire(address, timeout) - try: - routing_table = cx.route( - database or self.workspace_config.database, - imp_user or self.workspace_config.impersonated_user, - bookmarks - ) - finally: - self.release(cx) - return routing_table - - def fetch_routing_table(self, *, address, timeout, database, imp_user, - bookmarks): - """ Fetch a routing table from a given router address. - - :param address: router address - :param timeout: seconds - :param database: the database name - :type: str - :param imp_user: the user to impersonate while fetching the routing - table - :type imp_user: str or None - :param bookmarks: bookmarks used when fetching routing table - - :return: a new RoutingTable instance or None if the given router is - currently unable to provide routing information - """ - new_routing_info = None - try: - new_routing_info = self.fetch_routing_info( - address, database, imp_user, bookmarks, timeout - ) - except Neo4jError as e: - # checks if the code is an error that is caused by the client. In - # this case there is no sense in trying to fetch a RT from another - # router. Hence, the driver should fail fast during discovery. - if e.is_fatal_during_discovery(): - raise - except (ServiceUnavailable, SessionExpired): - pass - if not new_routing_info: - log.debug("Failed to fetch routing info %s", address) - return None - else: - servers = new_routing_info[0]["servers"] - ttl = new_routing_info[0]["ttl"] - database = new_routing_info[0].get("db", database) - new_routing_table = RoutingTable.parse_routing_info( - database=database, servers=servers, ttl=ttl - ) - - # Parse routing info and count the number of each type of server - num_routers = len(new_routing_table.routers) - num_readers = len(new_routing_table.readers) - - # num_writers = len(new_routing_table.writers) - # If no writers are available. This likely indicates a temporary state, - # such as leader switching, so we should not signal an error. - - # No routers - if num_routers == 0: - log.debug("No routing servers returned from server %s", address) - return None - - # No readers - if num_readers == 0: - log.debug("No read servers returned from server %s", address) - return None - - # At least one of each is fine, so return this table - return new_routing_table - - def _update_routing_table_from(self, *routers, database=None, imp_user=None, - bookmarks=None, database_callback=None): - """ Try to update routing tables with the given routers. - - :return: True if the routing table is successfully updated, - otherwise False - """ - log.debug("Attempting to update routing table from {}".format(", ".join(map(repr, routers)))) - for router in routers: - for address in router.resolve(resolver=self.pool_config.resolver): - new_routing_table = self.fetch_routing_table( - address=address, - timeout=self.pool_config.connection_timeout, - database=database, imp_user=imp_user, bookmarks=bookmarks - ) - if new_routing_table is not None: - new_databse = new_routing_table.database - self.get_or_create_routing_table(new_databse)\ - .update(new_routing_table) - log.debug( - "[#0000] C: address=%r (%r)", - address, self.routing_tables[new_databse] - ) - if callable(database_callback): - database_callback(new_databse) - return True - self.deactivate(router) - return False - - def update_routing_table(self, *, database, imp_user, bookmarks, - database_callback=None): - """ Update the routing table from the first router able to provide - valid routing information. - - :param database: The database name - :param imp_user: the user to impersonate while fetching the routing - table - :type imp_user: str or None - :param bookmarks: bookmarks used when fetching routing table - :param database_callback: A callback function that will be called with - the database name as only argument when a new routing table has been - acquired. This database name might different from `database` if that - was None and the underlying protocol supports reporting back the - actual database. - - :raise neo4j.exceptions.ServiceUnavailable: - """ - with self.refresh_lock: - # copied because it can be modified - existing_routers = set( - self.get_or_create_routing_table(database).routers - ) - - prefer_initial_routing_address = \ - self.routing_tables[database].initialized_without_writers - - if prefer_initial_routing_address: - # TODO: Test this state - if self._update_routing_table_from( - self.first_initial_routing_address, database=database, - imp_user=imp_user, bookmarks=bookmarks, - database_callback=database_callback - ): - # Why is only the first initial routing address used? - return - if self._update_routing_table_from( - *(existing_routers - {self.first_initial_routing_address}), - database=database, imp_user=imp_user, bookmarks=bookmarks, - database_callback=database_callback - ): - return - - if not prefer_initial_routing_address: - if self._update_routing_table_from( - self.first_initial_routing_address, database=database, - imp_user=imp_user, bookmarks=bookmarks, - database_callback=database_callback - ): - # Why is only the first initial routing address used? - return - - # None of the routers have been successful, so just fail - log.error("Unable to retrieve routing information") - raise ServiceUnavailable("Unable to retrieve routing information") - - def update_connection_pool(self, *, database): - servers = self.get_or_create_routing_table(database).servers() - for address in list(self.connections): - if address.unresolved not in servers: - super(Neo4jPool, self).deactivate(address) - - def ensure_routing_table_is_fresh(self, *, access_mode, database, imp_user, - bookmarks, database_callback=None): - """ Update the routing table if stale. - - This method performs two freshness checks, before and after acquiring - the refresh lock. If the routing table is already fresh on entry, the - method exits immediately; otherwise, the refresh lock is acquired and - the second freshness check that follows determines whether an update - is still required. - - This method is thread-safe. - - :return: `True` if an update was required, `False` otherwise. - """ - from neo4j.api import READ_ACCESS - with self.refresh_lock: - if self.get_or_create_routing_table(database)\ - .is_fresh(readonly=(access_mode == READ_ACCESS)): - # Readers are fresh. - return False - - self.update_routing_table( - database=database, imp_user=imp_user, bookmarks=bookmarks, - database_callback=database_callback - ) - self.update_connection_pool(database=database) - - for database in list(self.routing_tables.keys()): - # Remove unused databases in the routing table - # Remove the routing table after a timeout = TTL + 30s - log.debug("[#0000] C: database=%s", database) - if (self.routing_tables[database].should_be_purged_from_memory() - and database != self.workspace_config.database): - del self.routing_tables[database] - - return True - - def _select_address(self, *, access_mode, database): - from neo4j.api import READ_ACCESS - """ Selects the address with the fewest in-use connections. - """ - with self.refresh_lock: - if access_mode == READ_ACCESS: - addresses = self.routing_tables[database].readers - else: - addresses = self.routing_tables[database].writers - addresses_by_usage = {} - for address in addresses: - addresses_by_usage.setdefault( - self.in_use_connection_count(address), [] - ).append(address) - if not addresses_by_usage: - if access_mode == READ_ACCESS: - raise ReadServiceUnavailable( - "No read service currently available" - ) - else: - raise WriteServiceUnavailable( - "No write service currently available" - ) - return choice(addresses_by_usage[min(addresses_by_usage)]) - - def acquire(self, access_mode=None, timeout=None, database=None, - bookmarks=None): - if access_mode not in (WRITE_ACCESS, READ_ACCESS): - raise ClientError("Non valid 'access_mode'; {}".format(access_mode)) - if not timeout: - raise ClientError("'timeout' must be a float larger than 0; {}" - .format(timeout)) - - from neo4j.api import check_access_mode - access_mode = check_access_mode(access_mode) - with self.refresh_lock: - log.debug("[#0000] C: %r", - self.routing_tables) - self.ensure_routing_table_is_fresh( - access_mode=access_mode, database=database, imp_user=None, - bookmarks=bookmarks - ) - - while True: - try: - # Get an address for a connection that have the fewest in-use - # connections. - address = self._select_address(access_mode=access_mode, - database=database) - except (ReadServiceUnavailable, WriteServiceUnavailable) as err: - raise SessionExpired("Failed to obtain connection towards '%s' server." % access_mode) from err - try: - log.debug("[#0000] C: database=%r address=%r", database, address) - connection = self._acquire(address, timeout=timeout) # should always be a resolved address - except ServiceUnavailable: - self.deactivate(address=address) - else: - return connection - - def deactivate(self, address): - """ Deactivate an address from the connection pool, - if present, remove from the routing table and also closing - all idle connections to that address. - """ - log.debug("[#0000] C: Deactivating address %r", address) - # We use `discard` instead of `remove` here since the former - # will not fail if the address has already been removed. - for database in self.routing_tables.keys(): - self.routing_tables[database].routers.discard(address) - self.routing_tables[database].readers.discard(address) - self.routing_tables[database].writers.discard(address) - log.debug("[#0000] C: table=%r", self.routing_tables) - super(Neo4jPool, self).deactivate(address) - - def on_write_failure(self, address): - """ Remove a writer address from the routing table, if present. - """ - log.debug("[#0000] C: Removing writer %r", address) - for database in self.routing_tables.keys(): - self.routing_tables[database].writers.discard(address) - log.debug("[#0000] C: table=%r", self.routing_tables) - - -def _connect(resolved_address, timeout, keep_alive): - """ - - :param resolved_address: - :param timeout: seconds - :param keep_alive: True or False - :return: socket object - """ - - s = None # The socket - - try: - if len(resolved_address) == 2: - s = socket(AF_INET) - elif len(resolved_address) == 4: - s = socket(AF_INET6) - else: - raise ValueError("Unsupported address {!r}".format(resolved_address)) - t = s.gettimeout() - if timeout: - s.settimeout(timeout) - log.debug("[#0000] C: %s", resolved_address) - s.connect(resolved_address) - s.settimeout(t) - keep_alive = 1 if keep_alive else 0 - s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive) - except SocketTimeout: - log.debug("[#0000] C: %s", resolved_address) - log.debug("[#0000] C: %s", resolved_address) - _close_socket(s) - raise ServiceUnavailable("Timed out trying to establish connection to {!r}".format(resolved_address)) - except OSError as error: - log.debug("[#0000] C: %s %s", type(error).__name__, - " ".join(map(repr, error.args))) - log.debug("[#0000] C: %s", resolved_address) - s.close() - raise ServiceUnavailable("Failed to establish connection to {!r} (reason {})".format(resolved_address, error)) - else: - return s - - -def _secure(s, host, ssl_context): - local_port = s.getsockname()[1] - # Secure the connection if an SSL context has been provided - if ssl_context: - last_error = None - log.debug("[#%04X] C: %s", local_port, host) - try: - sni_host = host if HAS_SNI and host else None - s = ssl_context.wrap_socket(s, server_hostname=sni_host) - except (OSError, SSLError, CertificateError) as cause: - raise BoltSecurityError( - message="Failed to establish encrypted connection.", - address=(host, local_port) - ) from cause - # Check that the server provides a certificate - der_encoded_server_certificate = s.getpeercert(binary_form=True) - if der_encoded_server_certificate is None: - raise BoltProtocolError( - "When using an encrypted socket, the server should always " - "provide a certificate", address=(host, local_port) - ) - return s - return s - - -def _handshake(s, resolved_address): - """ - - :param s: Socket - :param resolved_address: - - :return: (socket, version, client_handshake, server_response_data) - """ - local_port = s.getsockname()[1] - - # TODO: Optimize logging code - handshake = Bolt.get_handshake() - import struct - handshake = struct.unpack(">16B", handshake) - handshake = [handshake[i:i + 4] for i in range(0, len(handshake), 4)] - - supported_versions = [("0x%02X%02X%02X%02X" % (vx[0], vx[1], vx[2], vx[3])) for vx in handshake] - - log.debug("[#%04X] C: 0x%08X", local_port, int.from_bytes(Bolt.MAGIC_PREAMBLE, byteorder="big")) - log.debug("[#%04X] C: %s %s %s %s", local_port, *supported_versions) - - data = Bolt.MAGIC_PREAMBLE + Bolt.get_handshake() - s.sendall(data) - - # Handle the handshake response - ready_to_read = False - with selectors.DefaultSelector() as selector: - selector.register(s, selectors.EVENT_READ) - selector.select(1) - try: - data = s.recv(4) - except OSError: - raise ServiceUnavailable("Failed to read any data from server {!r} " - "after connected".format(resolved_address)) - data_size = len(data) - if data_size == 0: - # If no data is returned after a successful select - # response, the server has closed the connection - log.debug("[#%04X] S: ", local_port) - _close_socket(s) - raise ServiceUnavailable("Connection to {address} closed without handshake response".format(address=resolved_address)) - if data_size != 4: - # Some garbled data has been received - log.debug("[#%04X] S: @*#!", local_port) - s.close() - raise BoltProtocolError("Expected four byte Bolt handshake response from %r, received %r instead; check for incorrect port number" % (resolved_address, data), address=resolved_address) - elif data == b"HTTP": - log.debug("[#%04X] S: ", local_port) - _close_socket(s) - raise ServiceUnavailable("Cannot to connect to Bolt service on {!r} " - "(looks like HTTP)".format(resolved_address)) - agreed_version = data[-1], data[-2] - log.debug("[#%04X] S: 0x%06X%02X", local_port, agreed_version[1], agreed_version[0]) - return s, agreed_version, handshake, data - - -def _close_socket(socket_): - try: - socket_.shutdown(SHUT_RDWR) - socket_.close() - except OSError: - pass - - -def connect(address, *, timeout, custom_resolver, ssl_context, keep_alive): - """ Connect and perform a handshake and return a valid Connection object, - assuming a protocol version can be agreed. - """ - errors = [] - # Establish a connection to the host and port specified - # Catches refused connections see: - # https://docs.python.org/2/library/errno.html - - resolved_addresses = Address(address).resolve(resolver=custom_resolver) - for resolved_address in resolved_addresses: - s = None - try: - s = _connect(resolved_address, timeout, keep_alive) - s = _secure(s, resolved_address.host_name, ssl_context) - return _handshake(s, resolved_address) - except (BoltError, DriverError, OSError) as error: - try: - local_port = s.getsockname()[1] - except (OSError, AttributeError): - local_port = 0 - err_str = error.__class__.__name__ - if str(error): - err_str += ": " + str(error) - log.debug("[#%04X] C: %s", local_port, err_str) - if s: - _close_socket(s) - errors.append(error) - except Exception: - if s: - _close_socket(s) - raise - if not errors: - raise ServiceUnavailable( - "Couldn't connect to %s (resolved to %s)" % ( - str(address), tuple(map(str, resolved_addresses))) - ) - else: - raise ServiceUnavailable( - "Couldn't connect to %s (resolved to %s):\n%s" % ( - str(address), tuple(map(str, resolved_addresses)), - "\n".join(map(str, errors)) - ) - ) from errors[0] - - -def check_supported_server_product(agent): - """ Checks that a server product is supported by the driver by - looking at the server agent string. - - :param agent: server agent string to check for validity - :raises UnsupportedServerProduct: if the product is not supported - """ - if not agent.startswith("Neo4j/"): - raise UnsupportedServerProduct(agent) diff --git a/neo4j/meta.py b/neo4j/meta.py index dd727eab..37d8f8d0 100644 --- a/neo4j/meta.py +++ b/neo4j/meta.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -31,7 +28,10 @@ def get_user_agent(): """ Obtain the default user agent string sent to the server after a successful handshake. """ - from sys import platform, version_info + from sys import ( + platform, + version_info, + ) template = "neo4j-python/{} Python/{}.{}.{}-{}-{} ({})" fields = (version,) + tuple(version_info) + (platform,) return template.format(*fields) diff --git a/neo4j/packstream.py b/neo4j/packstream.py index 406d761e..e453013c 100644 --- a/neo4j/packstream.py +++ b/neo4j/packstream.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -20,8 +17,11 @@ from codecs import decode -from io import BytesIO -from struct import pack as struct_pack, unpack as struct_unpack +from struct import ( + pack as struct_pack, + unpack as struct_unpack, +) + PACKED_UINT_8 = [struct_pack(">B", value) for value in range(0x100)] PACKED_UINT_16 = [struct_pack(">H", value) for value in range(0x10000)] @@ -472,14 +472,3 @@ def pop_u16(self): return value else: return -1 - - def receive(self, sock, n_bytes): - end = self.used + n_bytes - if end > len(self.data): - self.data += bytearray(end - len(self.data)) - view = memoryview(self.data) - while self.used < end: - n = sock.recv_into(view[self.used:end], end - self.used) - if n == 0: - raise OSError("No data") - self.used += n diff --git a/neo4j/routing.py b/neo4j/routing.py index 8303f4c2..b0546d12 100644 --- a/neo4j/routing.py +++ b/neo4j/routing.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,12 +16,11 @@ # limitations under the License. -from collections import OrderedDict from collections.abc import MutableSet from logging import getLogger from time import perf_counter -from neo4j.addressing import Address +from .addressing import Address log = getLogger("neo4j") @@ -33,7 +29,8 @@ class OrderedSet(MutableSet): def __init__(self, elements=()): - self._elements = OrderedDict.fromkeys(elements) + # dicts keep insertion order starting with Python 3.7 + self._elements = dict.fromkeys(elements) self._current = None def __repr__(self): @@ -70,12 +67,12 @@ def remove(self, element): raise ValueError(element) def update(self, elements=()): - self._elements.update(OrderedDict.fromkeys(elements)) + self._elements.update(dict.fromkeys(elements)) def replace(self, elements=()): e = self._elements e.clear() - e.update(OrderedDict.fromkeys(elements)) + e.update(dict.fromkeys(elements)) class RoutingTable: diff --git a/neo4j/spatial/__init__.py b/neo4j/spatial/__init__.py index 36e8d66f..5e250338 100644 --- a/neo4j/spatial/__init__.py +++ b/neo4j/spatial/__init__.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -24,21 +21,21 @@ """ -from threading import Lock - -from neo4j.packstream import Structure - - __all__ = [ - "Point", "CartesianPoint", - "WGS84Point", - "point_type", - "hydrate_point", "dehydrate_point", + "hydrate_point", + "Point", + "point_type", + "WGS84Point", ] +from threading import Lock + +from neo4j.packstream import Structure + + # SRID to subclass mappings __srid_table = {} __srid_table_lock = Lock() diff --git a/neo4j/time/__init__.py b/neo4j/time/__init__.py index 0a76280f..4e070d7a 100644 --- a/neo4j/time/__init__.py +++ b/neo4j/time/__init__.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# coding: utf-8 - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -24,19 +21,18 @@ as a number of utility functions. """ + from contextlib import contextmanager from datetime import ( - timedelta, date, - time, datetime, + time, + timedelta, ) from decimal import ( Decimal, localcontext, - ROUND_DOWN, ROUND_HALF_EVEN, - ROUND_HALF_UP, ) from functools import total_ordering from re import compile as re_compile @@ -48,18 +44,18 @@ from neo4j.meta import ( deprecated, - deprecation_warn + deprecation_warn, ) from neo4j.time.arithmetic import ( nano_add, nano_div, - symmetric_divmod, round_half_to_even, + symmetric_divmod, ) from neo4j.time.metaclasses import ( + DateTimeType, DateType, TimeType, - DateTimeType, ) @@ -84,19 +80,15 @@ def inner(*args, **kwargs): MIN_INT64 = -(2 ** 63) MAX_INT64 = (2 ** 63) - 1 +#: The smallest year number allowed in a :class:`.Date` or :class:`.DateTime` +#: object to be compatible with :class:`datetime.date` and +#: :class:`datetime.datetime`. MIN_YEAR = 1 -""" -The smallest year number allowed in a :class:`.Date` or :class:`.DateTime` -object to be compatible with :class:`datetime.date` and -:class:`datetime.datetime`. -""" +#: The largest year number allowed in a :class:`.Date` or :class:`.DateTime` +#: object to be compatible with :class:`datetime.date` and +#: :class:`datetime.datetime`. MAX_YEAR = 9999 -""" -The largest year number allowed in a :class:`.Date` or :class:`.DateTime` -object to be compatible with :class:`datetime.date` and -:class:`datetime.datetime`. -""" DATE_ISO_PATTERN = re_compile(r"^(\d{4})-(\d{2})-(\d{2})$") TIME_ISO_PATTERN = re_compile( @@ -1387,11 +1379,9 @@ def __format__(self, format_spec): Date.max = Date.from_ordinal(3652059) Date.resolution = Duration(days=1) +#: A :class:`neo4j.time.Date` instance set to `0000-00-00`. +#: This has an ordinal value of `0`. ZeroDate = object.__new__(Date) -""" -A :class:`neo4j.time.Date` instance set to `0000-00-00`. -This has an ordinal value of `0`. -""" class Time(metaclass=TimeType): @@ -2001,17 +1991,13 @@ def __format__(self, format_spec): Time.max = Time(hour=23, minute=59, second=59, nanosecond=999999999) Time.resolution = Duration(nanoseconds=1) +#: A :class:`.Time` instance set to `00:00:00`. +#: This has a :attr:`.ticks_ns` value of `0`. Midnight = Time.min -""" -A :class:`.Time` instance set to `00:00:00`. -This has a :attr:`.ticks_ns` value of `0`. -""" +#: A :class:`.Time` instance set to `12:00:00`. +#: This has a :attr:`.ticks_ns` value of `43200000000000`. Midday = Time(hour=12) -""" -A :class:`.Time` instance set to `12:00:00`. -This has a :attr:`.ticks_ns` value of `43200000000000`. -""" @total_ordering @@ -2621,12 +2607,9 @@ def __format__(self, format_spec): DateTime.max = DateTime.combine(Date.max, Time.max) DateTime.resolution = Time.resolution +#: A :class:`.DateTime` instance set to `0000-00-00T00:00:00`. +#: This has a :class:`.Date` component equal to :attr:`ZeroDate` and a Never = DateTime.combine(ZeroDate, Midnight) -""" -A :class:`.DateTime` instance set to `0000-00-00T00:00:00`. -This has a :class:`.Date` component equal to :attr:`ZeroDate` and a -:class:`.Time` component equal to :attr:`Midnight`. -""" +#: A :class:`.DateTime` instance set to `1970-01-01T00:00:00`. UnixEpoch = DateTime(1970, 1, 1, 0, 0, 0) -"""A :class:`.DateTime` instance set to `1970-01-01T00:00:00`.""" diff --git a/neo4j/time/__main__.py b/neo4j/time/__main__.py index 447b375c..9d1858b5 100644 --- a/neo4j/time/__main__.py +++ b/neo4j/time/__main__.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# coding: utf-8 # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] @@ -20,7 +19,11 @@ def main(): - from neo4j.time import Clock, DateTime, UnixEpoch + from neo4j.time import ( + Clock, + DateTime, + UnixEpoch, + ) clock = Clock() time = clock.utc_time() print("Using %s" % type(clock).__name__) diff --git a/neo4j/time/arithmetic.py b/neo4j/time/arithmetic.py index c39ea2c2..71c43f9f 100644 --- a/neo4j/time/arithmetic.py +++ b/neo4j/time/arithmetic.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# coding: utf-8 - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,9 +16,6 @@ # limitations under the License. -from math import isnan - - def nano_add(x, y): """ diff --git a/neo4j/time/clock_implementations.py b/neo4j/time/clock_implementations.py index fccbd433..de4d0ba7 100644 --- a/neo4j/time/clock_implementations.py +++ b/neo4j/time/clock_implementations.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# coding: utf-8 - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,10 +16,19 @@ # limitations under the License. -from ctypes import CDLL, Structure, c_longlong, c_long, byref +from ctypes import ( + byref, + c_long, + c_longlong, + CDLL, + Structure, +) from platform import uname -from neo4j.time import Clock, ClockTime +from neo4j.time import ( + Clock, + ClockTime, +) from neo4j.time.arithmetic import nano_divmod diff --git a/neo4j/time/hydration.py b/neo4j/time/hydration.py index 69a3ddcf..f522ebd7 100644 --- a/neo4j/time/hydration.py +++ b/neo4j/time/hydration.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -20,17 +17,17 @@ from datetime import ( - time, datetime, + time, timedelta, ) from neo4j.packstream import Structure from neo4j.time import ( - Duration, Date, - Time, DateTime, + Duration, + Time, ) @@ -114,7 +111,10 @@ def hydrate_datetime(seconds, nanoseconds, tz=None): :param tz: :return: datetime """ - from pytz import FixedOffset, timezone + from pytz import ( + FixedOffset, + timezone, + ) minutes, seconds = map(int, divmod(seconds, 60)) hours, minutes = map(int, divmod(minutes, 60)) days, hours = map(int, divmod(hours, 24)) diff --git a/neo4j/time/metaclasses.py b/neo4j/time/metaclasses.py index 278a0fb6..ad5e53ff 100644 --- a/neo4j/time/metaclasses.py +++ b/neo4j/time/metaclasses.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# coding: utf-8 - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # diff --git a/neo4j/work/__init__.py b/neo4j/work/__init__.py index 154f7381..a4254afa 100644 --- a/neo4j/work/__init__.py +++ b/neo4j/work/__init__.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,88 +16,19 @@ # limitations under the License. -from neo4j.conf import WorkspaceConfig -from neo4j.exceptions import ServiceUnavailable -from neo4j.io import Neo4jPool - - -class Workspace: - - def __init__(self, pool, config): - assert isinstance(config, WorkspaceConfig) - self._pool = pool - self._config = config - self._connection = None - self._connection_access_mode = None - # Sessions are supposed to cache the database on which to operate. - self._cached_database = False - self._bookmarks = None - - def __del__(self): - try: - self.close() - except OSError: - pass - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.close() - - def _set_cached_database(self, database): - self._cached_database = True - self._config.database = database - - def _connect(self, access_mode): - if self._connection: - # TODO: Investigate this - # log.warning("FIXME: should always disconnect before connect") - self._connection.send_all() - self._connection.fetch_all() - self._disconnect() - if not self._cached_database: - if (self._config.database is not None - or not isinstance(self._pool, Neo4jPool)): - self._set_cached_database(self._config.database) - else: - # This is the first time we open a connection to a server in a - # cluster environment for this session without explicitly - # configured database. Hence, we request a routing table update - # to try to fetch the home database. If provided by the server, - # we shall use this database explicitly for all subsequent - # actions within this session. - self._pool.update_routing_table( - database=self._config.database, - imp_user=self._config.impersonated_user, - bookmarks=self._bookmarks, - database_callback=self._set_cached_database - ) - self._connection = self._pool.acquire( - access_mode=access_mode, - timeout=self._config.connection_acquisition_timeout, - database=self._config.database, - bookmarks=self._bookmarks - ) - self._connection_access_mode = access_mode - - def _disconnect(self, sync=False): - if self._connection: - if sync: - try: - self._connection.send_all() - self._connection.fetch_all() - except (WorkspaceError, ServiceUnavailable): - pass - if self._connection: - self._pool.release(self._connection) - self._connection = None - self._connection_access_mode = None - - def close(self): - self._disconnect(sync=True) - +from .query import ( + Query, + unit_of_work, +) +from .summary import ( + ResultSummary, + SummaryCounters, +) -class WorkspaceError(Exception): - pass +__all__ = [ + "Query", + "ResultSummary", + "SummaryCounters", + "unit_of_work", +] diff --git a/neo4j/work/pipelining.py b/neo4j/work/pipelining.py deleted file mode 100644 index ccca8420..00000000 --- a/neo4j/work/pipelining.py +++ /dev/null @@ -1,136 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [http://neo4j.com] -# -# This file is part of Neo4j. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from collections import deque -from threading import Thread, Lock -from time import sleep - -from neo4j.work import Workspace -from neo4j.conf import WorkspaceConfig -from neo4j.api import ( - WRITE_ACCESS, -) - -class PipelineConfig(WorkspaceConfig): - - #: - flush_every = 8192 # bytes - - -class Pipeline(Workspace): - - def __init__(self, pool, config): - assert isinstance(config, PipelineConfig) - super(Pipeline, self).__init__(pool, config) - self._connect(WRITE_ACCESS) - self._flush_every = config.flush_every - self._data = deque() - self._pull_lock = Lock() - - def push(self, statement, parameters=None): - self._connection.run(statement, parameters) - self._connection.pull(on_records=self._data.extend) - output_buffer_size = len(self._connection.outbox.view()) - if output_buffer_size >= self._flush_every: - self._connection.send_all() - - def _results_generator(self): - results_returned_count = 0 - try: - summary = 0 - while summary == 0: - _, summary = self._connection.fetch_message() - summary = 0 - while summary == 0: - detail, summary = self._connection.fetch_message() - for n in range(detail): - response = self._data.popleft() - results_returned_count += 1 - yield response - finally: - self._pull_lock.release() - - def pull(self): - """Returns a generator containing the results of the next query in the pipeline""" - # n.b. pull is now somewhat misleadingly named because it doesn't do anything - # the connection isn't touched until you try and iterate the generator we return - lock_acquired = self._pull_lock.acquire(blocking=False) - if not lock_acquired: - raise PullOrderException() - return self._results_generator() - - -class PullOrderException(Exception): - """Raise when calling pull if a previous pull result has not been fully consumed""" - - -class Pusher(Thread): - - def __init__(self, pipeline): - super(Pusher, self).__init__() - self.pipeline = pipeline - self.running = True - self.count = 0 - - def run(self): - while self.running: - self.pipeline.push("RETURN $x", {"x": self.count}) - self.count += 1 - - -class Puller(Thread): - - def __init__(self, pipeline): - super(Puller, self).__init__() - self.pipeline = pipeline - self.running = True - self.count = 0 - - def run(self): - while self.running: - for _ in self.pipeline.pull(): - pass # consume and discard records - self.count += 1 - - -def main(): - from neo4j import Driver - # from neo4j.bolt.diagnostics import watch - # watch("neobolt") - with Driver("bolt://", auth=("neo4j", "password")) as dx: - p = dx.pipeline(flush_every=1024) - pusher = Pusher(p) - puller = Puller(p) - try: - pusher.start() - puller.start() - while True: - print("sent %d, received %d, backlog %d" % (pusher.count, puller.count, pusher.count - puller.count)) - sleep(1) - except KeyboardInterrupt: - pusher.running = False - pusher.join() - puller.running = False - puller.join() - - -if __name__ == "__main__": - main() diff --git a/neo4j/work/query.py b/neo4j/work/query.py new file mode 100644 index 00000000..acefa6e7 --- /dev/null +++ b/neo4j/work/query.py @@ -0,0 +1,75 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class Query: + """ Create a new query. + + :param text: The query text. + :type text: str + :param metadata: metadata attached to the query. + :type metadata: dict + :param timeout: seconds. + :type timeout: int + """ + def __init__(self, text, metadata=None, timeout=None): + self.text = text + + self.metadata = metadata + self.timeout = timeout + + def __str__(self): + return str(self.text) + + +def unit_of_work(metadata=None, timeout=None): + """This function is a decorator for transaction functions that allows extra control over how the transaction is carried out. + + For example, a timeout may be applied:: + + @unit_of_work(timeout=100) + def count_people_tx(tx): + result = tx.run("MATCH (a:Person) RETURN count(a) AS persons") + record = result.single() + return record["persons"] + + :param metadata: + a dictionary with metadata. + Specified metadata will be attached to the executing transaction and visible in the output of ``dbms.listQueries`` and ``dbms.listTransactions`` procedures. + It will also get logged to the ``query.log``. + This functionality makes it easier to tag transactions and is equivalent to ``dbms.setTXMetaData`` procedure, see https://neo4j.com/docs/operations-manual/current/reference/procedures/ for procedure reference. + :type metadata: dict + + :param timeout: + the transaction timeout in seconds. + Transactions that execute longer than the configured timeout will be terminated by the database. + This functionality allows to limit query/transaction execution time. + Specified timeout overrides the default timeout configured in the database using ``dbms.transaction.timeout`` setting. + Value should not represent a duration of zero or negative duration. + :type timeout: int + """ + + def wrapper(f): + + def wrapped(*args, **kwargs): + return f(*args, **kwargs) + + wrapped.metadata = metadata + wrapped.timeout = timeout + return wrapped + + return wrapper diff --git a/neo4j/work/summary.py b/neo4j/work/summary.py index c0955b00..600cfe2e 100644 --- a/neo4j/work/summary.py +++ b/neo4j/work/summary.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,9 +16,8 @@ # limitations under the License. -from collections import namedtuple +from .._exceptions import BoltProtocolError -from neo4j._exceptions import BoltProtocolError BOLT_VERSION_1 = 1 BOLT_VERSION_2 = 2 diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 00000000..cc40eecd --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,10 @@ +# auto-generate sync driver from async code +unasync>=0.5.0 +pre-commit>=2.15.0 +isort>=5.10.0 + +# needed for running tests +-r tests/requirements.txt + +# production dependencies +-r requirements.txt diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..c63faed9 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,13 @@ +[isort] +combine_as_imports=true +ensure_newline_before_comments=true +force_grid_wrap=2 +force_sort_within_sections=true +include_trailing_comma=true +#lines_before_imports=2 # currently broken +lines_after_imports=2 +lines_between_sections=1 +multi_line_output=3 +order_by_type=false +remove_redundant_aliases=true +use_parentheses=true diff --git a/setup.py b/setup.py index 24384398..f22a1842 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- encoding: utf-8 -*- # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] @@ -20,9 +19,17 @@ import os -from setuptools import find_packages, setup -from neo4j.meta import package, version +from setuptools import ( + find_packages, + setup, +) + +from neo4j.meta import ( + package, + version, +) + install_requires = [ "pytz", @@ -44,7 +51,8 @@ } packages = find_packages(exclude=["tests"]) -readme_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "README.rst")) +readme_path = os.path.abspath(os.path.join(os.path.dirname(__file__), + "README.rst")) with open(readme_path, mode="r", encoding="utf-8") as fr: readme = fr.read() diff --git a/testkit/.dockerignore b/testkit/.dockerignore index fe7d40a5..f104652b 100644 --- a/testkit/.dockerignore +++ b/testkit/.dockerignore @@ -1,2 +1 @@ *.py - diff --git a/testkit/backend.py b/testkit/backend.py index ca5a4779..5366afe9 100644 --- a/testkit/backend.py +++ b/testkit/backend.py @@ -1,8 +1,30 @@ +#!/usr/bin/env python + +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os import subprocess import sys + if __name__ == "__main__": - subprocess.check_call( - ["python", "-m", "testkitbackend"], - stdout=sys.stdout, stderr=sys.stderr - ) + cmd = ["python", "-m", "testkitbackend"] + if "TEST_BACKEND_SERVER" in os.environ: + cmd.append(os.environ["TEST_BACKEND_SERVER"]) + subprocess.check_call(cmd, stdout=sys.stdout, stderr=sys.stderr) diff --git a/testkit/build.py b/testkit/build.py index 132452e5..5de76f14 100644 --- a/testkit/build.py +++ b/testkit/build.py @@ -1,7 +1,29 @@ +#!/usr/bin/env python + +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + """ Executed in Go driver container. Responsible for building driver and test backend. """ + + import subprocess diff --git a/testkit/integration.py b/testkit/integration.py index 244a421f..c4b1c685 100644 --- a/testkit/integration.py +++ b/testkit/integration.py @@ -1,2 +1,22 @@ +#!/usr/bin/env python + +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + if __name__ == "__main__": pass diff --git a/testkit/stress.py b/testkit/stress.py index bc9718f2..3d171c49 100644 --- a/testkit/stress.py +++ b/testkit/stress.py @@ -1,5 +1,23 @@ -import subprocess -import os +#!/usr/bin/env python + +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import sys diff --git a/testkit/unittests.py b/testkit/unittests.py index 9561611d..a3730195 100644 --- a/testkit/unittests.py +++ b/testkit/unittests.py @@ -1,3 +1,23 @@ +#!/usr/bin/env python + +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import subprocess @@ -7,5 +27,4 @@ def run(args): if __name__ == "__main__": - run([ - "python", "-m", "tox", "-c", "tox-unit.ini"]) + run(["python", "-m", "tox", "-c", "tox-unit.ini"]) diff --git a/testkitbackend/__init__.py b/testkitbackend/__init__.py index e69de29b..b81a309d 100644 --- a/testkitbackend/__init__.py +++ b/testkitbackend/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/testkitbackend/__main__.py b/testkitbackend/__main__.py index 3e8a4b87..2dec1adb 100644 --- a/testkitbackend/__main__.py +++ b/testkitbackend/__main__.py @@ -1,6 +1,52 @@ -from .server import Server +#!/usr/bin/env python -if __name__ == "__main__": +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import sys + +from .server import ( + AsyncServer, + Server, +) + + +def sync_main(): server = Server(("0.0.0.0", 9876)) while True: server.handle_request() + + +def async_main(): + async def main(): + server = AsyncServer(("0.0.0.0", 9876)) + await server.start() + try: + await server.serve_forever() + finally: + server.stop() + + asyncio.run(main()) + + +if __name__ == "__main__": + if len(sys.argv) == 2 and sys.argv[1].lower().strip() == "async": + async_main() + else: + sync_main() diff --git a/testkitbackend/_async/__init__.py b/testkitbackend/_async/__init__.py new file mode 100644 index 00000000..b81a309d --- /dev/null +++ b/testkitbackend/_async/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/testkitbackend/_async/backend.py b/testkitbackend/_async/backend.py new file mode 100644 index 00000000..014e99e0 --- /dev/null +++ b/testkitbackend/_async/backend.py @@ -0,0 +1,145 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +from inspect import ( + getmembers, + isfunction, +) +from json import ( + dumps, + loads, +) +import traceback + +from neo4j._exceptions import BoltError +from neo4j.exceptions import ( + DriverError, + Neo4jError, + UnsupportedServerProduct, +) + +from . import requests +from .._driver_logger import ( + buffer_handler, + log, +) +from ..backend import Request + + +class AsyncBackend: + def __init__(self, rd, wr): + self._rd = rd + self._wr = wr + self.drivers = {} + self.custom_resolutions = {} + self.dns_resolutions = {} + self.sessions = {} + self.results = {} + self.errors = {} + self.transactions = {} + self.errors = {} + self.key = 0 + # Collect all request handlers + self._requestHandlers = dict( + [m for m in getmembers(requests, isfunction)]) + + def next_key(self): + self.key = self.key + 1 + return self.key + + async def process_request(self): + """ Reads next request from the stream and processes it. + """ + in_request = False + request = "" + async for line in self._rd: + # Remove trailing newline + line = line.decode('UTF-8').rstrip() + if line == "#request begin": + in_request = True + elif line == "#request end": + await self._process(request) + return True + else: + if in_request: + request = request + line + return False + + async def _process(self, request): + """ Process a received request by retrieving handler that + corresponds to the request name. + """ + try: + request = loads(request, object_pairs_hook=Request) + if not isinstance(request, Request): + raise Exception("Request is not an object") + name = request.get('name', 'invalid') + handler = self._requestHandlers.get(name) + if not handler: + raise Exception("No request handler for " + name) + data = request["data"] + log.info("<<< " + name + dumps(data)) + await handler(self, data) + unsused_keys = request.unseen_keys + if unsused_keys: + raise NotImplementedError( + "Backend does not support some properties of the " + name + + " request: " + ", ".join(unsused_keys) + ) + except (Neo4jError, DriverError, UnsupportedServerProduct, + BoltError) as e: + log.debug(traceback.format_exc()) + if isinstance(e, Neo4jError): + msg = "" if e.message is None else str(e.message) + else: + msg = str(e.args[0]) if e.args else "" + + key = self.next_key() + self.errors[key] = e + payload = {"id": key, "errorType": str(type(e)), "msg": msg} + if isinstance(e, Neo4jError): + payload["code"] = e.code + await self.send_response("DriverError", payload) + except requests.FrontendError as e: + await self.send_response("FrontendError", {"msg": str(e)}) + except Exception: + tb = traceback.format_exc() + log.error(tb) + await self.send_response("BackendError", {"msg": tb}) + + async def send_response(self, name, data): + """ Sends a response to backend. + """ + with buffer_handler.lock: + log_output = buffer_handler.stream.getvalue() + buffer_handler.stream.truncate(0) + buffer_handler.stream.seek(0) + if not log_output.endswith("\n"): + log_output += "\n" + self._wr.write(log_output.encode("utf-8")) + response = {"name": name, "data": data} + response = dumps(response) + log.info(">>> " + name + dumps(data)) + self._wr.write(b"#response begin\n") + self._wr.write(bytes(response+"\n", "utf-8")) + self._wr.write(b"#response end\n") + if isinstance(self._wr, asyncio.StreamWriter): + await self._wr.drain() + else: + self._wr.flush() diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py new file mode 100644 index 00000000..e80e39c5 --- /dev/null +++ b/testkitbackend/_async/requests.py @@ -0,0 +1,444 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +from os import path + +import neo4j +from neo4j._async_compat.util import AsyncUtil + +from .. import ( + fromtestkit, + totestkit, +) + + +class FrontendError(Exception): + pass + + +def load_config(): + config_path = path.join(path.dirname(__file__), "..", "test_config.json") + with open(config_path, "r") as fd: + config = json.load(fd) + skips = config["skips"] + features = [k for k, v in config["features"].items() if v is True] + import ssl + if ssl.HAS_TLSv1_3: + features += ["Feature:TLS:1.3"] + return skips, features + + +SKIPPED_TESTS, FEATURES = load_config() + + +async def StartTest(backend, data): + if data["testName"] in SKIPPED_TESTS: + await backend.send_response("SkipTest", { + "reason": SKIPPED_TESTS[data["testName"]] + }) + else: + await backend.send_response("RunTest", {}) + + +async def GetFeatures(backend, data): + await backend.send_response("FeatureList", {"features": FEATURES}) + + +async def NewDriver(backend, data): + auth_token = data["authorizationToken"]["data"] + data["authorizationToken"].mark_item_as_read_if_equals( + "name", "AuthorizationToken" + ) + scheme = auth_token["scheme"] + if scheme == "basic": + auth = neo4j.basic_auth( + auth_token["principal"], auth_token["credentials"], + realm=auth_token.get("realm", None) + ) + elif scheme == "kerberos": + auth = neo4j.kerberos_auth(auth_token["credentials"]) + elif scheme == "bearer": + auth = neo4j.bearer_auth(auth_token["credentials"]) + else: + auth = neo4j.custom_auth( + auth_token["principal"], auth_token["credentials"], + auth_token["realm"], auth_token["scheme"], + **auth_token.get("parameters", {}) + ) + auth_token.mark_item_as_read("parameters", recursive=True) + resolver = None + if data["resolverRegistered"] or data["domainNameResolverRegistered"]: + resolver = resolution_func(backend, data["resolverRegistered"], + data["domainNameResolverRegistered"]) + connection_timeout = data.get("connectionTimeoutMs") + if connection_timeout is not None: + connection_timeout /= 1000 + max_transaction_retry_time = data.get("maxTxRetryTimeMs") + if max_transaction_retry_time is not None: + max_transaction_retry_time /= 1000 + data.mark_item_as_read("domainNameResolverRegistered") + driver = neo4j.AsyncGraphDatabase.driver( + data["uri"], auth=auth, user_agent=data["userAgent"], + resolver=resolver, connection_timeout=connection_timeout, + fetch_size=data.get("fetchSize"), + max_transaction_retry_time=max_transaction_retry_time, + ) + key = backend.next_key() + backend.drivers[key] = driver + await backend.send_response("Driver", {"id": key}) + + +async def VerifyConnectivity(backend, data): + driver_id = data["driverId"] + driver = backend.drivers[driver_id] + await driver.verify_connectivity() + await backend.send_response("Driver", {"id": driver_id}) + + +async def CheckMultiDBSupport(backend, data): + driver_id = data["driverId"] + driver = backend.drivers[driver_id] + await backend.send_response("MultiDBSupport", { + "id": backend.next_key(), "available": await driver.supports_multi_db() + }) + + +def resolution_func(backend, custom_resolver=False, custom_dns_resolver=False): + # This solution (putting custom resolution together with DNS resolution + # into one function only works because the Python driver calls the custom + # resolver function for every connection, which is not true for all + # drivers. Properly exposing a way to change the DNS lookup behavior is not + # possible without changing the driver's code. + assert custom_resolver or custom_dns_resolver + + async def resolve(address): + addresses = [":".join(map(str, address))] + if custom_resolver: + key = backend.next_key() + await backend.send_response("ResolverResolutionRequired", { + "id": key, + "address": addresses[0] + }) + if not await backend.process_request(): + # connection was closed before end of next message + return [] + if key not in backend.custom_resolutions: + raise RuntimeError( + "Backend did not receive expected " + "ResolverResolutionCompleted message for id %s" % key + ) + addresses = backend.custom_resolutions.pop(key) + if custom_dns_resolver: + dns_resolved_addresses = [] + for address in addresses: + key = backend.next_key() + address = address.rsplit(":", 1) + await backend.send_response("DomainNameResolutionRequired", { + "id": key, + "name": address[0] + }) + if not await backend.process_request(): + # connection was closed before end of next message + return [] + if key not in backend.dns_resolutions: + raise RuntimeError( + "Backend did not receive expected " + "DomainNameResolutionCompleted message for id %s" % key + ) + dns_resolved_addresses += list(map( + lambda a: ":".join((a, *address[1:])), + backend.dns_resolutions.pop(key) + )) + + addresses = dns_resolved_addresses + + return list(map(neo4j.Address.parse, addresses)) + + return resolve + + +async def ResolverResolutionCompleted(backend, data): + backend.custom_resolutions[data["requestId"]] = data["addresses"] + + +async def DomainNameResolutionCompleted(backend, data): + backend.dns_resolutions[data["requestId"]] = data["addresses"] + + +async def DriverClose(backend, data): + key = data["driverId"] + driver = backend.drivers[key] + await driver.close() + await backend.send_response("Driver", {"id": key}) + + +class SessionTracker: + """ Keeps some extra state about the tracked session + """ + + def __init__(self, session): + self.session = session + self.state = "" + self.error_id = "" + + +async def NewSession(backend, data): + driver = backend.drivers[data["driverId"]] + access_mode = data["accessMode"] + if access_mode == "r": + access_mode = neo4j.READ_ACCESS + elif access_mode == "w": + access_mode = neo4j.WRITE_ACCESS + else: + raise ValueError("Unknown access mode:" + access_mode) + config = { + "default_access_mode": access_mode, + "bookmarks": data["bookmarks"], + "database": data["database"], + "fetch_size": data.get("fetchSize", None), + "impersonated_user": data.get("impersonatedUser", None), + + } + session = driver.session(**config) + key = backend.next_key() + backend.sessions[key] = SessionTracker(session) + await backend.send_response("Session", {"id": key}) + + +async def SessionRun(backend, data): + session = backend.sessions[data["sessionId"]].session + query, params = fromtestkit.to_query_and_params(data) + result = await session.run(query, parameters=params) + key = backend.next_key() + backend.results[key] = result + await backend.send_response("Result", {"id": key, "keys": result.keys()}) + + +async def SessionClose(backend, data): + key = data["sessionId"] + session = backend.sessions[key].session + await session.close() + del backend.sessions[key] + await backend.send_response("Session", {"id": key}) + + +async def SessionBeginTransaction(backend, data): + key = data["sessionId"] + session = backend.sessions[key].session + metadata, timeout = fromtestkit.to_meta_and_timeout(data) + tx = await session.begin_transaction(metadata=metadata, timeout=timeout) + key = backend.next_key() + backend.transactions[key] = tx + await backend.send_response("Transaction", {"id": key}) + + +async def SessionReadTransaction(backend, data): + await transactionFunc(backend, data, True) + + +async def SessionWriteTransaction(backend, data): + await transactionFunc(backend, data, False) + + +async def transactionFunc(backend, data, is_read): + key = data["sessionId"] + session_tracker = backend.sessions[key] + session = session_tracker.session + metadata, timeout = fromtestkit.to_meta_and_timeout(data) + + @neo4j.unit_of_work(metadata=metadata, timeout=timeout) + async def func(tx): + txkey = backend.next_key() + backend.transactions[txkey] = tx + session_tracker.state = '' + await backend.send_response("RetryableTry", {"id": txkey}) + + cont = True + while cont: + cont = await backend.process_request() + if session_tracker.state == '+': + cont = False + elif session_tracker.state == '-': + if session_tracker.error_id: + raise backend.errors[session_tracker.error_id] + else: + raise FrontendError("Client said no") + + if is_read: + await session.read_transaction(func) + else: + await session.write_transaction(func) + await backend.send_response("RetryableDone", {}) + + +async def RetryablePositive(backend, data): + key = data["sessionId"] + session_tracker = backend.sessions[key] + session_tracker.state = '+' + + +async def RetryableNegative(backend, data): + key = data["sessionId"] + session_tracker = backend.sessions[key] + session_tracker.state = '-' + session_tracker.error_id = data.get('errorId', '') + + +async def SessionLastBookmarks(backend, data): + key = data["sessionId"] + session = backend.sessions[key].session + bookmark = await session.last_bookmark() + bookmarks = [] + if bookmark: + bookmarks.append(bookmark) + await backend.send_response("Bookmarks", {"bookmarks": bookmarks}) + + +async def TransactionRun(backend, data): + key = data["txId"] + tx = backend.transactions[key] + cypher, params = fromtestkit.to_cypher_and_params(data) + result = await tx.run(cypher, parameters=params) + key = backend.next_key() + backend.results[key] = result + await backend.send_response("Result", {"id": key, "keys": result.keys()}) + + +async def TransactionCommit(backend, data): + key = data["txId"] + tx = backend.transactions[key] + await tx.commit() + await backend.send_response("Transaction", {"id": key}) + + +async def TransactionRollback(backend, data): + key = data["txId"] + tx = backend.transactions[key] + await tx.rollback() + await backend.send_response("Transaction", {"id": key}) + + +async def TransactionClose(backend, data): + key = data["txId"] + tx = backend.transactions[key] + await tx.close() + await backend.send_response("Transaction", {"id": key}) + + +async def ResultNext(backend, data): + result = backend.results[data["resultId"]] + + try: + record = await AsyncUtil.next(AsyncUtil.iter(result)) + except StopAsyncIteration: + await backend.send_response("NullRecord", {}) + return + await backend.send_response("Record", totestkit.record(record)) + + +async def ResultSingle(backend, data): + result = backend.results[data["resultId"]] + await backend.send_response("Record", totestkit.record(result.single())) + + +async def ResultPeek(backend, data): + result = backend.results[data["resultId"]] + record = await result.peek() + if record is not None: + await backend.send_response("Record", totestkit.record(record)) + else: + await backend.send_response("NullRecord", {}) + + +async def ResultList(backend, data): + result = backend.results[data["resultId"]] + records = await AsyncUtil.list(result) + await backend.send_response("RecordList", { + "records": [totestkit.record(r) for r in records] + }) + + +async def ResultConsume(backend, data): + result = backend.results[data["resultId"]] + summary = await result.consume() + from neo4j import ResultSummary + assert isinstance(summary, ResultSummary) + await backend.send_response("Summary", { + "serverInfo": { + "address": ":".join(map(str, summary.server.address)), + "agent": summary.server.agent, + "protocolVersion": + ".".join(map(str, summary.server.protocol_version)), + }, + "counters": None if not summary.counters else { + "constraintsAdded": summary.counters.constraints_added, + "constraintsRemoved": summary.counters.constraints_removed, + "containsSystemUpdates": summary.counters.contains_system_updates, + "containsUpdates": summary.counters.contains_updates, + "indexesAdded": summary.counters.indexes_added, + "indexesRemoved": summary.counters.indexes_removed, + "labelsAdded": summary.counters.labels_added, + "labelsRemoved": summary.counters.labels_removed, + "nodesCreated": summary.counters.nodes_created, + "nodesDeleted": summary.counters.nodes_deleted, + "propertiesSet": summary.counters.properties_set, + "relationshipsCreated": summary.counters.relationships_created, + "relationshipsDeleted": summary.counters.relationships_deleted, + "systemUpdates": summary.counters.system_updates, + }, + "database": summary.database, + "notifications": summary.notifications, + "plan": summary.plan, + "profile": summary.profile, + "query": { + "text": summary.query, + "parameters": {k: totestkit.field(v) + for k, v in summary.parameters.items()}, + }, + "queryType": summary.query_type, + "resultAvailableAfter": summary.result_available_after, + "resultConsumedAfter": summary.result_consumed_after, + }) + + +async def ForcedRoutingTableUpdate(backend, data): + driver_id = data["driverId"] + driver = backend.drivers[driver_id] + database = data["database"] + bookmarks = data["bookmarks"] + async with driver._pool.refresh_lock: + await driver._pool.update_routing_table( + database=database, imp_user=None, bookmarks=bookmarks + ) + await backend.send_response("Driver", {"id": driver_id}) + + +async def GetRoutingTable(backend, data): + driver_id = data["driverId"] + database = data["database"] + driver = backend.drivers[driver_id] + routing_table = driver._pool.routing_tables[database] + response_data = { + "database": routing_table.database, + "ttl": routing_table.ttl, + } + for role in ("routers", "readers", "writers"): + addresses = routing_table.__getattribute__(role) + response_data[role] = list(map(str, addresses)) + await backend.send_response("RoutingTable", response_data) diff --git a/testkitbackend/_driver_logger.py b/testkitbackend/_driver_logger.py new file mode 100644 index 00000000..ef09ec35 --- /dev/null +++ b/testkitbackend/_driver_logger.py @@ -0,0 +1,41 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import io +import logging +import sys + + +buffer_handler = logging.StreamHandler(io.StringIO()) +buffer_handler.setLevel(logging.DEBUG) + +handler = logging.StreamHandler(sys.stdout) +handler.setLevel(logging.DEBUG) +logging.getLogger("neo4j").addHandler(handler) +logging.getLogger("neo4j").addHandler(buffer_handler) +logging.getLogger("neo4j").setLevel(logging.DEBUG) + +log = logging.getLogger("testkitbackend") +log.addHandler(handler) +log.setLevel(logging.DEBUG) + + +__all__ = [ + "buffer_handler", + "log", +] diff --git a/testkitbackend/_sync/__init__.py b/testkitbackend/_sync/__init__.py new file mode 100644 index 00000000..b81a309d --- /dev/null +++ b/testkitbackend/_sync/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/testkitbackend/_sync/backend.py b/testkitbackend/_sync/backend.py new file mode 100644 index 00000000..9f1bae94 --- /dev/null +++ b/testkitbackend/_sync/backend.py @@ -0,0 +1,145 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +from inspect import ( + getmembers, + isfunction, +) +from json import ( + dumps, + loads, +) +import traceback + +from neo4j._exceptions import BoltError +from neo4j.exceptions import ( + DriverError, + Neo4jError, + UnsupportedServerProduct, +) + +from . import requests +from .._driver_logger import ( + buffer_handler, + log, +) +from ..backend import Request + + +class Backend: + def __init__(self, rd, wr): + self._rd = rd + self._wr = wr + self.drivers = {} + self.custom_resolutions = {} + self.dns_resolutions = {} + self.sessions = {} + self.results = {} + self.errors = {} + self.transactions = {} + self.errors = {} + self.key = 0 + # Collect all request handlers + self._requestHandlers = dict( + [m for m in getmembers(requests, isfunction)]) + + def next_key(self): + self.key = self.key + 1 + return self.key + + def process_request(self): + """ Reads next request from the stream and processes it. + """ + in_request = False + request = "" + for line in self._rd: + # Remove trailing newline + line = line.decode('UTF-8').rstrip() + if line == "#request begin": + in_request = True + elif line == "#request end": + self._process(request) + return True + else: + if in_request: + request = request + line + return False + + def _process(self, request): + """ Process a received request by retrieving handler that + corresponds to the request name. + """ + try: + request = loads(request, object_pairs_hook=Request) + if not isinstance(request, Request): + raise Exception("Request is not an object") + name = request.get('name', 'invalid') + handler = self._requestHandlers.get(name) + if not handler: + raise Exception("No request handler for " + name) + data = request["data"] + log.info("<<< " + name + dumps(data)) + handler(self, data) + unsused_keys = request.unseen_keys + if unsused_keys: + raise NotImplementedError( + "Backend does not support some properties of the " + name + + " request: " + ", ".join(unsused_keys) + ) + except (Neo4jError, DriverError, UnsupportedServerProduct, + BoltError) as e: + log.debug(traceback.format_exc()) + if isinstance(e, Neo4jError): + msg = "" if e.message is None else str(e.message) + else: + msg = str(e.args[0]) if e.args else "" + + key = self.next_key() + self.errors[key] = e + payload = {"id": key, "errorType": str(type(e)), "msg": msg} + if isinstance(e, Neo4jError): + payload["code"] = e.code + self.send_response("DriverError", payload) + except requests.FrontendError as e: + self.send_response("FrontendError", {"msg": str(e)}) + except Exception: + tb = traceback.format_exc() + log.error(tb) + self.send_response("BackendError", {"msg": tb}) + + def send_response(self, name, data): + """ Sends a response to backend. + """ + with buffer_handler.lock: + log_output = buffer_handler.stream.getvalue() + buffer_handler.stream.truncate(0) + buffer_handler.stream.seek(0) + if not log_output.endswith("\n"): + log_output += "\n" + self._wr.write(log_output.encode("utf-8")) + response = {"name": name, "data": data} + response = dumps(response) + log.info(">>> " + name + dumps(data)) + self._wr.write(b"#response begin\n") + self._wr.write(bytes(response+"\n", "utf-8")) + self._wr.write(b"#response end\n") + if isinstance(self._wr, asyncio.StreamWriter): + self._wr.drain() + else: + self._wr.flush() diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py new file mode 100644 index 00000000..744d8356 --- /dev/null +++ b/testkitbackend/_sync/requests.py @@ -0,0 +1,444 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +from os import path + +import neo4j +from neo4j._async_compat.util import Util + +from .. import ( + fromtestkit, + totestkit, +) + + +class FrontendError(Exception): + pass + + +def load_config(): + config_path = path.join(path.dirname(__file__), "..", "test_config.json") + with open(config_path, "r") as fd: + config = json.load(fd) + skips = config["skips"] + features = [k for k, v in config["features"].items() if v is True] + import ssl + if ssl.HAS_TLSv1_3: + features += ["Feature:TLS:1.3"] + return skips, features + + +SKIPPED_TESTS, FEATURES = load_config() + + +def StartTest(backend, data): + if data["testName"] in SKIPPED_TESTS: + backend.send_response("SkipTest", { + "reason": SKIPPED_TESTS[data["testName"]] + }) + else: + backend.send_response("RunTest", {}) + + +def GetFeatures(backend, data): + backend.send_response("FeatureList", {"features": FEATURES}) + + +def NewDriver(backend, data): + auth_token = data["authorizationToken"]["data"] + data["authorizationToken"].mark_item_as_read_if_equals( + "name", "AuthorizationToken" + ) + scheme = auth_token["scheme"] + if scheme == "basic": + auth = neo4j.basic_auth( + auth_token["principal"], auth_token["credentials"], + realm=auth_token.get("realm", None) + ) + elif scheme == "kerberos": + auth = neo4j.kerberos_auth(auth_token["credentials"]) + elif scheme == "bearer": + auth = neo4j.bearer_auth(auth_token["credentials"]) + else: + auth = neo4j.custom_auth( + auth_token["principal"], auth_token["credentials"], + auth_token["realm"], auth_token["scheme"], + **auth_token.get("parameters", {}) + ) + auth_token.mark_item_as_read("parameters", recursive=True) + resolver = None + if data["resolverRegistered"] or data["domainNameResolverRegistered"]: + resolver = resolution_func(backend, data["resolverRegistered"], + data["domainNameResolverRegistered"]) + connection_timeout = data.get("connectionTimeoutMs") + if connection_timeout is not None: + connection_timeout /= 1000 + max_transaction_retry_time = data.get("maxTxRetryTimeMs") + if max_transaction_retry_time is not None: + max_transaction_retry_time /= 1000 + data.mark_item_as_read("domainNameResolverRegistered") + driver = neo4j.GraphDatabase.driver( + data["uri"], auth=auth, user_agent=data["userAgent"], + resolver=resolver, connection_timeout=connection_timeout, + fetch_size=data.get("fetchSize"), + max_transaction_retry_time=max_transaction_retry_time, + ) + key = backend.next_key() + backend.drivers[key] = driver + backend.send_response("Driver", {"id": key}) + + +def VerifyConnectivity(backend, data): + driver_id = data["driverId"] + driver = backend.drivers[driver_id] + driver.verify_connectivity() + backend.send_response("Driver", {"id": driver_id}) + + +def CheckMultiDBSupport(backend, data): + driver_id = data["driverId"] + driver = backend.drivers[driver_id] + backend.send_response("MultiDBSupport", { + "id": backend.next_key(), "available": driver.supports_multi_db() + }) + + +def resolution_func(backend, custom_resolver=False, custom_dns_resolver=False): + # This solution (putting custom resolution together with DNS resolution + # into one function only works because the Python driver calls the custom + # resolver function for every connection, which is not true for all + # drivers. Properly exposing a way to change the DNS lookup behavior is not + # possible without changing the driver's code. + assert custom_resolver or custom_dns_resolver + + def resolve(address): + addresses = [":".join(map(str, address))] + if custom_resolver: + key = backend.next_key() + backend.send_response("ResolverResolutionRequired", { + "id": key, + "address": addresses[0] + }) + if not backend.process_request(): + # connection was closed before end of next message + return [] + if key not in backend.custom_resolutions: + raise RuntimeError( + "Backend did not receive expected " + "ResolverResolutionCompleted message for id %s" % key + ) + addresses = backend.custom_resolutions.pop(key) + if custom_dns_resolver: + dns_resolved_addresses = [] + for address in addresses: + key = backend.next_key() + address = address.rsplit(":", 1) + backend.send_response("DomainNameResolutionRequired", { + "id": key, + "name": address[0] + }) + if not backend.process_request(): + # connection was closed before end of next message + return [] + if key not in backend.dns_resolutions: + raise RuntimeError( + "Backend did not receive expected " + "DomainNameResolutionCompleted message for id %s" % key + ) + dns_resolved_addresses += list(map( + lambda a: ":".join((a, *address[1:])), + backend.dns_resolutions.pop(key) + )) + + addresses = dns_resolved_addresses + + return list(map(neo4j.Address.parse, addresses)) + + return resolve + + +def ResolverResolutionCompleted(backend, data): + backend.custom_resolutions[data["requestId"]] = data["addresses"] + + +def DomainNameResolutionCompleted(backend, data): + backend.dns_resolutions[data["requestId"]] = data["addresses"] + + +def DriverClose(backend, data): + key = data["driverId"] + driver = backend.drivers[key] + driver.close() + backend.send_response("Driver", {"id": key}) + + +class SessionTracker: + """ Keeps some extra state about the tracked session + """ + + def __init__(self, session): + self.session = session + self.state = "" + self.error_id = "" + + +def NewSession(backend, data): + driver = backend.drivers[data["driverId"]] + access_mode = data["accessMode"] + if access_mode == "r": + access_mode = neo4j.READ_ACCESS + elif access_mode == "w": + access_mode = neo4j.WRITE_ACCESS + else: + raise ValueError("Unknown access mode:" + access_mode) + config = { + "default_access_mode": access_mode, + "bookmarks": data["bookmarks"], + "database": data["database"], + "fetch_size": data.get("fetchSize", None), + "impersonated_user": data.get("impersonatedUser", None), + + } + session = driver.session(**config) + key = backend.next_key() + backend.sessions[key] = SessionTracker(session) + backend.send_response("Session", {"id": key}) + + +def SessionRun(backend, data): + session = backend.sessions[data["sessionId"]].session + query, params = fromtestkit.to_query_and_params(data) + result = session.run(query, parameters=params) + key = backend.next_key() + backend.results[key] = result + backend.send_response("Result", {"id": key, "keys": result.keys()}) + + +def SessionClose(backend, data): + key = data["sessionId"] + session = backend.sessions[key].session + session.close() + del backend.sessions[key] + backend.send_response("Session", {"id": key}) + + +def SessionBeginTransaction(backend, data): + key = data["sessionId"] + session = backend.sessions[key].session + metadata, timeout = fromtestkit.to_meta_and_timeout(data) + tx = session.begin_transaction(metadata=metadata, timeout=timeout) + key = backend.next_key() + backend.transactions[key] = tx + backend.send_response("Transaction", {"id": key}) + + +def SessionReadTransaction(backend, data): + transactionFunc(backend, data, True) + + +def SessionWriteTransaction(backend, data): + transactionFunc(backend, data, False) + + +def transactionFunc(backend, data, is_read): + key = data["sessionId"] + session_tracker = backend.sessions[key] + session = session_tracker.session + metadata, timeout = fromtestkit.to_meta_and_timeout(data) + + @neo4j.unit_of_work(metadata=metadata, timeout=timeout) + def func(tx): + txkey = backend.next_key() + backend.transactions[txkey] = tx + session_tracker.state = '' + backend.send_response("RetryableTry", {"id": txkey}) + + cont = True + while cont: + cont = backend.process_request() + if session_tracker.state == '+': + cont = False + elif session_tracker.state == '-': + if session_tracker.error_id: + raise backend.errors[session_tracker.error_id] + else: + raise FrontendError("Client said no") + + if is_read: + session.read_transaction(func) + else: + session.write_transaction(func) + backend.send_response("RetryableDone", {}) + + +def RetryablePositive(backend, data): + key = data["sessionId"] + session_tracker = backend.sessions[key] + session_tracker.state = '+' + + +def RetryableNegative(backend, data): + key = data["sessionId"] + session_tracker = backend.sessions[key] + session_tracker.state = '-' + session_tracker.error_id = data.get('errorId', '') + + +def SessionLastBookmarks(backend, data): + key = data["sessionId"] + session = backend.sessions[key].session + bookmark = session.last_bookmark() + bookmarks = [] + if bookmark: + bookmarks.append(bookmark) + backend.send_response("Bookmarks", {"bookmarks": bookmarks}) + + +def TransactionRun(backend, data): + key = data["txId"] + tx = backend.transactions[key] + cypher, params = fromtestkit.to_cypher_and_params(data) + result = tx.run(cypher, parameters=params) + key = backend.next_key() + backend.results[key] = result + backend.send_response("Result", {"id": key, "keys": result.keys()}) + + +def TransactionCommit(backend, data): + key = data["txId"] + tx = backend.transactions[key] + tx.commit() + backend.send_response("Transaction", {"id": key}) + + +def TransactionRollback(backend, data): + key = data["txId"] + tx = backend.transactions[key] + tx.rollback() + backend.send_response("Transaction", {"id": key}) + + +def TransactionClose(backend, data): + key = data["txId"] + tx = backend.transactions[key] + tx.close() + backend.send_response("Transaction", {"id": key}) + + +def ResultNext(backend, data): + result = backend.results[data["resultId"]] + + try: + record = Util.next(Util.iter(result)) + except StopIteration: + backend.send_response("NullRecord", {}) + return + backend.send_response("Record", totestkit.record(record)) + + +def ResultSingle(backend, data): + result = backend.results[data["resultId"]] + backend.send_response("Record", totestkit.record(result.single())) + + +def ResultPeek(backend, data): + result = backend.results[data["resultId"]] + record = result.peek() + if record is not None: + backend.send_response("Record", totestkit.record(record)) + else: + backend.send_response("NullRecord", {}) + + +def ResultList(backend, data): + result = backend.results[data["resultId"]] + records = Util.list(result) + backend.send_response("RecordList", { + "records": [totestkit.record(r) for r in records] + }) + + +def ResultConsume(backend, data): + result = backend.results[data["resultId"]] + summary = result.consume() + from neo4j import ResultSummary + assert isinstance(summary, ResultSummary) + backend.send_response("Summary", { + "serverInfo": { + "address": ":".join(map(str, summary.server.address)), + "agent": summary.server.agent, + "protocolVersion": + ".".join(map(str, summary.server.protocol_version)), + }, + "counters": None if not summary.counters else { + "constraintsAdded": summary.counters.constraints_added, + "constraintsRemoved": summary.counters.constraints_removed, + "containsSystemUpdates": summary.counters.contains_system_updates, + "containsUpdates": summary.counters.contains_updates, + "indexesAdded": summary.counters.indexes_added, + "indexesRemoved": summary.counters.indexes_removed, + "labelsAdded": summary.counters.labels_added, + "labelsRemoved": summary.counters.labels_removed, + "nodesCreated": summary.counters.nodes_created, + "nodesDeleted": summary.counters.nodes_deleted, + "propertiesSet": summary.counters.properties_set, + "relationshipsCreated": summary.counters.relationships_created, + "relationshipsDeleted": summary.counters.relationships_deleted, + "systemUpdates": summary.counters.system_updates, + }, + "database": summary.database, + "notifications": summary.notifications, + "plan": summary.plan, + "profile": summary.profile, + "query": { + "text": summary.query, + "parameters": {k: totestkit.field(v) + for k, v in summary.parameters.items()}, + }, + "queryType": summary.query_type, + "resultAvailableAfter": summary.result_available_after, + "resultConsumedAfter": summary.result_consumed_after, + }) + + +def ForcedRoutingTableUpdate(backend, data): + driver_id = data["driverId"] + driver = backend.drivers[driver_id] + database = data["database"] + bookmarks = data["bookmarks"] + with driver._pool.refresh_lock: + driver._pool.update_routing_table( + database=database, imp_user=None, bookmarks=bookmarks + ) + backend.send_response("Driver", {"id": driver_id}) + + +def GetRoutingTable(backend, data): + driver_id = data["driverId"] + database = data["database"] + driver = backend.drivers[driver_id] + routing_table = driver._pool.routing_tables[database] + response_data = { + "database": routing_table.database, + "ttl": routing_table.ttl, + } + for role in ("routers", "readers", "writers"): + addresses = routing_table.__getattribute__(role) + response_data[role] = list(map(str, addresses)) + backend.send_response("RoutingTable", response_data) diff --git a/testkitbackend/backend.py b/testkitbackend/backend.py index c488f8ad..f97dcc1c 100644 --- a/testkitbackend/backend.py +++ b/testkitbackend/backend.py @@ -14,39 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from inspect import ( - getmembers, - isfunction, -) -import io -from json import loads, dumps -import logging -import sys -import traceback - -from neo4j._exceptions import ( - BoltError -) -from neo4j.exceptions import ( - DriverError, - Neo4jError, - UnsupportedServerProduct, -) - -import testkitbackend.requests as requests - -buffer_handler = logging.StreamHandler(io.StringIO()) -buffer_handler.setLevel(logging.DEBUG) - -handler = logging.StreamHandler(sys.stdout) -handler.setLevel(logging.DEBUG) -logging.getLogger("neo4j").addHandler(handler) -logging.getLogger("neo4j").addHandler(buffer_handler) -logging.getLogger("neo4j").setLevel(logging.DEBUG) - -log = logging.getLogger("testkitbackend") -log.addHandler(handler) -log.setLevel(logging.DEBUG) class Request(dict): @@ -93,104 +60,3 @@ def unseen_keys(self): @property def seen_all_keys(self): return not self.unseen_keys - - -class Backend: - def __init__(self, rd, wr): - self._rd = rd - self._wr = wr - self.drivers = {} - self.custom_resolutions = {} - self.dns_resolutions = {} - self.sessions = {} - self.results = {} - self.errors = {} - self.transactions = {} - self.errors = {} - self.key = 0 - # Collect all request handlers - self._requestHandlers = dict( - [m for m in getmembers(requests, isfunction)]) - - def next_key(self): - self.key = self.key + 1 - return self.key - - def process_request(self): - """ Reads next request from the stream and processes it. - """ - in_request = False - request = "" - for line in self._rd: - # Remove trailing newline - line = line.decode('UTF-8').rstrip() - if line == "#request begin": - in_request = True - elif line == "#request end": - self._process(request) - return True - else: - if in_request: - request = request + line - return False - - def _process(self, request): - """ Process a received request by retrieving handler that - corresponds to the request name. - """ - try: - request = loads(request, object_pairs_hook=Request) - if not isinstance(request, Request): - raise Exception("Request is not an object") - name = request.get('name', 'invalid') - handler = self._requestHandlers.get(name) - if not handler: - raise Exception("No request handler for " + name) - data = request["data"] - log.info("<<< " + name + dumps(data)) - handler(self, data) - unsused_keys = request.unseen_keys - if unsused_keys: - raise NotImplementedError( - "Backend does not support some properties of the " + name + - " request: " + ", ".join(unsused_keys) - ) - except (Neo4jError, DriverError, UnsupportedServerProduct, - BoltError) as e: - log.debug(traceback.format_exc()) - if isinstance(e, Neo4jError): - msg = "" if e.message is None else str(e.message) - else: - msg = str(e.args[0]) if e.args else "" - - key = self.next_key() - self.errors[key] = e - payload = {"id": key, "errorType": str(type(e)), "msg": msg} - if isinstance(e, Neo4jError): - payload["code"] = e.code - self.send_response("DriverError", payload) - except requests.FrontendError as e: - self.send_response("FrontendError", {"msg": str(e)}) - except Exception: - tb = traceback.format_exc() - log.error(tb) - self.send_response("BackendError", {"msg": tb}) - - def send_response(self, name, data): - """ Sends a response to backend. - """ - buffer_handler.acquire() - log_output = buffer_handler.stream.getvalue() - buffer_handler.stream.truncate(0) - buffer_handler.stream.seek(0) - buffer_handler.release() - if not log_output.endswith("\n"): - log_output += "\n" - self._wr.write(log_output.encode("utf-8")) - response = {"name": name, "data": data} - response = dumps(response) - log.info(">>> " + name + dumps(data)) - self._wr.write(b"#response begin\n") - self._wr.write(bytes(response+"\n", "utf-8")) - self._wr.write(b"#response end\n") - self._wr.flush() diff --git a/testkitbackend/fromtestkit.py b/testkitbackend/fromtestkit.py index 5c3c92d8..7cf17b92 100644 --- a/testkitbackend/fromtestkit.py +++ b/testkitbackend/fromtestkit.py @@ -15,7 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from neo4j.work.simple import Query + +from neo4j import Query def to_cypher_and_params(data): diff --git a/testkitbackend/requests.py b/testkitbackend/requests.py index 09a2c972..72e4d66f 100644 --- a/testkitbackend/requests.py +++ b/testkitbackend/requests.py @@ -14,13 +14,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + import json from os import path import neo4j -import testkitbackend.fromtestkit as fromtestkit -import testkitbackend.totestkit as totestkit -from testkitbackend.fromtestkit import to_meta_and_timeout + +from . import ( + fromtestkit, + totestkit, +) +from .fromtestkit import to_meta_and_timeout class FrontendError(Exception): @@ -357,7 +362,7 @@ def ResultList(backend, data): def ResultConsume(backend, data): result = backend.results[data["resultId"]] summary = result.consume() - from neo4j.work.summary import ResultSummary + from neo4j import ResultSummary assert isinstance(summary, ResultSummary) backend.send_response("Summary", { "serverInfo": { diff --git a/testkitbackend/server.py b/testkitbackend/server.py index 45728164..0eb36bd1 100644 --- a/testkitbackend/server.py +++ b/testkitbackend/server.py @@ -14,8 +14,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from socketserver import TCPServer, StreamRequestHandler -from testkitbackend.backend import Backend + + +import asyncio +from socketserver import ( + StreamRequestHandler, + TCPServer, +) + +from ._async.backend import AsyncBackend +from ._sync.backend import Backend class Server(TCPServer): @@ -29,3 +37,32 @@ def handle(self): pass print("Disconnected") super(Server, self).__init__(address, Handler) + + +class AsyncServer: + def __init__(self, address): + self._address = address + self._server = None + + @staticmethod + async def _handler(reader, writer): + backend = AsyncBackend(reader, writer) + while await backend.process_request(): + pass + print("Disconnected") + + async def start(self): + self._server = await asyncio.start_server( + self._handler, host=self._address[0], port=self._address[1], + limit=float("inf") # this is dirty but works (for now) + ) + + async def serve_forever(self): + if not self._server: + raise RuntimeError("Server not started") + await self._server.serve_forever() + + def stop(self): + if not self._server: + raise RuntimeError("Try starting the server before stopping it ;)") + self._server.close() diff --git a/testkitbackend/totestkit.py b/testkitbackend/totestkit.py index 3068017a..2e48861e 100644 --- a/testkitbackend/totestkit.py +++ b/testkitbackend/totestkit.py @@ -14,6 +14,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + import math from neo4j.graph import ( diff --git a/tests/__init__.py b/tests/__init__.py index e69de29b..b81a309d 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/env.py b/tests/env.py index cacb2440..f3e965cb 100644 --- a/tests/env.py +++ b/tests/env.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 015ba64d..9c123ac5 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -21,18 +18,19 @@ from math import ceil from os import getenv -from os.path import dirname, join +from os.path import ( + dirname, + join, +) from threading import RLock import pytest -import urllib -from neo4j import ( - GraphDatabase, -) -from neo4j.exceptions import ServiceUnavailable +from neo4j import GraphDatabase from neo4j._exceptions import BoltHandshakeError -from neo4j.io import Bolt +from neo4j._sync.io import Bolt +from neo4j.exceptions import ServiceUnavailable + # import logging # log = logging.getLogger("neo4j") @@ -50,7 +48,7 @@ NEO4J_CORES = 3 NEO4J_REPLICAS = 2 NEO4J_USER = "neo4j" -NEO4J_PASSWORD = "password" +NEO4J_PASSWORD = "pass" NEO4J_AUTH = (NEO4J_USER, NEO4J_PASSWORD) NEO4J_LOCK = RLock() NEO4J_SERVICE = None @@ -76,7 +74,10 @@ def __init__(self, name=None, image=None, auth=None, n_cores=None, n_replicas=None, bolt_port=None, http_port=None, debug_port=None, debug_suspend=None, dir_spec=None, config=None): - from boltkit.legacy.controller import _install, create_controller + from boltkit.legacy.controller import ( + _install, + create_controller, + ) assert image.endswith("-enterprise") release = image[:-11] if release == "snapshot": diff --git a/tests/integration/examples/__init__.py b/tests/integration/examples/__init__.py index 87bbaf26..a629622e 100644 --- a/tests/integration/examples/__init__.py +++ b/tests/integration/examples/__init__.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # diff --git a/tests/integration/examples/test_autocommit_transaction_example.py b/tests/integration/examples/test_autocommit_transaction_example.py index 0f5b16e1..f4ffa7e7 100644 --- a/tests/integration/examples/test_autocommit_transaction_example.py +++ b/tests/integration/examples/test_autocommit_transaction_example.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,9 +16,11 @@ # limitations under the License. +# isort: off # tag::autocommit-transaction-import[] from neo4j import Query # end::autocommit-transaction-import[] +# isort: on # python -m pytest tests/integration/examples/test_autocommit_transaction_example.py -s -v diff --git a/tests/integration/examples/test_basic_auth_example.py b/tests/integration/examples/test_basic_auth_example.py index 90e32808..482ac7ba 100644 --- a/tests/integration/examples/test_basic_auth_example.py +++ b/tests/integration/examples/test_basic_auth_example.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -21,13 +18,16 @@ import pytest +from neo4j._exceptions import BoltHandshakeError +from neo4j.exceptions import ServiceUnavailable +from tests.integration.examples import DriverSetupExample + + +# isort: off # tag::basic-auth-import[] from neo4j import GraphDatabase # end::basic-auth-import[] - -from neo4j.exceptions import ServiceUnavailable -from neo4j._exceptions import BoltHandshakeError -from tests.integration.examples import DriverSetupExample +# isort: on # python -m pytest tests/integration/examples/test_basic_auth_example.py -s -v diff --git a/tests/integration/examples/test_bearer_auth_example.py b/tests/integration/examples/test_bearer_auth_example.py index 727dd6a9..191ca7e0 100644 --- a/tests/integration/examples/test_bearer_auth_example.py +++ b/tests/integration/examples/test_bearer_auth_example.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,22 +16,22 @@ # limitations under the License. -import pytest - import neo4j +from tests.integration.examples import DriverSetupExample + + +# isort: off # tag::bearer-auth-import[] from neo4j import ( bearer_auth, GraphDatabase, ) # end::bearer-auth-import[] - -from tests.integration.examples import DriverSetupExample +# isort: on # python -m pytest tests/integration/examples/test_bearer_auth_example.py -s -v - class BearerAuthExample(DriverSetupExample): # tag::bearer-auth[] @@ -61,4 +58,3 @@ def test_example(uri, mocker): assert not hasattr(auth, "principal") assert auth.credentials == token assert not hasattr(auth, "parameters") - diff --git a/tests/integration/examples/test_config_connection_pool_example.py b/tests/integration/examples/test_config_connection_pool_example.py index 7ddaacc7..233adaa6 100644 --- a/tests/integration/examples/test_config_connection_pool_example.py +++ b/tests/integration/examples/test_config_connection_pool_example.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -21,14 +18,16 @@ import pytest -from neo4j.exceptions import ServiceUnavailable from neo4j._exceptions import BoltHandshakeError +from neo4j.exceptions import ServiceUnavailable +from tests.integration.examples import DriverSetupExample + +# isort: off # tag::config-connection-pool-import[] from neo4j import GraphDatabase # end::config-connection-pool-import[] - -from tests.integration.examples import DriverSetupExample +# isort: on # python -m pytest tests/integration/examples/test_config_connection_pool_example.py -s -v diff --git a/tests/integration/examples/test_config_connection_timeout_example.py b/tests/integration/examples/test_config_connection_timeout_example.py index 00d6c97d..ef17e2b4 100644 --- a/tests/integration/examples/test_config_connection_timeout_example.py +++ b/tests/integration/examples/test_config_connection_timeout_example.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -21,14 +18,16 @@ import pytest -from neo4j.exceptions import ServiceUnavailable from neo4j._exceptions import BoltHandshakeError +from neo4j.exceptions import ServiceUnavailable +from tests.integration.examples import DriverSetupExample + +# isort: off # tag::config-connection-timeout-import[] from neo4j import GraphDatabase # end::config-connection-timeout-import[] - -from tests.integration.examples import DriverSetupExample +# isort: on # python -m pytest tests/integration/examples/test_config_connection_timeout_example.py -s -v diff --git a/tests/integration/examples/test_config_max_retry_time_example.py b/tests/integration/examples/test_config_max_retry_time_example.py index 8cd25905..3a28235d 100644 --- a/tests/integration/examples/test_config_max_retry_time_example.py +++ b/tests/integration/examples/test_config_max_retry_time_example.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -21,13 +18,16 @@ import pytest +from neo4j._exceptions import BoltHandshakeError +from neo4j.exceptions import ServiceUnavailable +from tests.integration.examples import DriverSetupExample + + +# isort: off # tag::config-max-retry-time-import[] from neo4j import GraphDatabase # end::config-max-retry-time-import[] - -from neo4j.exceptions import ServiceUnavailable -from neo4j._exceptions import BoltHandshakeError -from tests.integration.examples import DriverSetupExample +# isort: on # python -m pytest tests/integration/examples/test_config_max_retry_time_example.py -s -v diff --git a/tests/integration/examples/test_config_secure_example.py b/tests/integration/examples/test_config_secure_example.py index ebd048dc..2a6c6603 100644 --- a/tests/integration/examples/test_config_secure_example.py +++ b/tests/integration/examples/test_config_secure_example.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -21,15 +18,21 @@ import pytest -# tag::config-secure-import[] -from neo4j import GraphDatabase, TRUST_SYSTEM_CA_SIGNED_CERTIFICATES -# end::config-secure-import[] - -from neo4j.exceptions import ServiceUnavailable from neo4j._exceptions import BoltHandshakeError +from neo4j.exceptions import ServiceUnavailable from tests.integration.examples import DriverSetupExample +# isort: off +# tag::config-secure-import[] +from neo4j import ( + GraphDatabase, + TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, +) +# end::config-secure-import[] +# isort: off + + # python -m pytest tests/integration/examples/test_config_secure_example.py -s -v class ConfigSecureExample(DriverSetupExample): diff --git a/tests/integration/examples/test_config_trust_example.py b/tests/integration/examples/test_config_trust_example.py index 1466c536..b63139f1 100644 --- a/tests/integration/examples/test_config_trust_example.py +++ b/tests/integration/examples/test_config_trust_example.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -21,15 +18,17 @@ import pytest +from tests.integration.examples import DriverSetupExample + + +# isort: off # tag::config-trust-import[] from neo4j import ( GraphDatabase, - TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, - TRUST_ALL_CERTIFICATES, + TRUST_ALL_CERTIFICATES ) # end::config-trust-import[] - -from tests.integration.examples import DriverSetupExample +# isort: on class ConfigTrustExample(DriverSetupExample): diff --git a/tests/integration/examples/test_config_unencrypted_example.py b/tests/integration/examples/test_config_unencrypted_example.py index 6d70ad03..d04297df 100644 --- a/tests/integration/examples/test_config_unencrypted_example.py +++ b/tests/integration/examples/test_config_unencrypted_example.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -21,14 +18,16 @@ import pytest +from neo4j._exceptions import BoltHandshakeError +from neo4j.exceptions import ServiceUnavailable +from tests.integration.examples import DriverSetupExample + + +# isort: off # tag::config-unencrypted-import[] from neo4j import GraphDatabase # end::config-unencrypted-import[] - -from neo4j.exceptions import ServiceUnavailable -from neo4j._exceptions import BoltHandshakeError - -from tests.integration.examples import DriverSetupExample +# isort: on # python -m pytest tests/integration/examples/test_config_unencrypted_example.py -s -v diff --git a/tests/integration/examples/test_custom_auth_example.py b/tests/integration/examples/test_custom_auth_example.py index 1aaddd1e..c9621c6c 100644 --- a/tests/integration/examples/test_custom_auth_example.py +++ b/tests/integration/examples/test_custom_auth_example.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -21,16 +18,19 @@ import pytest +from neo4j._exceptions import BoltHandshakeError +from neo4j.exceptions import ServiceUnavailable +from tests.integration.examples import DriverSetupExample + + +# isort: off # tag::custom-auth-import[] from neo4j import ( GraphDatabase, custom_auth, ) # end::custom-auth-import[] - -from neo4j.exceptions import ServiceUnavailable -from neo4j._exceptions import BoltHandshakeError -from tests.integration.examples import DriverSetupExample +# isort: on # python -m pytest tests/integration/examples/test_custom_auth_example.py -s -v diff --git a/tests/integration/examples/test_custom_resolver_example.py b/tests/integration/examples/test_custom_resolver_example.py index fe4e5e08..9bc349de 100644 --- a/tests/integration/examples/test_custom_resolver_example.py +++ b/tests/integration/examples/test_custom_resolver_example.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -21,15 +18,18 @@ import pytest +from neo4j._exceptions import BoltHandshakeError +from neo4j.exceptions import ServiceUnavailable + + +# isort: off # tag::custom-resolver-import[] from neo4j import ( GraphDatabase, WRITE_ACCESS, ) # end::custom-resolver-import[] - -from neo4j.exceptions import ServiceUnavailable -from neo4j._exceptions import BoltHandshakeError +# isort: on # python -m pytest tests/integration/examples/test_custom_resolver_example.py -s -v diff --git a/tests/integration/examples/test_cypher_error_example.py b/tests/integration/examples/test_cypher_error_example.py index c2108b9f..411929db 100644 --- a/tests/integration/examples/test_cypher_error_example.py +++ b/tests/integration/examples/test_cypher_error_example.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -22,9 +19,12 @@ from contextlib import redirect_stdout from io import StringIO + +# isort: off # tag::cypher-error-import[] from neo4j.exceptions import ClientError # end::cypher-error-import[] +# isort: on class Neo4jErrorExample: diff --git a/tests/integration/examples/test_database_selection_example.py b/tests/integration/examples/test_database_selection_example.py index 99293ea8..65b8b8fc 100644 --- a/tests/integration/examples/test_database_selection_example.py +++ b/tests/integration/examples/test_database_selection_example.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,24 +16,22 @@ # limitations under the License. -import pytest - from contextlib import redirect_stdout from io import StringIO -from neo4j import GraphDatabase +# isort: off # tag::database-selection-import[] -from neo4j import READ_ACCESS +from neo4j import ( + GraphDatabase, + READ_ACCESS, +) # end::database-selection-import[] - -from neo4j.exceptions import ServiceUnavailable -from neo4j._exceptions import BoltHandshakeError +# isort: on # python -m pytest tests/integration/examples/test_database_selection_example.py -s -v - class DatabaseSelectionExample: def __init__(self, uri, user, password): diff --git a/tests/integration/examples/test_driver_introduction_example.py b/tests/integration/examples/test_driver_introduction_example.py index febd26a7..4348b35f 100644 --- a/tests/integration/examples/test_driver_introduction_example.py +++ b/tests/integration/examples/test_driver_introduction_example.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,11 +16,15 @@ # limitations under the License. -import pytest - from contextlib import redirect_stdout from io import StringIO +import pytest + +from neo4j._exceptions import BoltHandshakeError + + +# isort: off # tag::driver-introduction-example-import[] import logging import sys @@ -31,8 +32,7 @@ from neo4j import GraphDatabase from neo4j.exceptions import ServiceUnavailable # end::driver-introduction-example-import[] - -from neo4j._exceptions import BoltHandshakeError +# isort: on # python -m pytest tests/integration/examples/test_driver_introduction_example.py -s -v diff --git a/tests/integration/examples/test_driver_lifecycle_example.py b/tests/integration/examples/test_driver_lifecycle_example.py index d416173f..3ba18854 100644 --- a/tests/integration/examples/test_driver_lifecycle_example.py +++ b/tests/integration/examples/test_driver_lifecycle_example.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -21,12 +18,15 @@ import pytest +from neo4j._exceptions import BoltHandshakeError +from neo4j.exceptions import ServiceUnavailable + + +# isort: off # tag::driver-lifecycle-import[] from neo4j import GraphDatabase # end::driver-lifecycle-import[] - -from neo4j.exceptions import ServiceUnavailable -from neo4j._exceptions import BoltHandshakeError +# isort: on # python -m pytest tests/integration/examples/test_driver_lifecycle_example.py -s -v diff --git a/tests/integration/examples/test_geospatial_types_example.py b/tests/integration/examples/test_geospatial_types_example.py index 30b7ff4e..39a2df8e 100644 --- a/tests/integration/examples/test_geospatial_types_example.py +++ b/tests/integration/examples/test_geospatial_types_example.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -18,8 +15,10 @@ # See the License for the specific language governing permissions and # limitations under the License. + import pytest + # python -m pytest tests/integration/examples/test_geospatial_types_example.py -s -v @@ -28,9 +27,11 @@ def _echo(tx, x): def test_cartesian_point(driver): + # isort: off # tag::geospatial-types-cartesian-import[] from neo4j.spatial import CartesianPoint # end::geospatial-types-cartesian-import[] + # isort: on # tag::geospatial-types-cartesian[] # Creating a 2D point in Cartesian space @@ -87,9 +88,11 @@ def test_cartesian_point(driver): def test_wgs84_point(driver): + # isort: off # tag::geospatial-types-wgs84-import[] from neo4j.spatial import WGS84Point # end::geospatial-types-wgs84-import[] + # isort: on # tag::geospatial-types-wgs84[] # Creating a 2D point in WSG84 space diff --git a/tests/integration/examples/test_hello_world_example.py b/tests/integration/examples/test_hello_world_example.py index a2723453..b3c83e8a 100644 --- a/tests/integration/examples/test_hello_world_example.py +++ b/tests/integration/examples/test_hello_world_example.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,17 +16,20 @@ # limitations under the License. -import pytest - from contextlib import redirect_stdout from io import StringIO +import pytest + +from neo4j._exceptions import BoltHandshakeError +from neo4j.exceptions import ServiceUnavailable + + +# isort: off # tag::hello-world-import[] from neo4j import GraphDatabase # end::hello-world-import[] - -from neo4j.exceptions import ServiceUnavailable -from neo4j._exceptions import BoltHandshakeError +# isort: on # python -m pytest tests/integration/examples/test_hello_world_example.py -s -v @@ -79,4 +79,3 @@ def test_hello_world_example(uri, auth): except ServiceUnavailable as error: if isinstance(error.__cause__, BoltHandshakeError): pytest.skip(error.args[0]) - diff --git a/tests/integration/examples/test_kerberos_auth_example.py b/tests/integration/examples/test_kerberos_auth_example.py index 4a8c42ec..ed8b8c17 100644 --- a/tests/integration/examples/test_kerberos_auth_example.py +++ b/tests/integration/examples/test_kerberos_auth_example.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -18,16 +15,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + +from tests.integration.examples import DriverSetupExample + -import neo4j +# isort: off # tag::kerberos-auth-import[] from neo4j import ( GraphDatabase, kerberos_auth, ) # end::kerberos-auth-import[] - -from tests.integration.examples import DriverSetupExample +# isort: on # python -m pytest tests/integration/examples/test_kerberos_auth_example.py -s -v @@ -39,21 +39,5 @@ def __init__(self, uri, ticket): # end::kerberos-auth[] -def test_example(uri, mocker): - # Currently, there is no way of running the test against a server with SSO - # setup. - mocker.patch("neo4j.GraphDatabase.bolt_driver") - mocker.patch("neo4j.GraphDatabase.neo4j_driver") - - ticket = "myTicket" - KerberosAuthExample(uri, ticket) - calls = (neo4j.GraphDatabase.bolt_driver.call_args_list - + neo4j.GraphDatabase.neo4j_driver.call_args_list) - assert len(calls) == 1 - args_, kwargs = calls[0] - auth = kwargs.get("auth") - assert isinstance(auth, neo4j.Auth) - assert auth.scheme == "kerberos" - assert auth.principal == "" - assert auth.credentials == ticket - assert not hasattr(auth, "parameters") +def test_example(): + pytest.skip("Currently no way to test Kerberos auth") diff --git a/tests/integration/examples/test_pass_bookmarks_example.py b/tests/integration/examples/test_pass_bookmarks_example.py index 6f51584a..51cc33d9 100644 --- a/tests/integration/examples/test_pass_bookmarks_example.py +++ b/tests/integration/examples/test_pass_bookmarks_example.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -21,12 +18,15 @@ import pytest +from neo4j._exceptions import BoltHandshakeError +from neo4j.exceptions import ServiceUnavailable + + +# isort: off # tag::pass-bookmarks-import[] from neo4j import GraphDatabase # end::pass-bookmarks-import[] - -from neo4j.exceptions import ServiceUnavailable -from neo4j._exceptions import BoltHandshakeError +# isort: on # python -m pytest tests/integration/examples/test_pass_bookmarks_example.py -s -v diff --git a/tests/integration/examples/test_read_write_transaction_example.py b/tests/integration/examples/test_read_write_transaction_example.py index 04a1a4ef..787a4849 100644 --- a/tests/integration/examples/test_read_write_transaction_example.py +++ b/tests/integration/examples/test_read_write_transaction_example.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # diff --git a/tests/integration/examples/test_result_consume_example.py b/tests/integration/examples/test_result_consume_example.py index 53da413b..3770d9ed 100644 --- a/tests/integration/examples/test_result_consume_example.py +++ b/tests/integration/examples/test_result_consume_example.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # diff --git a/tests/integration/examples/test_result_retain_example.py b/tests/integration/examples/test_result_retain_example.py index a96bf02e..828a8236 100644 --- a/tests/integration/examples/test_result_retain_example.py +++ b/tests/integration/examples/test_result_retain_example.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # diff --git a/tests/integration/examples/test_service_unavailable_example.py b/tests/integration/examples/test_service_unavailable_example.py index 7b2961db..21972798 100644 --- a/tests/integration/examples/test_service_unavailable_example.py +++ b/tests/integration/examples/test_service_unavailable_example.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -21,9 +18,12 @@ import pytest + +# isort: off # tag::service-unavailable-import[] from neo4j.exceptions import ServiceUnavailable # end::service-unavailable-import[] +# isort: on def service_unavailable_example(driver): diff --git a/tests/integration/examples/test_session_example.py b/tests/integration/examples/test_session_example.py index 65cb3255..14302c31 100644 --- a/tests/integration/examples/test_session_example.py +++ b/tests/integration/examples/test_session_example.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # diff --git a/tests/integration/examples/test_temporal_types_example.py b/tests/integration/examples/test_temporal_types_example.py index 97b57e88..42b17cdc 100644 --- a/tests/integration/examples/test_temporal_types_example.py +++ b/tests/integration/examples/test_temporal_types_example.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -26,12 +23,14 @@ def _echo(tx, x): def test_datetime(driver): + # isort: off # tag::temporal-types-datetime-import[] from datetime import datetime from neo4j.time import DateTime import pytz # end::temporal-types-datetime-import[] + # isort: on # tag::temporal-types-datetime[] # Create datetimes to be used as query parameters @@ -78,11 +77,13 @@ def test_datetime(driver): def test_date(driver): + # isort: off # tag::temporal-types-date-import[] from datetime import date from neo4j.time import Date # end::temporal-types-date-import[] + # isort: on # tag::temporal-types-date[] # Create dates to be used as query parameters @@ -127,12 +128,14 @@ def test_date(driver): def test_time(driver): + # isort: off # tag::temporal-types-time-import[] from datetime import time from neo4j.time import Time import pytz # end::temporal-types-time-import[] + # isort: on # tag::temporal-types-time[] # Create datetimes to be used as query parameters @@ -176,12 +179,13 @@ def test_time(driver): def test_local_datetime(driver): + # isort: off # tag::temporal-types-local-datetime-import[] from datetime import datetime from neo4j.time import DateTime - import pytz # end::temporal-types-local-datetime-import[] + # isort: on # tag::temporal-types-local-datetime[] # Create datetimes to be used as query parameters @@ -226,12 +230,13 @@ def test_local_datetime(driver): def test_local_time(driver): + # isort: off # tag::temporal-types-local-time-import[] from datetime import time from neo4j.time import Time - import pytz # end::temporal-types-local-time-import[] + # isort: on # tag::temporal-types-local-time[] # Create datetimes to be used as query parameters @@ -275,11 +280,13 @@ def test_local_time(driver): def test_duration_example(driver): + # isort: off # tag::temporal-types-duration-import[] from datetime import timedelta from neo4j.time import Duration # end::temporal-types-duration-import[] + # isort: on # tag::temporal-types-duration[] # Creating durations to be used as query parameters diff --git a/tests/integration/examples/test_transaction_function_example.py b/tests/integration/examples/test_transaction_function_example.py index 5e525dff..e0b6588c 100644 --- a/tests/integration/examples/test_transaction_function_example.py +++ b/tests/integration/examples/test_transaction_function_example.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,9 +16,11 @@ # limitations under the License. +# isort: off # tag::transaction-function-import[] from neo4j import unit_of_work # end::transaction-function-import[] +# isort: on # python -m pytest tests/integration/examples/test_transaction_function_example.py -s -v diff --git a/tests/integration/examples/test_transaction_metadata_config_example.py b/tests/integration/examples/test_transaction_metadata_config_example.py index ef7ef39d..8862d22a 100644 --- a/tests/integration/examples/test_transaction_metadata_config_example.py +++ b/tests/integration/examples/test_transaction_metadata_config_example.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -18,7 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from neo4j import unit_of_work, Query + +from neo4j import unit_of_work # python -m pytest tests/integration/examples/test_transaction_metadata_config_example.py -s -v diff --git a/tests/integration/examples/test_transaction_timeout_config_example.py b/tests/integration/examples/test_transaction_timeout_config_example.py index cf3ed5a6..e3503691 100644 --- a/tests/integration/examples/test_transaction_timeout_config_example.py +++ b/tests/integration/examples/test_transaction_timeout_config_example.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -18,7 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from neo4j import unit_of_work, Query + +from neo4j import unit_of_work # python -m pytest tests/integration/examples/test_transaction_timeout_config_example.py -s -v diff --git a/tests/integration/test_autocommit.py b/tests/integration/test_autocommit.py index 32619991..960cbe3f 100644 --- a/tests/integration/test_autocommit.py +++ b/tests/integration/test_autocommit.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,7 +16,7 @@ # limitations under the License. -from neo4j.work.simple import Query +from neo4j import Query # TODO: this test will stay until a uniform behavior for `.single()` across the diff --git a/tests/integration/test_bolt_driver.py b/tests/integration/test_bolt_driver.py index 5b2e76a0..346b82cd 100644 --- a/tests/integration/test_bolt_driver.py +++ b/tests/integration/test_bolt_driver.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,30 +16,8 @@ # limitations under the License. -import pytest from pytest import fixture -from neo4j import ( - GraphDatabase, - BoltDriver, - Version, - READ_ACCESS, - WRITE_ACCESS, - ResultSummary, - unit_of_work, - Transaction, - Result, - ServerInfo, -) -from neo4j.exceptions import ( - ServiceUnavailable, - AuthError, - ConfigurationError, - ClientError, -) -from neo4j._exceptions import BoltHandshakeError -from neo4j.io._bolt3 import Bolt3 - # TODO: this test will stay until a uniform behavior for `.single()` across the # drivers has been specified and tests are created in testkit diff --git a/tests/integration/test_pipelines.py b/tests/integration/test_pipelines.py deleted file mode 100644 index 299ec76f..00000000 --- a/tests/integration/test_pipelines.py +++ /dev/null @@ -1,308 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [http://neo4j.com] -# -# This file is part of Neo4j. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import pytest -from uuid import uuid4 - - -from neo4j.packstream import Structure -from neo4j.exceptions import ( - Neo4jError, - CypherSyntaxError, -) -from neo4j.graph import ( - Node, - Relationship, - Path, -) -from neo4j.work.pipelining import PullOrderException - - -def test_can_run_simple_statement(bolt_driver): - pipeline = bolt_driver.pipeline(flush_every=0) - pipeline.push("RETURN 1 AS n") - for record in pipeline.pull(): - assert len(record) == 1 - assert record[0] == 1 - # TODO: why does pipeline result not look like a regular result? - # assert record["n"] == 1 - # with pytest.raises(KeyError): - # _ = record["x"] - # assert record["n"] == 1 - # with pytest.raises(KeyError): - # _ = record["x"] - with pytest.raises(TypeError): - _ = record[object()] - assert repr(record) - assert len(record) == 1 - pipeline.close() - - -def test_can_run_simple_statement_with_params(bolt_driver): - pipeline = bolt_driver.pipeline(flush_every=0) - count = 0 - pipeline.push("RETURN $x AS n", {"x": {"abc": ["d", "e", "f"]}}) - for record in pipeline.pull(): - assert record[0] == {"abc": ["d", "e", "f"]} - # TODO: why does pipeline result not look like a regular result? - # assert record["n"] == {"abc": ["d", "e", "f"]} - assert repr(record) - assert len(record) == 1 - count += 1 - pipeline.close() - assert count == 1 - - -def test_can_run_write_statement_with_no_return(driver): - pipeline = driver.pipeline(flush_every=0) - count = 0 - test_uid = str(uuid4()) - pipeline.push("CREATE (a:Person {uid:$test_uid})", dict(test_uid=test_uid)) - - for _ in pipeline.pull(): - raise Exception("Should not return any results from create with no return") - # Note you still have to consume the generator if you want to be allowed to pull from the pipeline again even - # though it doesn't apparently return any items. - - pipeline.push("MATCH (a:Person {uid:$test_uid}) RETURN a LIMIT 1", dict(test_uid=test_uid)) - for _ in pipeline.pull(): - count += 1 - pipeline.close() - assert count == 1 - - -def test_fails_on_bad_syntax(bolt_driver): - pipeline = bolt_driver.pipeline(flush_every=0) - with pytest.raises(Neo4jError): - pipeline.push("X") - next(pipeline.pull()) - - -def test_doesnt_fail_on_bad_syntax_somewhere(bolt_driver): - pipeline = bolt_driver.pipeline(flush_every=0) - pipeline.push("RETURN 1 AS n") - pipeline.push("X") - assert next(pipeline.pull())[0] == 1 - with pytest.raises(Neo4jError): - next(pipeline.pull()) - - -def test_fails_on_missing_parameter(bolt_driver): - pipeline = bolt_driver.pipeline(flush_every=0) - with pytest.raises(Neo4jError): - pipeline.push("RETURN $x") - next(pipeline.pull()) - - -def test_can_run_simple_statement_from_bytes_string(bolt_driver): - pipeline = bolt_driver.pipeline(flush_every=0) - count = 0 - pytest.skip("FIXME: why can't pipeline handle bytes string?") - pipeline.push(b"RETURN 1 AS n") - for record in pipeline.pull(): - assert record[0] == 1 - assert record["n"] == 1 - assert repr(record) - assert len(record) == 1 - count += 1 - pipeline.close() - assert count == 1 - - -def test_can_run_statement_that_returns_multiple_records(bolt_driver): - pipeline = bolt_driver.pipeline(flush_every=0) - count = 0 - pipeline.push("unwind(range(1, 10)) AS z RETURN z") - for record in pipeline.pull(): - assert 1 <= record[0] <= 10 - count += 1 - pipeline.close() - assert count == 10 - - -def test_can_return_node(neo4j_driver): - with neo4j_driver.pipeline(flush_every=0) as pipeline: - pipeline.push("CREATE (a:Person {name:'Alice'}) RETURN a") - record_list = list(pipeline.pull()) - assert len(record_list) == 1 - for record in record_list: - alice = record[0] - print(alice) - pytest.skip("FIXME: why does pipeline result not look like a regular result?") - assert isinstance(alice, Node) - assert alice.labels == {"Person"} - assert dict(alice) == {"name": "Alice"} - - -def test_can_return_relationship(neo4j_driver): - with neo4j_driver.pipeline(flush_every=0) as pipeline: - pipeline.push("CREATE ()-[r:KNOWS {since:1999}]->() RETURN r") - record_list = list(pipeline.pull()) - assert len(record_list) == 1 - for record in record_list: - rel = record[0] - print(rel) - pytest.skip("FIXME: why does pipeline result not look like a regular result?") - assert isinstance(rel, Relationship) - assert rel.type == "KNOWS" - assert dict(rel) == {"since": 1999} - - -def test_can_return_path(neo4j_driver): - with neo4j_driver.pipeline(flush_every=0) as pipeline: - test_uid = str(uuid4()) - pipeline.push( - "MERGE p=(alice:Person {name:'Alice', test_uid: $test_uid})" - "-[:KNOWS {test_uid: $test_uid}]->" - "(:Person {name:'Bob', test_uid: $test_uid})" - " RETURN p", - dict(test_uid=test_uid) - ) - record_list = list(pipeline.pull()) - assert len(record_list) == 1 - for record in record_list: - path = record[0] - print(path) - pytest.skip("FIXME: why does pipeline result not look like a regular result?") - assert isinstance(path, Path) - assert path.start_node["name"] == "Alice" - assert path.end_node["name"] == "Bob" - assert path.relationships[0].type == "KNOWS" - assert len(path.nodes) == 2 - assert len(path.relationships) == 1 - - -def test_can_handle_cypher_error(bolt_driver): - with bolt_driver.pipeline(flush_every=0) as pipeline: - pipeline.push("X") - with pytest.raises(Neo4jError): - next(pipeline.pull()) - - -def test_should_not_allow_empty_statements(bolt_driver, requires_bolt_4x): - with bolt_driver.pipeline(flush_every=0) as pipeline: - pipeline.push("") - with pytest.raises(CypherSyntaxError): - next(pipeline.pull()) - - -def test_can_queue_multiple_statements(bolt_driver): - count = 0 - with bolt_driver.pipeline(flush_every=0) as pipeline: - pipeline.push("unwind(range(1, 10)) AS z RETURN z") - pipeline.push("unwind(range(11, 20)) AS z RETURN z") - pipeline.push("unwind(range(21, 30)) AS z RETURN z") - for i in range(3): - for record in pipeline.pull(): - assert (i * 10 + 1) <= record[0] <= ((i + 1) * 10) - count += 1 - assert count == 30 - - -def test_pull_order_exception(bolt_driver): - """If you try and pull when you haven't finished iterating the previous result you get an error""" - pipeline = bolt_driver.pipeline(flush_every=0) - with pytest.raises(PullOrderException): - pipeline.push("unwind(range(1, 10)) AS z RETURN z") - pipeline.push("unwind(range(11, 20)) AS z RETURN z") - generator_one = pipeline.pull() - generator_two = pipeline.pull() - - -def test_pipeline_can_read_own_writes(neo4j_driver): - """I am not sure that we _should_ guarantee this""" - count = 0 - with neo4j_driver.pipeline(flush_every=0) as pipeline: - test_uid = str(uuid4()) - pipeline.push( - "CREATE (a:Person {name:'Alice', test_uid: $test_uid})", - dict(test_uid=test_uid) - ) - pipeline.push( - "MATCH (alice:Person {name:'Alice', test_uid: $test_uid}) " - "MERGE (alice)" - "-[:KNOWS {test_uid: $test_uid}]->" - "(:Person {name:'Bob', test_uid: $test_uid})", - dict(test_uid=test_uid) - ) - pipeline.push("MATCH (n:Person {test_uid: $test_uid}) RETURN n", dict(test_uid=test_uid)) - pipeline.push( - "MATCH" - " p=(:Person {test_uid: $test_uid})-[:KNOWS {test_uid: $test_uid}]->(:Person {test_uid: $test_uid})" - " RETURN p", - dict(test_uid=test_uid) - ) - - # create Alice - # n.b. we have to consume the result - assert next(pipeline.pull(), True) == True - - # merge "knows Bob" - # n.b. we have to consume the result - assert next(pipeline.pull(), True) == True - - # get people - for result in pipeline.pull(): - count += 1 - - assert len(result) == 1 - person = result[0] - print(person) - assert isinstance(person, Structure) - assert person.tag == b'N' - print(person.fields) - assert set(person.fields[1]) == {"Person"} - - print(count) - assert count == 2 - - # get path - for result in pipeline.pull(): - count += 1 - - assert len(result) == 1 - path = result[0] - print(path) - assert isinstance(path, Structure) - assert path.tag == b'P' - - # TODO: return Path / Node / Rel instances rather than Structures - # assert isinstance(path, Path) - # assert path.start_node["name"] == "Alice" - # assert path.end_node["name"] == "Bob" - # assert path.relationships[0].type == "KNOWS" - # assert len(path.nodes) == 2 - # assert len(path.relationships) == 1 - - assert count == 3 - - -def test_automatic_reset_after_failure(bolt_driver): - with bolt_driver.pipeline(flush_every=0) as pipeline: - try: - pipeline.push("X") - next(pipeline.pull()) - except Neo4jError: - pipeline.push("RETURN 1") - record = next(pipeline.pull()) - assert record[0] == 1 - else: - assert False, "A Cypher error should have occurred" diff --git a/tests/integration/test_readme.py b/tests/integration/test_readme.py index 4788411d..fde46639 100644 --- a/tests/integration/test_readme.py +++ b/tests/integration/test_readme.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -18,10 +15,12 @@ # See the License for the specific language governing permissions and # limitations under the License. + import pytest -from neo4j.exceptions import ServiceUnavailable from neo4j._exceptions import BoltHandshakeError +from neo4j.exceptions import ServiceUnavailable + # python -m pytest tests/integration/test_readme.py -s -v diff --git a/tests/integration/test_result.py b/tests/integration/test_result.py index 80e14fb2..b4af833a 100644 --- a/tests/integration/test_result.py +++ b/tests/integration/test_result.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # diff --git a/tests/integration/test_result_graph.py b/tests/integration/test_result_graph.py index 8c2cf81f..15ea7a37 100644 --- a/tests/integration/test_result_graph.py +++ b/tests/integration/test_result_graph.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -21,13 +18,7 @@ import pytest -from neo4j.graph import ( - Node, - Relationship, - Graph, - Path, -) -from neo4j.exceptions import Neo4jError +from neo4j.graph import Graph def test_result_graph_instance(session): diff --git a/tests/integration/test_spatial_types.py b/tests/integration/test_spatial_types.py index 2227715f..71ed4cd5 100644 --- a/tests/integration/test_spatial_types.py +++ b/tests/integration/test_spatial_types.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -21,7 +18,6 @@ import pytest - from neo4j.spatial import ( CartesianPoint, WGS84Point, diff --git a/tests/integration/test_temporal_types.py b/tests/integration/test_temporal_types.py index 0e996db2..e9e64fb6 100644 --- a/tests/integration/test_temporal_types.py +++ b/tests/integration/test_temporal_types.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,9 +16,9 @@ # limitations under the License. -import pytest import datetime +import pytest from pytz import ( FixedOffset, timezone, @@ -31,9 +28,9 @@ from neo4j.exceptions import CypherTypeError from neo4j.time import ( Date, - Time, DateTime, Duration, + Time, ) @@ -402,4 +399,4 @@ def test_time_parameter_case3(session): t2 = session.run("RETURN $time", time=t1).single().value() assert isinstance(t2, Time) - assert t1 == t2 \ No newline at end of file + assert t1 == t2 diff --git a/tests/integration/test_tx_functions.py b/tests/integration/test_tx_functions.py index 184efbcd..bfaa2d58 100644 --- a/tests/integration/test_tx_functions.py +++ b/tests/integration/test_tx_functions.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,15 +16,17 @@ # limitations under the License. -import pytest from uuid import uuid4 -from neo4j.work.simple import unit_of_work +import pytest + +from neo4j import unit_of_work from neo4j.exceptions import ( - Neo4jError, ClientError, + Neo4jError, ) + # python -m pytest tests/integration/test_tx_functions.py -s -v diff --git a/tests/performance/test_async_results.py b/tests/performance/test_async_results.py new file mode 100644 index 00000000..085001e0 --- /dev/null +++ b/tests/performance/test_async_results.py @@ -0,0 +1,86 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +from itertools import product + +from pytest import mark + +from neo4j import AsyncGraphDatabase + +from .tools import RemoteGraphDatabaseServer + + +class AsyncReadWorkload(object): + + server = None + driver = None + loop = None + + @classmethod + def setup_class(cls): + cls.server = server = RemoteGraphDatabaseServer() + server.start() + cls.loop = asyncio.new_event_loop() + asyncio.set_event_loop(cls.loop) + cls.driver = AsyncGraphDatabase.driver(server.server_uri, + auth=server.auth_token, + encrypted=server.encrypted) + + @classmethod + def teardown_class(cls): + try: + cls.loop.run_until_complete(cls.driver.close()) + cls.server.stop() + finally: + cls.loop.stop() + asyncio.set_event_loop(None) + + def work(self, *units_of_work): + async def runner(): + async with self.driver.session() as session: + for unit_of_work in units_of_work: + await session.read_transaction(unit_of_work) + + def sync_runner(): + self.loop.run_until_complete(runner()) + + return sync_runner + + +class TestAsyncReadWorkload(AsyncReadWorkload): + + @staticmethod + def uow(record_count, record_width, value): + + async def _(tx): + s = "UNWIND range(1, $record_count) AS _ RETURN {}".format( + ", ".join("$x AS x{}".format(i) for i in range(record_width))) + p = {"record_count": record_count, "x": value} + async for record in await tx.run(s, p): + assert all(x == value for x in record.values()) + + return _ + + @mark.parametrize("record_count,record_width,value", product( + [1, 1000], # record count + [1, 10], # record width + [1, u'hello, world'], # value + )) + def test_1x1(self, benchmark, record_count, record_width, value): + benchmark(self.work(self.uow(record_count, record_width, value))) diff --git a/tests/performance/test_results.py b/tests/performance/test_results.py index 15971faa..9629c7d8 100644 --- a/tests/performance/test_results.py +++ b/tests/performance/test_results.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -24,6 +21,7 @@ from pytest import mark from neo4j import GraphDatabase + from .tools import RemoteGraphDatabaseServer diff --git a/tests/performance/tools.py b/tests/performance/tools.py index 3c51a4fe..511644e5 100644 --- a/tests/performance/tools.py +++ b/tests/performance/tools.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,17 +16,15 @@ # limitations under the License. -from unittest import TestCase, SkipTest - -try: - from urllib.request import urlretrieve -except ImportError: - from urllib import urlretrieve +from unittest import SkipTest from neo4j import GraphDatabase from neo4j.exceptions import AuthError - -from tests.env import NEO4J_USER, NEO4J_PASSWORD, NEO4J_SERVER_URI +from tests.env import ( + NEO4J_PASSWORD, + NEO4J_SERVER_URI, + NEO4J_USER, +) def is_listening(address): diff --git a/tests/requirements.txt b/tests/requirements.txt index b9405952..d3b59806 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,7 +1,10 @@ git+https://github.com/neo4j-drivers/boltkit@4.2#egg=boltkit -coverage -pytest -pytest-benchmark -pytest-cov -pytest-mock -teamcity-messages +coverage>=5.5 +pytest>=6.2.5 +pytest-asyncio>=0.16.0 +pytest-benchmark>=3.4.1 +pytest-cov>=3.0.0 +pytest-mock>=3.6.1 +# brings mock.AsyncMock to Python 3.7 (3.8+ ships with built-in support) +mock>=4.0.3; python_version < '3.8' +teamcity-messages>=1.29 diff --git a/tests/stub/conftest.py b/tests/stub/conftest.py index c8d91e69..7738a949 100644 --- a/tests/stub/conftest.py +++ b/tests/stub/conftest.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,18 +16,17 @@ # limitations under the License. -import subprocess +import logging import os -import time - from platform import system +import subprocess from threading import Thread from time import sleep from boltkit.server.stub import BoltStubService from pytest import fixture -import logging + log = logging.getLogger("neo4j") # from neo4j.debug import watch diff --git a/tests/stub/scripts/v1/empty_explicit_hello_goodbye.script b/tests/stub/scripts/v1/empty_explicit_hello_goodbye.script index cd48e67a..f55e38cc 100644 --- a/tests/stub/scripts/v1/empty_explicit_hello_goodbye.script +++ b/tests/stub/scripts/v1/empty_explicit_hello_goodbye.script @@ -3,4 +3,4 @@ C: INIT {"user_agent": "test", "scheme": "basic", "principal": "test", "credentials": "test"} S: SUCCESS {"server": "Neo4j/3.3.0", "connection_id": "123e4567-e89b-12d3-a456-426655440000"} C: RESET -S: \ No newline at end of file +S: diff --git a/tests/stub/scripts/v2/empty_explicit_hello_goodbye.script b/tests/stub/scripts/v2/empty_explicit_hello_goodbye.script index 7e8ae848..8e661fde 100644 --- a/tests/stub/scripts/v2/empty_explicit_hello_goodbye.script +++ b/tests/stub/scripts/v2/empty_explicit_hello_goodbye.script @@ -3,4 +3,4 @@ C: INIT {"user_agent": "test", "scheme": "basic", "principal": "test", "credentials": "test"} S: SUCCESS {"server": "Neo4j/3.4.0", "connection_id": "123e4567-e89b-12d3-a456-426655440000"} C: RESET -S: \ No newline at end of file +S: diff --git a/tests/stub/scripts/v3/dbms_cluster_routing_get_routing_table_system.script b/tests/stub/scripts/v3/dbms_cluster_routing_get_routing_table_system.script index 7aea9af8..e47a2e2e 100644 --- a/tests/stub/scripts/v3/dbms_cluster_routing_get_routing_table_system.script +++ b/tests/stub/scripts/v3/dbms_cluster_routing_get_routing_table_system.script @@ -11,4 +11,4 @@ S: SUCCESS {"fields": ["ttl", "servers"]} SUCCESS {} C: GOODBYE -S: \ No newline at end of file +S: diff --git a/tests/stub/scripts/v3/empty_explicit_hello_goodbye.script b/tests/stub/scripts/v3/empty_explicit_hello_goodbye.script index 4ed2b804..be98dd7a 100644 --- a/tests/stub/scripts/v3/empty_explicit_hello_goodbye.script +++ b/tests/stub/scripts/v3/empty_explicit_hello_goodbye.script @@ -4,4 +4,4 @@ C: HELLO {"user_agent": "test", "scheme": "basic", "principal": "test", "credentials": "test"} S: SUCCESS {"server": "Neo4j/3.5.0", "connection_id": "123e4567-e89b-12d3-a456-426655440000"} C: GOODBYE -S: \ No newline at end of file +S: diff --git a/tests/stub/scripts/v3/get_routing_table.script b/tests/stub/scripts/v3/get_routing_table.script index b3ec1fa0..07e1ef20 100644 --- a/tests/stub/scripts/v3/get_routing_table.script +++ b/tests/stub/scripts/v3/get_routing_table.script @@ -8,4 +8,4 @@ C: RUN "CALL dbms.cluster.routing.getRoutingTable($context)" {"context": {"addre PULL_ALL S: SUCCESS {"fields": ["ttl", "servers"]} RECORD [9223372036854775807, [{"addresses": ["127.0.0.1:9001"],"role": "WRITE"}, {"addresses": ["127.0.0.1:9002"], "role": "READ"},{"addresses": ["127.0.0.1:9001", "127.0.0.1:9002"], "role": "ROUTE"}]] - SUCCESS {} \ No newline at end of file + SUCCESS {} diff --git a/tests/stub/scripts/v3/get_routing_table_with_context.script b/tests/stub/scripts/v3/get_routing_table_with_context.script index 65e67bc6..03a23967 100644 --- a/tests/stub/scripts/v3/get_routing_table_with_context.script +++ b/tests/stub/scripts/v3/get_routing_table_with_context.script @@ -8,4 +8,4 @@ C: RUN "CALL dbms.cluster.routing.getRoutingTable($context)" {"context": {"name" PULL_ALL S: SUCCESS {"fields": ["ttl", "servers"]} RECORD [9223372036854775807, [{"addresses": ["127.0.0.1:9001"],"role": "WRITE"}, {"addresses": ["127.0.0.1:9002"], "role": "READ"},{"addresses": ["127.0.0.1:9001", "127.0.0.1:9002"], "role": "ROUTE"}]] - SUCCESS {} \ No newline at end of file + SUCCESS {} diff --git a/tests/stub/scripts/v3/pull_all_port_9001_transaction_function.script b/tests/stub/scripts/v3/pull_all_port_9001_transaction_function.script index 8dd7b406..1003b6ef 100644 --- a/tests/stub/scripts/v3/pull_all_port_9001_transaction_function.script +++ b/tests/stub/scripts/v3/pull_all_port_9001_transaction_function.script @@ -16,4 +16,4 @@ S: RECORD [1] RECORD [4] SUCCESS {"type": "r", "t_last": 500} C: COMMIT -S: SUCCESS {"bookmark": "neo4j:bookmark-test-1"} \ No newline at end of file +S: SUCCESS {"bookmark": "neo4j:bookmark-test-1"} diff --git a/tests/stub/scripts/v3/return_1_port_9001.script b/tests/stub/scripts/v3/return_1_port_9001.script index 06cc51a5..5135a788 100644 --- a/tests/stub/scripts/v3/return_1_port_9001.script +++ b/tests/stub/scripts/v3/return_1_port_9001.script @@ -9,4 +9,4 @@ C: RUN "RETURN 1 AS x" {} {"mode": "r"} PULL_ALL S: SUCCESS {"fields": ["x"]} RECORD [1] - SUCCESS {"bookmark": "neo4j:bookmark-test-1", "type": "r", "t_last": 5} \ No newline at end of file + SUCCESS {"bookmark": "neo4j:bookmark-test-1", "type": "r", "t_last": 5} diff --git a/tests/stub/scripts/v3/router_with_multiple_servers.script b/tests/stub/scripts/v3/router_with_multiple_servers.script index 34ac0f0f..fec9d25d 100644 --- a/tests/stub/scripts/v3/router_with_multiple_servers.script +++ b/tests/stub/scripts/v3/router_with_multiple_servers.script @@ -8,4 +8,4 @@ C: RUN "CALL dbms.cluster.routing.getRoutingTable($context)" {"context": {"addre PULL_ALL S: SUCCESS {"fields": ["ttl", "servers"]} RECORD [300, [{"role":"ROUTE","addresses":["127.0.0.1:9001","127.0.0.1:9002"]},{"role":"READ","addresses":["127.0.0.1:9001","127.0.0.1:9003"]},{"role":"WRITE","addresses":["127.0.0.1:9004"]}]] - SUCCESS {} \ No newline at end of file + SUCCESS {} diff --git a/tests/stub/scripts/v3/rude_reader.script b/tests/stub/scripts/v3/rude_reader.script index 055c06f1..ea323e98 100644 --- a/tests/stub/scripts/v3/rude_reader.script +++ b/tests/stub/scripts/v3/rude_reader.script @@ -7,4 +7,3 @@ C: RUN "RETURN 1" {} {"mode": "r"} PULL_ALL S: - diff --git a/tests/stub/scripts/v3/rude_router.script b/tests/stub/scripts/v3/rude_router.script index 9fa163b8..a29bbbd4 100644 --- a/tests/stub/scripts/v3/rude_router.script +++ b/tests/stub/scripts/v3/rude_router.script @@ -7,4 +7,3 @@ C: RUN "CALL dbms.cluster.routing.getRoutingTable($context)" {"context": {"address": "localhost:9001"}} {"mode": "r"} PULL_ALL S: - diff --git a/tests/stub/scripts/v4x0/dbms_routing_get_routing_table_system_default.script b/tests/stub/scripts/v4x0/dbms_routing_get_routing_table_system_default.script index 05f9a9ab..ec865816 100644 --- a/tests/stub/scripts/v4x0/dbms_routing_get_routing_table_system_default.script +++ b/tests/stub/scripts/v4x0/dbms_routing_get_routing_table_system_default.script @@ -11,4 +11,4 @@ S: SUCCESS {"fields": ["ttl", "servers"]} SUCCESS {"bookmark": "neo4j:bookmark-test-1", "type": "s", "t_last": 15, "db": "system"} C: GOODBYE -S: \ No newline at end of file +S: diff --git a/tests/stub/scripts/v4x0/dbms_routing_get_routing_table_system_neo4j.script b/tests/stub/scripts/v4x0/dbms_routing_get_routing_table_system_neo4j.script index aed4b3c4..ba93794d 100644 --- a/tests/stub/scripts/v4x0/dbms_routing_get_routing_table_system_neo4j.script +++ b/tests/stub/scripts/v4x0/dbms_routing_get_routing_table_system_neo4j.script @@ -11,4 +11,4 @@ S: SUCCESS {"fields": ["ttl", "servers"]} SUCCESS {"bookmark": "neo4j:bookmark-test-1", "type": "r", "t_last": 15, "db": "neo4j"} C: GOODBYE -S: \ No newline at end of file +S: diff --git a/tests/stub/scripts/v4x0/empty.script b/tests/stub/scripts/v4x0/empty.script index 6e97eca9..9b11fd90 100644 --- a/tests/stub/scripts/v4x0/empty.script +++ b/tests/stub/scripts/v4x0/empty.script @@ -2,4 +2,4 @@ !: AUTO HELLO !: AUTO GOODBYE !: AUTO RESET -!: PORT 9001 \ No newline at end of file +!: PORT 9001 diff --git a/tests/stub/scripts/v4x0/empty_explicit_hello_goodbye.script b/tests/stub/scripts/v4x0/empty_explicit_hello_goodbye.script index 29957f14..5a2bdc88 100644 --- a/tests/stub/scripts/v4x0/empty_explicit_hello_goodbye.script +++ b/tests/stub/scripts/v4x0/empty_explicit_hello_goodbye.script @@ -4,4 +4,4 @@ C: HELLO {"user_agent": "test", "scheme": "basic", "principal": "test", "credentials": "test"} S: SUCCESS {"server": "Neo4j/4.0.0", "connection_id": "123e4567-e89b-12d3-a456-426655440000"} C: GOODBYE -S: \ No newline at end of file +S: diff --git a/tests/stub/scripts/v4x0/return_1_four_times_port_9004.script b/tests/stub/scripts/v4x0/return_1_four_times_port_9004.script index 47698633..a1f56d8b 100644 --- a/tests/stub/scripts/v4x0/return_1_four_times_port_9004.script +++ b/tests/stub/scripts/v4x0/return_1_four_times_port_9004.script @@ -26,4 +26,4 @@ C: RUN "RETURN 1" {} {"mode": "r"} PULL {"n": -1} S: SUCCESS {"fields": ["1"]} RECORD [1] - SUCCESS {} \ No newline at end of file + SUCCESS {} diff --git a/tests/stub/scripts/v4x0/return_1_port_9001.script b/tests/stub/scripts/v4x0/return_1_port_9001.script index c0c9c04c..188c1839 100644 --- a/tests/stub/scripts/v4x0/return_1_port_9001.script +++ b/tests/stub/scripts/v4x0/return_1_port_9001.script @@ -9,4 +9,4 @@ C: RUN "RETURN 1 AS x" {} {"mode": "r"} PULL {"n": -1} S: SUCCESS {"fields": ["x"]} RECORD [1] - SUCCESS {"bookmark": "neo4j:bookmark-test-1", "type": "r", "t_last": 5, "db": "system"} \ No newline at end of file + SUCCESS {"bookmark": "neo4j:bookmark-test-1", "type": "r", "t_last": 5, "db": "system"} diff --git a/tests/stub/scripts/v4x0/router_port_9001_one_read_port_9004_one_write_port_9006.script b/tests/stub/scripts/v4x0/router_port_9001_one_read_port_9004_one_write_port_9006.script index cafc5742..b373111f 100644 --- a/tests/stub/scripts/v4x0/router_port_9001_one_read_port_9004_one_write_port_9006.script +++ b/tests/stub/scripts/v4x0/router_port_9001_one_read_port_9004_one_write_port_9006.script @@ -10,4 +10,4 @@ C: RUN "CALL dbms.routing.getRoutingTable($context)" {"context": {"address": "lo PULL {"n": -1} S: SUCCESS {"fields": ["ttl", "servers"]} RECORD [300, [{"role":"ROUTE", "addresses":["127.0.0.1:9001", "127.0.0.1:9002"]}, {"role":"READ", "addresses":["127.0.0.1:9004"]}, {"role":"WRITE", "addresses":["127.0.0.1:9006"]}]] - SUCCESS {"bookmark": "neo4j:bookmark-test-1", "type": "s", "t_last": 5, "db": "system"} \ No newline at end of file + SUCCESS {"bookmark": "neo4j:bookmark-test-1", "type": "s", "t_last": 5, "db": "system"} diff --git a/tests/stub/test_directdriver.py b/tests/stub/test_directdriver.py index 2e7ae7f6..492e323f 100644 --- a/tests/stub/test_directdriver.py +++ b/tests/stub/test_directdriver.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -21,33 +18,14 @@ import pytest -from neo4j.exceptions import ( - ServiceUnavailable, - ConfigurationError, - UnsupportedServerProduct, -) -from neo4j._exceptions import ( - BoltHandshakeError, - BoltSecurityError, -) - from neo4j import ( - GraphDatabase, BoltDriver, - Query, - WRITE_ACCESS, + GraphDatabase, READ_ACCESS, - TRUST_ALL_CERTIFICATES, - TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, - DEFAULT_DATABASE, - Result, - unit_of_work, - Transaction, ) +from neo4j.exceptions import ServiceUnavailable +from tests.stub.conftest import StubCluster -from tests.stub.conftest import ( - StubCluster, -) # python -m pytest tests/stub/test_directdriver.py -s -v diff --git a/tests/stub/test_routingdriver.py b/tests/stub/test_routingdriver.py index 139c225d..c94b5d95 100644 --- a/tests/stub/test_routingdriver.py +++ b/tests/stub/test_routingdriver.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -24,25 +21,11 @@ from neo4j import ( GraphDatabase, Neo4jDriver, - TRUST_ALL_CERTIFICATES, - TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, - DEFAULT_DATABASE, -) -from neo4j.api import ( - READ_ACCESS, - WRITE_ACCESS, -) -from neo4j.exceptions import ( - ServiceUnavailable, - TransientError, - SessionExpired, - ConfigurationError, -) -from neo4j._exceptions import ( - BoltSecurityError, ) +from neo4j.exceptions import ServiceUnavailable from tests.stub.conftest import StubCluster + # python -m pytest tests/stub/test_routingdriver.py -s -v # TODO: those tests will stay until a uniform behavior across the drivers has # been specified and tests are created in testkit diff --git a/tests/unit/README.md b/tests/unit/README.md new file mode 100644 index 00000000..415e2d23 --- /dev/null +++ b/tests/unit/README.md @@ -0,0 +1,10 @@ +Structure: + +``` +. +├── _async_compat # utility code to allow auto-converting async tests to sync tests +├── async # tests that are specifit to async classes, methods, or functions +├── common # tests of classes, methods, and functions that only exist in sync +├── mixed # tests that cannot be auto converted from async to sync +└── sync # auto-genereated sync versions of tetst in async +``` diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index e69de29b..b81a309d 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/_async_compat/__init__.py b/tests/unit/_async_compat/__init__.py new file mode 100644 index 00000000..285e3af5 --- /dev/null +++ b/tests/unit/_async_compat/__init__.py @@ -0,0 +1,53 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import sys + + +if sys.version_info >= (3, 8): + from unittest import mock + from unittest.mock import AsyncMockMixin +else: + import mock + from mock.mock import AsyncMockMixin + +from .mark_decorator import ( + mark_async_test, + mark_sync_test, +) + + +AsyncMagicMock = mock.AsyncMock +MagicMock = mock.MagicMock +Mock = mock.Mock +mock.NonCallableMagicMock + + +class AsyncMock(AsyncMockMixin, Mock): + pass + + +__all__ = [ + "mark_async_test", + "mark_sync_test", + "AsyncMagicMock", + "AsyncMock", + "MagicMock", + "Mock", + "mock", +] diff --git a/tests/unit/_async_compat/mark_decorator.py b/tests/unit/_async_compat/mark_decorator.py new file mode 100644 index 00000000..6cd4832d --- /dev/null +++ b/tests/unit/_async_compat/mark_decorator.py @@ -0,0 +1,26 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + + +mark_async_test = pytest.mark.asyncio + + +def mark_sync_test(f): + return f diff --git a/tests/unit/async_/__init__.py b/tests/unit/async_/__init__.py new file mode 100644 index 00000000..b81a309d --- /dev/null +++ b/tests/unit/async_/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/async_/io/__init__.py b/tests/unit/async_/io/__init__.py new file mode 100644 index 00000000..b81a309d --- /dev/null +++ b/tests/unit/async_/io/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/async_/io/conftest.py b/tests/unit/async_/io/conftest.py new file mode 100644 index 00000000..1da3e44e --- /dev/null +++ b/tests/unit/async_/io/conftest.py @@ -0,0 +1,156 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from io import BytesIO +from struct import ( + pack as struct_pack, + unpack as struct_unpack, +) + +import pytest + +from neo4j._async.io._common import AsyncMessageInbox +from neo4j.packstream import ( + Packer, + UnpackableBuffer, + Unpacker, +) + + +class AsyncFakeSocket: + + def __init__(self, address): + self.address = address + self.captured = b"" + self.messages = AsyncMessageInbox(self, on_error=print) + + def getsockname(self): + return "127.0.0.1", 0xFFFF + + def getpeername(self): + return self.address + + async def recv_into(self, buffer, nbytes): + data = self.captured[:nbytes] + actual = len(data) + buffer[:actual] = data + self.captured = self.captured[actual:] + return actual + + async def sendall(self, data): + self.captured += data + + def close(self): + return + + async def pop_message(self): + return await self.messages.pop() + + +class AsyncFakeSocket2: + + def __init__(self, address=None, on_send=None): + self.address = address + self.recv_buffer = bytearray() + self._messages = AsyncMessageInbox(self, on_error=print) + self.on_send = on_send + + def getsockname(self): + return "127.0.0.1", 0xFFFF + + def getpeername(self): + return self.address + + async def recv_into(self, buffer, nbytes): + data = self.recv_buffer[:nbytes] + actual = len(data) + buffer[:actual] = data + self.recv_buffer = self.recv_buffer[actual:] + return actual + + async def sendall(self, data): + if callable(self.on_send): + self.on_send(data) + + def close(self): + return + + def inject(self, data): + self.recv_buffer += data + + def _pop_chunk(self): + chunk_size, = struct_unpack(">H", self.recv_buffer[:2]) + print("CHUNK SIZE %r" % chunk_size) + end = 2 + chunk_size + chunk_data, self.recv_buffer = self.recv_buffer[2:end], self.recv_buffer[end:] + return chunk_data + + async def pop_message(self): + data = bytearray() + while True: + chunk = self._pop_chunk() + print("CHUNK %r" % chunk) + if chunk: + data.extend(chunk) + elif data: + break # end of message + else: + continue # NOOP + header = data[0] + n_fields = header % 0x10 + tag = data[1] + buffer = UnpackableBuffer(data[2:]) + unpacker = Unpacker(buffer) + fields = [unpacker.unpack() for _ in range(n_fields)] + return tag, fields + + async def send_message(self, tag, *fields): + data = self.encode_message(tag, *fields) + await self.sendall(struct_pack(">H", len(data)) + data + b"\x00\x00") + + @classmethod + def encode_message(cls, tag, *fields): + b = BytesIO() + packer = Packer(b) + for field in fields: + packer.pack(field) + return bytearray([0xB0 + len(fields), tag]) + b.getvalue() + + +class AsyncFakeSocketPair: + + def __init__(self, address): + self.client = AsyncFakeSocket2(address) + self.server = AsyncFakeSocket2() + self.client.on_send = self.server.inject + self.server.on_send = self.client.inject + + +@pytest.fixture +def fake_socket(): + return AsyncFakeSocket + + +@pytest.fixture +def fake_socket_2(): + return AsyncFakeSocket2 + + +@pytest.fixture +def fake_socket_pair(): + return AsyncFakeSocketPair diff --git a/tests/unit/async_/io/test__common.py b/tests/unit/async_/io/test__common.py new file mode 100644 index 00000000..7fefd5b9 --- /dev/null +++ b/tests/unit/async_/io/test__common.py @@ -0,0 +1,50 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._async.io._common import Outbox + + +@pytest.mark.parametrize(("chunk_size", "data", "result"), ( + ( + 2, + (bytes(range(10, 15)),), + bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 1, 14)) + ), + ( + 2, + (bytes(range(10, 14)),), + bytes((0, 2, 10, 11, 0, 2, 12, 13)) + ), + ( + 2, + (bytes((5, 6, 7)), bytes((8, 9))), + bytes((0, 2, 5, 6, 0, 2, 7, 8, 0, 1, 9)) + ), +)) +def test_async_outbox_chunking(chunk_size, data, result): + outbox = Outbox(max_chunk_size=chunk_size) + assert bytes(outbox.view()) == b"" + for d in data: + outbox.write(d) + assert bytes(outbox.view()) == result + # make sure this works multiple times + assert bytes(outbox.view()) == result + outbox.clear() + assert bytes(outbox.view()) == b"" diff --git a/tests/unit/async_/io/test_class_bolt.py b/tests/unit/async_/io/test_class_bolt.py new file mode 100644 index 00000000..50706d75 --- /dev/null +++ b/tests/unit/async_/io/test_class_bolt.py @@ -0,0 +1,62 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._async.io import AsyncBolt + + +# python -m pytest tests/unit/io/test_class_bolt.py -s -v + + +def test_class_method_protocol_handlers(): + # python -m pytest tests/unit/io/test_class_bolt.py -s -v -k test_class_method_protocol_handlers + protocol_handlers = AsyncBolt.protocol_handlers() + assert len(protocol_handlers) == 6 + + +@pytest.mark.parametrize( + "test_input, expected", + [ + ((0, 0), 0), + ((4, 0), 1), + ] +) +def test_class_method_protocol_handlers_with_protocol_version(test_input, expected): + # python -m pytest tests/unit/io/test_class_bolt.py -s -v -k test_class_method_protocol_handlers_with_protocol_version + protocol_handlers = AsyncBolt.protocol_handlers(protocol_version=test_input) + assert len(protocol_handlers) == expected + + +def test_class_method_protocol_handlers_with_invalid_protocol_version(): + # python -m pytest tests/unit/io/test_class_bolt.py -s -v -k test_class_method_protocol_handlers_with_invalid_protocol_version + with pytest.raises(TypeError): + AsyncBolt.protocol_handlers(protocol_version=2) + + +def test_class_method_get_handshake(): + # python -m pytest tests/unit/io/test_class_bolt.py -s -v -k test_class_method_get_handshake + handshake = AsyncBolt.get_handshake() + assert handshake == b"\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x04\x00\x00\x00\x03" + + +def test_magic_preamble(): + # python -m pytest tests/unit/io/test_class_bolt.py -s -v -k test_magic_preamble + preamble = 0x6060B017 + preamble_bytes = preamble.to_bytes(4, byteorder="big") + assert AsyncBolt.MAGIC_PREAMBLE == preamble_bytes diff --git a/tests/unit/async_/io/test_class_bolt3.py b/tests/unit/async_/io/test_class_bolt3.py new file mode 100644 index 00000000..b1220da6 --- /dev/null +++ b/tests/unit/async_/io/test_class_bolt3.py @@ -0,0 +1,115 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._async.io._bolt3 import AsyncBolt3 +from neo4j.conf import PoolConfig +from neo4j.exceptions import ConfigurationError + +from ..._async_compat import ( + AsyncMagicMock, + mark_async_test, +) + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 0 + connection = AsyncBolt3(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = -1 + connection = AsyncBolt3(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 999999999 + connection = AsyncBolt3(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +def test_db_extra_not_supported_in_begin(fake_socket): + address = ("127.0.0.1", 7687) + connection = AsyncBolt3(address, fake_socket(address), PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError): + connection.begin(db="something") + + +def test_db_extra_not_supported_in_run(fake_socket): + address = ("127.0.0.1", 7687) + connection = AsyncBolt3(address, fake_socket(address), PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError): + connection.run("", db="something") + + +@mark_async_test +async def test_simple_discard(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt3(address, socket, PoolConfig.max_connection_lifetime) + connection.discard() + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 0 + + +@mark_async_test +async def test_simple_pull(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt3(address, socket, PoolConfig.max_connection_lifetime) + connection.pull() + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 0 + + +@pytest.mark.parametrize("recv_timeout", (1, -1)) +@mark_async_test +async def test_hint_recv_timeout_seconds_gets_ignored( + fake_socket_pair, recv_timeout +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + sockets.client.settimeout = AsyncMagicMock() + await sockets.server.send_message(0x70, { + "server": "Neo4j/3.5.0", + "hints": {"connection.recv_timeout_seconds": recv_timeout}, + }) + connection = AsyncBolt3( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + await connection.hello() + sockets.client.settimeout.assert_not_called() diff --git a/tests/unit/async_/io/test_class_bolt4x0.py b/tests/unit/async_/io/test_class_bolt4x0.py new file mode 100644 index 00000000..c2623cf1 --- /dev/null +++ b/tests/unit/async_/io/test_class_bolt4x0.py @@ -0,0 +1,209 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from unittest.mock import MagicMock + +import pytest + +from neo4j._async.io._bolt4 import AsyncBolt4x0 +from neo4j.conf import PoolConfig + +from ..._async_compat import mark_async_test + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 0 + connection = AsyncBolt4x0(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = -1 + connection = AsyncBolt4x0(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 999999999 + connection = AsyncBolt4x0(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@mark_async_test +async def test_db_extra_in_begin(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) + connection.begin(db="something") + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x11" + assert len(fields) == 1 + assert fields[0] == {"db": "something"} + + +@mark_async_test +async def test_db_extra_in_run(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) + connection.run("", {}, db="something") + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x10" + assert len(fields) == 3 + assert fields[0] == "" + assert fields[1] == {} + assert fields[2] == {"db": "something"} + + +@mark_async_test +async def test_n_extra_in_discard(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ] +) +@mark_async_test +async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666, qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_n_extra_in_pull(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_qid_extra_in_pull(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_async_test +async def test_n_and_qid_extras_in_pull(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=666, qid=777) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@pytest.mark.parametrize("recv_timeout", (1, -1)) +@mark_async_test +async def test_hint_recv_timeout_seconds_gets_ignored( + fake_socket_pair, recv_timeout +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + sockets.client.settimeout = MagicMock() + await sockets.server.send_message(0x70, { + "server": "Neo4j/4.0.0", + "hints": {"connection.recv_timeout_seconds": recv_timeout}, + }) + connection = AsyncBolt4x0( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + await connection.hello() + sockets.client.settimeout.assert_not_called() diff --git a/tests/unit/async_/io/test_class_bolt4x1.py b/tests/unit/async_/io/test_class_bolt4x1.py new file mode 100644 index 00000000..9123e3b0 --- /dev/null +++ b/tests/unit/async_/io/test_class_bolt4x1.py @@ -0,0 +1,227 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._async.io._bolt4 import AsyncBolt4x1 +from neo4j.conf import PoolConfig + +from ..._async_compat import ( + AsyncMagicMock, + mark_async_test, +) + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 0 + connection = AsyncBolt4x1(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = -1 + connection = AsyncBolt4x1(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 999999999 + connection = AsyncBolt4x1(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@mark_async_test +async def test_db_extra_in_begin(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) + connection.begin(db="something") + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x11" + assert len(fields) == 1 + assert fields[0] == {"db": "something"} + + +@mark_async_test +async def test_db_extra_in_run(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) + connection.run("", {}, db="something") + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x10" + assert len(fields) == 3 + assert fields[0] == "" + assert fields[1] == {} + assert fields[2] == {"db": "something"} + + +@mark_async_test +async def test_n_extra_in_discard(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ] +) +@mark_async_test +async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666, qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_n_extra_in_pull(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_qid_extra_in_pull(fake_socket, test_input, expected): + # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_async_test +async def test_n_and_qid_extras_in_pull(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=666, qid=777) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_async_test +async def test_hello_passes_routing_metadata(fake_socket_pair): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + await sockets.server.send_message(0x70, {"server": "Neo4j/4.1.0"}) + connection = AsyncBolt4x1( + address, sockets.client, PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"} + ) + await connection.hello() + tag, fields = await sockets.server.pop_message() + assert tag == 0x01 + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize("recv_timeout", (1, -1)) +@mark_async_test +async def test_hint_recv_timeout_seconds_gets_ignored( + fake_socket_pair, recv_timeout +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + sockets.client.settimeout = AsyncMagicMock() + await sockets.server.send_message(0x70, { + "server": "Neo4j/4.1.0", + "hints": {"connection.recv_timeout_seconds": recv_timeout}, + }) + connection = AsyncBolt4x1(address, sockets.client, + PoolConfig.max_connection_lifetime) + await connection.hello() + sockets.client.settimeout.assert_not_called() diff --git a/tests/unit/async_/io/test_class_bolt4x2.py b/tests/unit/async_/io/test_class_bolt4x2.py new file mode 100644 index 00000000..1a575b2b --- /dev/null +++ b/tests/unit/async_/io/test_class_bolt4x2.py @@ -0,0 +1,228 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._async.io._bolt4 import AsyncBolt4x2 +from neo4j.conf import PoolConfig + +from ..._async_compat import ( + AsyncMagicMock, + mark_async_test, +) + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 0 + connection = AsyncBolt4x2(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = -1 + connection = AsyncBolt4x2(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 999999999 + connection = AsyncBolt4x2(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@mark_async_test +async def test_db_extra_in_begin(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) + connection.begin(db="something") + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x11" + assert len(fields) == 1 + assert fields[0] == {"db": "something"} + + +@mark_async_test +async def test_db_extra_in_run(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) + connection.run("", {}, db="something") + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x10" + assert len(fields) == 3 + assert fields[0] == "" + assert fields[1] == {} + assert fields[2] == {"db": "something"} + + +@mark_async_test +async def test_n_extra_in_discard(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ] +) +@mark_async_test +async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666, qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_n_extra_in_pull(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_qid_extra_in_pull(fake_socket, test_input, expected): + # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_async_test +async def test_n_and_qid_extras_in_pull(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=666, qid=777) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_async_test +async def test_hello_passes_routing_metadata(fake_socket_pair): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + await sockets.server.send_message(0x70, {"server": "Neo4j/4.2.0"}) + connection = AsyncBolt4x2( + address, sockets.client, PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"} + ) + await connection.hello() + tag, fields = await sockets.server.pop_message() + assert tag == 0x01 + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize("recv_timeout", (1, -1)) +@mark_async_test +async def test_hint_recv_timeout_seconds_gets_ignored( + fake_socket_pair, recv_timeout +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + sockets.client.settimeout = AsyncMagicMock() + await sockets.server.send_message(0x70, { + "server": "Neo4j/4.2.0", + "hints": {"connection.recv_timeout_seconds": recv_timeout}, + }) + connection = AsyncBolt4x2( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + await connection.hello() + sockets.client.settimeout.assert_not_called() diff --git a/tests/unit/async_/io/test_class_bolt4x3.py b/tests/unit/async_/io/test_class_bolt4x3.py new file mode 100644 index 00000000..e3b22af2 --- /dev/null +++ b/tests/unit/async_/io/test_class_bolt4x3.py @@ -0,0 +1,255 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +import pytest + +from neo4j._async.io._bolt4 import AsyncBolt4x3 +from neo4j.conf import PoolConfig + +from ..._async_compat import ( + AsyncMagicMock, + mark_async_test, +) + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 0 + connection = AsyncBolt4x3(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = -1 + connection = AsyncBolt4x3(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 999999999 + connection = AsyncBolt4x3(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@mark_async_test +async def test_db_extra_in_begin(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection.begin(db="something") + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x11" + assert len(fields) == 1 + assert fields[0] == {"db": "something"} + + +@mark_async_test +async def test_db_extra_in_run(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection.run("", {}, db="something") + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x10" + assert len(fields) == 3 + assert fields[0] == "" + assert fields[1] == {} + assert fields[2] == {"db": "something"} + + +@mark_async_test +async def test_n_extra_in_discard(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ] +) +@mark_async_test +async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666, qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_n_extra_in_pull(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_qid_extra_in_pull(fake_socket, test_input, expected): + # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_async_test +async def test_n_and_qid_extras_in_pull(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=666, qid=777) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_async_test +async def test_hello_passes_routing_metadata(fake_socket_pair): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + await sockets.server.send_message(0x70, {"server": "Neo4j/4.3.0"}) + connection = AsyncBolt4x3( + address, sockets.client, PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"} + ) + await connection.hello() + tag, fields = await sockets.server.pop_message() + assert tag == 0x01 + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize(("hints", "valid"), ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), +)) +@mark_async_test +async def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + sockets.client.settimeout = AsyncMagicMock() + await sockets.server.send_message( + 0x70, {"server": "Neo4j/4.3.0", "hints": hints} + ) + connection = AsyncBolt4x3(address, sockets.client, + PoolConfig.max_connection_lifetime) + with caplog.at_level(logging.INFO): + await connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any("recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + else: + sockets.client.settimeout.assert_not_called() + assert any(repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) diff --git a/tests/unit/async_/io/test_class_bolt4x4.py b/tests/unit/async_/io/test_class_bolt4x4.py new file mode 100644 index 00000000..70ecafe0 --- /dev/null +++ b/tests/unit/async_/io/test_class_bolt4x4.py @@ -0,0 +1,271 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from unittest.mock import MagicMock + +import pytest + +from neo4j._async.io._bolt4 import AsyncBolt4x4 +from neo4j.conf import PoolConfig + +from ..._async_compat import ( + AsyncMagicMock, + mark_async_test, +) + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 0 + connection = AsyncBolt4x4(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = -1 + connection = AsyncBolt4x4(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 999999999 + connection = AsyncBolt4x4(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},) + ), +)) +@mark_async_test +async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) + connection.begin(*args, **kwargs) + await connection.send_all() + tag, is_fields = await socket.pop_message() + assert tag == b"\x11" + assert tuple(is_fields) == expected_fields + + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + (("", {}), {"imp_user": "imposter"}, ("", {}, {"imp_user": "imposter"})), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}) + ), +)) +@mark_async_test +async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) + connection.run(*args, **kwargs) + await connection.send_all() + tag, is_fields = await socket.pop_message() + assert tag == b"\x10" + assert tuple(is_fields) == expected_fields + + +@mark_async_test +async def test_n_extra_in_discard(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ] +) +@mark_async_test +async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666, qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_n_extra_in_pull(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_qid_extra_in_pull(fake_socket, test_input, expected): + # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_async_test +async def test_n_and_qid_extras_in_pull(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=666, qid=777) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_async_test +async def test_hello_passes_routing_metadata(fake_socket_pair): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + await sockets.server.send_message(0x70, {"server": "Neo4j/4.4.0"}) + connection = AsyncBolt4x4( + address, sockets.client, PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"} + ) + await connection.hello() + tag, fields = await sockets.server.pop_message() + assert tag == 0x01 + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize(("hints", "valid"), ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), +)) +@mark_async_test +async def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + sockets.client.settimeout = MagicMock() + await sockets.server.send_message( + 0x70, {"server": "Neo4j/4.3.4", "hints": hints} + ) + connection = AsyncBolt4x4( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + with caplog.at_level(logging.INFO): + await connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any("recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + else: + sockets.client.settimeout.assert_not_called() + assert any(repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py new file mode 100644 index 00000000..266b9779 --- /dev/null +++ b/tests/unit/async_/io/test_direct.py @@ -0,0 +1,231 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j import ( + Config, + PoolConfig, + WorkspaceConfig, +) +from neo4j._async.io import AsyncBolt +from neo4j._async.io._pool import AsyncIOPool +from neo4j.exceptions import ( + ClientError, + ServiceUnavailable, +) + +from ..._async_compat import ( + AsyncMock, + mark_async_test, + mock, +) + + +class AsyncFakeSocket: + def __init__(self, address): + self.address = address + + def getpeername(self): + return self.address + + async def sendall(self, data): + return + + def close(self): + return + + +class AsyncQuickConnection: + def __init__(self, socket): + self.socket = socket + self.address = socket.getpeername() + + @property + def is_reset(self): + return True + + def stale(self): + return False + + async def reset(self): + pass + + def close(self): + self.socket.close() + + def closed(self): + return False + + def defunct(self): + return False + + def timedout(self): + return False + + +class AsyncFakeBoltPool(AsyncIOPool): + + def __init__(self, address, *, auth=None, **config): + self.pool_config, self.workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) + if config: + raise ValueError("Unexpected config keys: %s" % ", ".join(config.keys())) + + async def opener(addr, timeout): + return AsyncQuickConnection(AsyncFakeSocket(addr)) + + super().__init__(opener, self.pool_config, self.workspace_config) + self.address = address + + async def acquire( + self, access_mode=None, timeout=None, database=None, bookmarks=None + ): + return await self._acquire(self.address, timeout) + + +@mark_async_test +async def test_bolt_connection_open(): + with pytest.raises(ServiceUnavailable): + await AsyncBolt.open(("localhost", 9999), auth=("test", "test")) + + +@mark_async_test +async def test_bolt_connection_open_timeout(): + with pytest.raises(ServiceUnavailable): + await AsyncBolt.open(("localhost", 9999), auth=("test", "test"), + timeout=1) + + +@mark_async_test +async def test_bolt_connection_ping(): + protocol_version = await AsyncBolt.ping(("localhost", 9999)) + assert protocol_version is None + + +@mark_async_test +async def test_bolt_connection_ping_timeout(): + protocol_version = await AsyncBolt.ping(("localhost", 9999), timeout=1) + assert protocol_version is None + + +@pytest.fixture +async def pool(): + async with AsyncFakeBoltPool(("127.0.0.1", 7687)) as pool: + yield pool + + +def assert_pool_size( address, expected_active, expected_inactive, pool): + try: + connections = pool.connections[address] + except KeyError: + assert 0 == expected_active + assert 0 == expected_inactive + else: + assert expected_active == len([cx for cx in connections if cx.in_use]) + assert (expected_inactive + == len([cx for cx in connections if not cx.in_use])) + + +@mark_async_test +async def test_pool_can_acquire(pool): + address = ("127.0.0.1", 7687) + connection = await pool._acquire(address, timeout=3) + assert connection.address == address + assert_pool_size(address, 1, 0, pool) + + +@mark_async_test +async def test_pool_can_acquire_twice(pool): + address = ("127.0.0.1", 7687) + connection_1 = await pool._acquire(address, timeout=3) + connection_2 = await pool._acquire(address, timeout=3) + assert connection_1.address == address + assert connection_2.address == address + assert connection_1 is not connection_2 + assert_pool_size(address, 2, 0, pool) + + +@mark_async_test +async def test_pool_can_acquire_two_addresses(pool): + address_1 = ("127.0.0.1", 7687) + address_2 = ("127.0.0.1", 7474) + connection_1 = await pool._acquire(address_1, timeout=3) + connection_2 = await pool._acquire(address_2, timeout=3) + assert connection_1.address == address_1 + assert connection_2.address == address_2 + assert_pool_size(address_1, 1, 0, pool) + assert_pool_size(address_2, 1, 0, pool) + + +@mark_async_test +async def test_pool_can_acquire_and_release(pool): + address = ("127.0.0.1", 7687) + connection = await pool._acquire(address, timeout=3) + assert_pool_size(address, 1, 0, pool) + await pool.release(connection) + assert_pool_size(address, 0, 1, pool) + + +@mark_async_test +async def test_pool_releasing_twice(pool): + address = ("127.0.0.1", 7687) + connection = await pool._acquire(address, timeout=3) + await pool.release(connection) + assert_pool_size(address, 0, 1, pool) + await pool.release(connection) + assert_pool_size(address, 0, 1, pool) + + +@mark_async_test +async def test_pool_in_use_count(pool): + address = ("127.0.0.1", 7687) + assert pool.in_use_connection_count(address) == 0 + connection = await pool._acquire(address, timeout=3) + assert pool.in_use_connection_count(address) == 1 + await pool.release(connection) + assert pool.in_use_connection_count(address) == 0 + + +@mark_async_test +async def test_pool_max_conn_pool_size(pool): + async with AsyncFakeBoltPool((), max_connection_pool_size=1) as pool: + address = ("127.0.0.1", 7687) + await pool._acquire(address, timeout=0) + assert pool.in_use_connection_count(address) == 1 + with pytest.raises(ClientError): + await pool._acquire(address, timeout=0) + assert pool.in_use_connection_count(address) == 1 + + +@pytest.mark.parametrize("is_reset", (True, False)) +@mark_async_test +async def test_pool_reset_when_released(is_reset, pool): + address = ("127.0.0.1", 7687) + quick_connection_name = AsyncQuickConnection.__name__ + with mock.patch(f"{__name__}.{quick_connection_name}.is_reset", + new_callable=mock.PropertyMock) as is_reset_mock: + with mock.patch(f"{__name__}.{quick_connection_name}.reset", + new_callable=AsyncMock) as reset_mock: + is_reset_mock.return_value = is_reset + connection = await pool._acquire(address, timeout=3) + assert isinstance(connection, AsyncQuickConnection) + assert is_reset_mock.call_count == 0 + assert reset_mock.call_count == 0 + await pool.release(connection) + assert is_reset_mock.call_count == 1 + assert reset_mock.call_count == int(not is_reset) diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py new file mode 100644 index 00000000..3962a4de --- /dev/null +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -0,0 +1,259 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from unittest.mock import Mock + +import pytest + +from neo4j import ( + READ_ACCESS, + WRITE_ACCESS, +) +from neo4j._async.io import AsyncNeo4jPool +from neo4j.addressing import ResolvedAddress +from neo4j.conf import ( + PoolConfig, + RoutingConfig, + WorkspaceConfig, +) + +from ..._async_compat import ( + AsyncMock, + mark_async_test, +) +from ..work import AsyncFakeConnection + + +ROUTER_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host") +READER_ADDRESS = ResolvedAddress(("1.2.3.1", 9002), host_name="host") +WRITER_ADDRESS = ResolvedAddress(("1.2.3.1", 9003), host_name="host") + + +@pytest.fixture() +def opener(): + async def open_(addr, timeout): + connection = AsyncFakeConnection() + connection.addr = addr + connection.timeout = timeout + route_mock = AsyncMock() + route_mock.return_value = [{ + "ttl": 1000, + "servers": [ + {"addresses": [str(ROUTER_ADDRESS)], "role": "ROUTE"}, + {"addresses": [str(READER_ADDRESS)], "role": "READ"}, + {"addresses": [str(WRITER_ADDRESS)], "role": "WRITE"}, + ], + }] + connection.attach_mock(route_mock, "route") + opener_.connections.append(connection) + return connection + + opener_ = AsyncMock() + opener_.connections = [] + opener_.side_effect = open_ + return opener_ + + +@mark_async_test +async def test_acquires_new_routing_table_if_deleted(opener): + pool = AsyncNeo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) + cx = await pool.acquire(READ_ACCESS, 30, "test_db", None) + await pool.release(cx) + assert pool.routing_tables.get("test_db") + + del pool.routing_tables["test_db"] + + cx = await pool.acquire(READ_ACCESS, 30, "test_db", None) + await pool.release(cx) + assert pool.routing_tables.get("test_db") + + +@mark_async_test +async def test_acquires_new_routing_table_if_stale(opener): + pool = AsyncNeo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) + cx = await pool.acquire(READ_ACCESS, 30, "test_db", None) + await pool.release(cx) + assert pool.routing_tables.get("test_db") + + old_value = pool.routing_tables["test_db"].last_updated_time + pool.routing_tables["test_db"].ttl = 0 + + cx = await pool.acquire(READ_ACCESS, 30, "test_db", None) + await pool.release(cx) + assert pool.routing_tables["test_db"].last_updated_time > old_value + + +@mark_async_test +async def test_removes_old_routing_table(opener): + pool = AsyncNeo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) + cx = await pool.acquire(READ_ACCESS, 30, "test_db1", None) + await pool.release(cx) + assert pool.routing_tables.get("test_db1") + cx = await pool.acquire(READ_ACCESS, 30, "test_db2", None) + await pool.release(cx) + assert pool.routing_tables.get("test_db2") + + old_value = pool.routing_tables["test_db1"].last_updated_time + pool.routing_tables["test_db1"].ttl = 0 + pool.routing_tables["test_db2"].ttl = \ + -RoutingConfig.routing_table_purge_delay + + cx = await pool.acquire(READ_ACCESS, 30, "test_db1", None) + await pool.release(cx) + assert pool.routing_tables["test_db1"].last_updated_time > old_value + assert "test_db2" not in pool.routing_tables + + +@pytest.mark.parametrize("type_", ("r", "w")) +@mark_async_test +async def test_chooses_right_connection_type(opener, type_): + pool = AsyncNeo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) + cx1 = await pool.acquire(READ_ACCESS if type_ == "r" else WRITE_ACCESS, + 30, "test_db", None) + await pool.release(cx1) + if type_ == "r": + assert cx1.addr == READER_ADDRESS + else: + assert cx1.addr == WRITER_ADDRESS + + +@mark_async_test +async def test_reuses_connection(opener): + pool = AsyncNeo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None) + await pool.release(cx1) + cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None) + assert cx1 is cx2 + + +@pytest.mark.parametrize("break_on_close", (True, False)) +@mark_async_test +async def test_closes_stale_connections(opener, break_on_close): + def break_connection(): + pool.deactivate(cx1.addr) + + if cx_close_mock_side_effect: + cx_close_mock_side_effect() + + pool = AsyncNeo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None) + await pool.release(cx1) + assert cx1 in pool.connections[cx1.addr] + # simulate connection going stale (e.g. exceeding) and then breaking when + # the pool tries to close the connection + cx1.stale.return_value = True + cx_close_mock = cx1.close + if break_on_close: + cx_close_mock_side_effect = cx_close_mock.side_effect + cx_close_mock.side_effect = break_connection + cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None) + await pool.release(cx2) + if break_on_close: + cx1.close.assert_called() + else: + cx1.close.assert_called_once() + assert cx2 is not cx1 + assert cx2.addr == cx1.addr + assert cx1 not in pool.connections[cx1.addr] + assert cx2 in pool.connections[cx2.addr] + + +@mark_async_test +async def test_does_not_close_stale_connections_in_use(opener): + pool = AsyncNeo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None) + assert cx1 in pool.connections[cx1.addr] + # simulate connection going stale (e.g. exceeding) while being in use + cx1.stale.return_value = True + cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None) + await pool.release(cx2) + cx1.close.assert_not_called() + assert cx2 is not cx1 + assert cx2.addr == cx1.addr + assert cx1 in pool.connections[cx1.addr] + assert cx2 in pool.connections[cx2.addr] + + await pool.release(cx1) + # now that cx1 is back in the pool and still stale, + # it should be closed when trying to acquire the next connection + cx1.close.assert_not_called() + + cx3 = await pool.acquire(READ_ACCESS, 30, "test_db", None) + await pool.release(cx3) + cx1.close.assert_called_once() + assert cx2 is cx3 + assert cx3.addr == cx1.addr + assert cx1 not in pool.connections[cx1.addr] + assert cx3 in pool.connections[cx2.addr] + + +@mark_async_test +async def test_release_resets_connections(opener): + pool = AsyncNeo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None) + cx1.is_reset_mock.return_value = False + cx1.is_reset_mock.reset_mock() + await pool.release(cx1) + cx1.is_reset_mock.assert_called_once() + cx1.reset.assert_called_once() + + +@mark_async_test +async def test_release_does_not_resets_closed_connections(opener): + pool = AsyncNeo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None) + cx1.closed.return_value = True + cx1.closed.reset_mock() + cx1.is_reset_mock.reset_mock() + await pool.release(cx1) + cx1.closed.assert_called_once() + cx1.is_reset_mock.asset_not_called() + cx1.reset.asset_not_called() + + +@mark_async_test +async def test_release_does_not_resets_defunct_connections(opener): + pool = AsyncNeo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None) + cx1.defunct.return_value = True + cx1.defunct.reset_mock() + cx1.is_reset_mock.reset_mock() + await pool.release(cx1) + cx1.defunct.assert_called_once() + cx1.is_reset_mock.asset_not_called() + cx1.reset.asset_not_called() diff --git a/tests/unit/async_/test_addressing.py b/tests/unit/async_/test_addressing.py new file mode 100644 index 00000000..69a5556f --- /dev/null +++ b/tests/unit/async_/test_addressing.py @@ -0,0 +1,125 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from socket import ( + AF_INET, + AF_INET6, +) +import unittest.mock as mock + +import pytest + +from neo4j import ( + Address, + IPv4Address, +) +from neo4j._async_compat.network import AsyncNetworkUtil +from neo4j._async_compat.util import AsyncUtil + +from .._async_compat import mark_async_test + + +mock_socket_ipv4 = mock.Mock() +mock_socket_ipv4.getpeername = lambda: ("127.0.0.1", 7687) # (address, port) + +mock_socket_ipv6 = mock.Mock() +mock_socket_ipv6.getpeername = lambda: ("[::1]", 7687, 0, 0) # (address, port, flow info, scope id) + + +@mark_async_test +async def test_address_resolve(): + address = Address(("127.0.0.1", 7687)) + resolved = AsyncNetworkUtil.resolve_address(address) + resolved = await AsyncUtil.list(resolved) + assert isinstance(resolved, Address) is False + assert isinstance(resolved, list) is True + assert len(resolved) == 1 + assert resolved[0] == IPv4Address(('127.0.0.1', 7687)) + + +@mark_async_test +async def test_address_resolve_with_custom_resolver_none(): + address = Address(("127.0.0.1", 7687)) + resolved = AsyncNetworkUtil.resolve_address(address, resolver=None) + resolved = await AsyncUtil.list(resolved) + assert isinstance(resolved, Address) is False + assert isinstance(resolved, list) is True + assert len(resolved) == 1 + assert resolved[0] == IPv4Address(('127.0.0.1', 7687)) + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (Address(("127.0.0.1", "abcd")), ValueError), + (Address((None, None)), ValueError), + ] + +) +@mark_async_test +async def test_address_resolve_with_unresolvable_address(test_input, expected): + with pytest.raises(expected): + await AsyncUtil.list( + AsyncNetworkUtil.resolve_address(test_input, resolver=None) + ) + + +@mark_async_test +@pytest.mark.parametrize("resolver_type", ("sync", "async")) +async def test_address_resolve_with_custom_resolver(resolver_type): + def custom_resolver_sync(_): + return [("127.0.0.1", 7687), ("localhost", 1234)] + + async def custom_resolver_async(_): + return [("127.0.0.1", 7687), ("localhost", 1234)] + + if resolver_type == "sync": + custom_resolver = custom_resolver_sync + else: + custom_resolver = custom_resolver_async + + address = Address(("127.0.0.1", 7687)) + resolved = AsyncNetworkUtil.resolve_address( + address, family=AF_INET, resolver=custom_resolver + ) + resolved = await AsyncUtil.list(resolved) + assert isinstance(resolved, Address) is False + assert isinstance(resolved, list) is True + assert len(resolved) == 2 # IPv4 only + assert resolved[0] == IPv4Address(('127.0.0.1', 7687)) + assert resolved[1] == IPv4Address(('127.0.0.1', 1234)) + + +@mark_async_test +async def test_address_unresolve(): + custom_resolved = [("127.0.0.1", 7687), ("localhost", 4321)] + custom_resolver = lambda _: custom_resolved + + address = Address(("foobar", 1234)) + unresolved = address.unresolved + assert address.__class__ == unresolved.__class__ + assert address == unresolved + resolved = AsyncNetworkUtil.resolve_address( + address, family=AF_INET, resolver=custom_resolver + ) + resolved = await AsyncUtil.list(resolved) + custom_resolved = sorted(Address(a) for a in custom_resolved) + unresolved = sorted(a.unresolved for a in resolved) + assert custom_resolved == unresolved + assert (list(map(lambda a: a.__class__, custom_resolved)) + == list(map(lambda a: a.__class__, unresolved))) diff --git a/tests/unit/async_/test_driver.py b/tests/unit/async_/test_driver.py new file mode 100644 index 00000000..7f23af09 --- /dev/null +++ b/tests/unit/async_/test_driver.py @@ -0,0 +1,157 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j import ( + AsyncBoltDriver, + AsyncGraphDatabase, + AsyncNeo4jDriver, + TRUST_ALL_CERTIFICATES, + TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, +) +from neo4j.api import WRITE_ACCESS +from neo4j.exceptions import ConfigurationError + +from .._async_compat import ( + mark_async_test, + mock, +) + + +@pytest.mark.parametrize("protocol", ("bolt://", "bolt+s://", "bolt+ssc://")) +@pytest.mark.parametrize("host", ("localhost", "127.0.0.1", + "[::1]", "[0:0:0:0:0:0:0:1]")) +@pytest.mark.parametrize("port", (":1234", "", ":7687")) +@pytest.mark.parametrize("auth_token", (("test", "test"), None)) +def test_direct_driver_constructor(protocol, host, port, auth_token): + uri = protocol + host + port + driver = AsyncGraphDatabase.driver(uri, auth=auth_token) + assert isinstance(driver, AsyncBoltDriver) + + +@pytest.mark.parametrize("protocol", ("neo4j://", "neo4j+s://", "neo4j+ssc://")) +@pytest.mark.parametrize("host", ("localhost", "127.0.0.1", + "[::1]", "[0:0:0:0:0:0:0:1]")) +@pytest.mark.parametrize("port", (":1234", "", ":7687")) +@pytest.mark.parametrize("auth_token", (("test", "test"), None)) +def test_routing_driver_constructor(protocol, host, port, auth_token): + uri = protocol + host + port + driver = AsyncGraphDatabase.driver(uri, auth=auth_token) + assert isinstance(driver, AsyncNeo4jDriver) + + +@pytest.mark.parametrize("test_uri", ( + "bolt+ssc://127.0.0.1:9001", + "bolt+s://127.0.0.1:9001", + "bolt://127.0.0.1:9001", + "neo4j+ssc://127.0.0.1:9001", + "neo4j+s://127.0.0.1:9001", + "neo4j://127.0.0.1:9001", +)) +@pytest.mark.parametrize( + ("test_config", "expected_failure", "expected_failure_message"), + ( + ({"encrypted": False}, ConfigurationError, "The config settings"), + ({"encrypted": True}, ConfigurationError, "The config settings"), + ( + {"encrypted": True, "trust": TRUST_ALL_CERTIFICATES}, + ConfigurationError, "The config settings" + ), + ( + {"trust": TRUST_ALL_CERTIFICATES}, + ConfigurationError, "The config settings" + ), + ( + {"trust": TRUST_SYSTEM_CA_SIGNED_CERTIFICATES}, + ConfigurationError, "The config settings" + ), + ) +) +def test_driver_config_error( + test_uri, test_config, expected_failure, expected_failure_message +): + if "+" in test_uri: + # `+s` and `+ssc` are short hand syntax for not having to configure the + # encryption behavior of the driver. Specifying both is invalid. + with pytest.raises(expected_failure, match=expected_failure_message): + AsyncGraphDatabase.driver(test_uri, **test_config) + else: + AsyncGraphDatabase.driver(test_uri, **test_config) + + +@pytest.mark.parametrize("test_uri", ( + "http://localhost:9001", + "ftp://localhost:9001", + "x://localhost:9001", +)) +def test_invalid_protocol(test_uri): + with pytest.raises(ConfigurationError, match="scheme"): + AsyncGraphDatabase.driver(test_uri) + + +@pytest.mark.parametrize( + ("test_config", "expected_failure", "expected_failure_message"), + ( + ({"trust": 1}, ConfigurationError, "The config setting `trust`"), + ({"trust": True}, ConfigurationError, "The config setting `trust`"), + ({"trust": None}, ConfigurationError, "The config setting `trust`"), + ) +) +def test_driver_trust_config_error( + test_config, expected_failure, expected_failure_message +): + with pytest.raises(expected_failure, match=expected_failure_message): + AsyncGraphDatabase.driver("bolt://127.0.0.1:9001", **test_config) + + +@pytest.mark.parametrize("uri", ( + "bolt://127.0.0.1:9000", + "neo4j://127.0.0.1:9000", +)) +@mark_async_test +async def test_driver_opens_write_session_by_default(uri, mocker): + driver = AsyncGraphDatabase.driver(uri) + from neo4j import AsyncTransaction + + # we set a specific db, because else the driver would try to fetch a RT + # to get hold of the actual home database (which won't work in this + # unittest) + async with driver.session(database="foobar") as session: + with mock.patch.object( + session._pool, "acquire", autospec=True + ) as acquire_mock: + with mock.patch.object( + AsyncTransaction, "_begin", autospec=True + ) as tx_begin_mock: + tx = await session.begin_transaction() + acquire_mock.assert_called_once_with( + access_mode=WRITE_ACCESS, + timeout=mocker.ANY, + database=mocker.ANY, + bookmarks=mocker.ANY + ) + tx_begin_mock.assert_called_once_with( + tx, + mocker.ANY, + mocker.ANY, + mocker.ANY, + WRITE_ACCESS, + mocker.ANY, + mocker.ANY + ) diff --git a/tests/unit/async_/work/__init__.py b/tests/unit/async_/work/__init__.py new file mode 100644 index 00000000..3bfbf0ed --- /dev/null +++ b/tests/unit/async_/work/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ._fake_connection import ( + async_fake_connection, + AsyncFakeConnection, +) diff --git a/tests/unit/async_/work/_fake_connection.py b/tests/unit/async_/work/_fake_connection.py new file mode 100644 index 00000000..c3bf9b96 --- /dev/null +++ b/tests/unit/async_/work/_fake_connection.py @@ -0,0 +1,111 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect + +import pytest + +from neo4j import ServerInfo +from neo4j._async.io import AsyncBolt + +from ..._async_compat import ( + AsyncMock, + mock, + Mock, +) + + +class AsyncFakeConnection(mock.NonCallableMagicMock): + callbacks = [] + server_info = ServerInfo("127.0.0.1", (4, 3)) + + def __init__(self, *args, **kwargs): + kwargs["spec"] = AsyncBolt + super().__init__(*args, **kwargs) + self.attach_mock(Mock(return_value=True), "is_reset_mock") + self.attach_mock(Mock(return_value=False), "defunct") + self.attach_mock(Mock(return_value=False), "stale") + self.attach_mock(Mock(return_value=False), "closed") + self.attach_mock(Mock(), "unresolved_address") + + def close_side_effect(): + self.closed.return_value = True + + self.attach_mock(AsyncMock(side_effect=close_side_effect), + "close") + + @property + def is_reset(self): + if self.closed.return_value or self.defunct.return_value: + raise AssertionError( + "is_reset should not be called on a closed or defunct " + "connection." + ) + return self.is_reset_mock() + + async def fetch_message(self, *args, **kwargs): + if self.callbacks: + cb = self.callbacks.pop(0) + await cb() + return await super().__getattr__("fetch_message")(*args, **kwargs) + + async def fetch_all(self, *args, **kwargs): + while self.callbacks: + cb = self.callbacks.pop(0) + cb() + return await super().__getattr__("fetch_all")(*args, **kwargs) + + def __getattr__(self, name): + parent = super() + + def build_message_handler(name): + def func(*args, **kwargs): + async def callback(): + for cb_name, param_count in ( + ("on_success", 1), + ("on_summary", 0) + ): + cb = kwargs.get(cb_name, None) + if callable(cb): + try: + param_count = \ + len(inspect.signature(cb).parameters) + except ValueError: + # e.g. built-in method as cb + pass + if param_count == 1: + res = cb({}) + else: + res = cb() + try: + await res # maybe the callback is async + except TypeError: + pass # or maybe it wasn't ;) + self.callbacks.append(callback) + + return func + + method_mock = parent.__getattr__(name) + if name in ("run", "commit", "pull", "rollback", "discard"): + method_mock.side_effect = build_message_handler(name) + return method_mock + + +@pytest.fixture +def async_fake_connection(): + return AsyncFakeConnection() diff --git a/tests/unit/async_/work/test_result.py b/tests/unit/async_/work/test_result.py new file mode 100644 index 00000000..d73acaab --- /dev/null +++ b/tests/unit/async_/work/test_result.py @@ -0,0 +1,456 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from unittest import mock + +import pytest + +from neo4j import ( + Address, + AsyncResult, + Record, + ResultSummary, + ServerInfo, + SummaryCounters, + Version, +) +from neo4j._async_compat.util import AsyncUtil +from neo4j.data import DataHydrator + +from ..._async_compat import mark_async_test + + +class Records: + def __init__(self, fields, records): + assert all(len(fields) == len(r) for r in records) + self.fields = fields + # self.records = [{"record_values": r} for r in records] + self.records = records + + def __len__(self): + return self.records.__len__() + + def __iter__(self): + return self.records.__iter__() + + def __getitem__(self, item): + return self.records.__getitem__(item) + + +class AsyncConnectionStub: + class Message: + def __init__(self, message, *args, **kwargs): + self.message = message + self.args = args + self.kwargs = kwargs + + async def _cb(self, cb_name, *args, **kwargs): + # print(self.message, cb_name.upper(), args, kwargs) + cb = self.kwargs.get(cb_name) + await AsyncUtil.callback(cb, *args, **kwargs) + + async def on_success(self, metadata): + await self._cb("on_success", metadata) + + async def on_summary(self): + await self._cb("on_summary") + + async def on_records(self, records): + await self._cb("on_records", records) + + def __eq__(self, other): + return self.message == other + + def __repr__(self): + return "Message(%s)" % self.message + + def __init__(self, records=None, run_meta=None, summary_meta=None, + force_qid=False): + self._multi_result = isinstance(records, (list, tuple)) + if self._multi_result: + self._records = records + self._use_qid = True + else: + self._records = records, + self._use_qid = force_qid + self.fetch_idx = 0 + self._qid = -1 + self.most_recent_qid = None + self.record_idxs = [0] * len(self._records) + self.to_pull = [None] * len(self._records) + self._exhausted = [False] * len(self._records) + self.queued = [] + self.sent = [] + self.run_meta = run_meta + self.summary_meta = summary_meta + AsyncConnectionStub.server_info.update({"server": "Neo4j/4.3.0"}) + self.unresolved_address = None + + async def send_all(self): + self.sent += self.queued + self.queued = [] + + async def fetch_message(self): + if self.fetch_idx >= len(self.sent): + pytest.fail("Waits for reply to never sent message") + msg = self.sent[self.fetch_idx] + if msg == "RUN": + self.fetch_idx += 1 + self._qid += 1 + meta = {"fields": self._records[self._qid].fields, + **(self.run_meta or {})} + if self._use_qid: + meta.update(qid=self._qid) + await msg.on_success(meta) + elif msg == "DISCARD": + self.fetch_idx += 1 + qid = msg.kwargs.get("qid", -1) + if qid < 0: + qid = self._qid + self.record_idxs[qid] = len(self._records[qid]) + await msg.on_success(self.summary_meta or {}) + await msg.on_summary() + elif msg == "PULL": + qid = msg.kwargs.get("qid", -1) + if qid < 0: + qid = self._qid + if self._exhausted[qid]: + pytest.fail("PULLing exhausted result") + if self.to_pull[qid] is None: + n = msg.kwargs.get("n", -1) + if n < 0: + n = len(self._records[qid]) + self.to_pull[qid] = \ + min(n, len(self._records[qid]) - self.record_idxs[qid]) + # if to == len(self._records): + # self.fetch_idx += 1 + if self.to_pull[qid] > 0: + record = self._records[qid][self.record_idxs[qid]] + self.record_idxs[qid] += 1 + self.to_pull[qid] -= 1 + await msg.on_records([record]) + elif self.to_pull[qid] == 0: + self.to_pull[qid] = None + self.fetch_idx += 1 + if self.record_idxs[qid] < len(self._records[qid]): + await msg.on_success({"has_more": True}) + else: + await msg.on_success( + {"bookmark": "foo", **(self.summary_meta or {})} + ) + self._exhausted[qid] = True + await msg.on_summary() + + async def fetch_all(self): + while self.fetch_idx < len(self.sent): + await self.fetch_message() + + def run(self, *args, **kwargs): + self.queued.append(AsyncConnectionStub.Message("RUN", *args, **kwargs)) + + def discard(self, *args, **kwargs): + self.queued.append(AsyncConnectionStub.Message("DISCARD", *args, **kwargs)) + + def pull(self, *args, **kwargs): + self.queued.append(AsyncConnectionStub.Message("PULL", *args, **kwargs)) + + server_info = ServerInfo(Address(("bolt://localhost", 7687)), Version(4, 3)) + + def defunct(self): + return False + + +class HydratorStub(DataHydrator): + def hydrate(self, values): + return values + + +def noop(*_, **__): + pass + + +async def fetch_and_compare_all_records( + result, key, expected_records, method, limit=None +): + received_records = [] + if method == "for loop": + async for record in result: + assert isinstance(record, Record) + received_records.append([record.data().get(key, None)]) + if limit is not None and len(received_records) == limit: + break + if limit is None: + assert result._closed + elif method == "next": + iter_ = AsyncUtil.iter(result) + n = len(expected_records) if limit is None else limit + for _ in range(n): + record = await AsyncUtil.next(iter_) + received_records.append([record.get(key, None)]) + if limit is None: + with pytest.raises(StopAsyncIteration): + await AsyncUtil.next(iter_) + assert result._closed + elif method == "new iter": + n = len(expected_records) if limit is None else limit + for _ in range(n): + iter_ = AsyncUtil.iter(result) + record = await AsyncUtil.next(iter_) + received_records.append([record.get(key, None)]) + if limit is None: + iter_ = AsyncUtil.iter(result) + with pytest.raises(StopAsyncIteration): + await AsyncUtil.next(iter_) + assert result._closed + else: + raise ValueError() + assert received_records == expected_records + + +@pytest.mark.parametrize("method", ("for loop", "next", "new iter")) +@pytest.mark.parametrize("records", ( + [], + [[42]], + [[1], [2], [3], [4], [5]], +)) +@mark_async_test +async def test_result_iteration(method, records): + connection = AsyncConnectionStub(records=Records(["x"], records)) + result = AsyncResult(connection, HydratorStub(), 2, noop, noop) + await result._run("CYPHER", {}, None, None, "r", None) + await fetch_and_compare_all_records(result, "x", records, method) + + +@pytest.mark.parametrize("method", ("for loop", "next", "new iter")) +@pytest.mark.parametrize("invert_fetch", (True, False)) +@mark_async_test +async def test_parallel_result_iteration(method, invert_fetch): + records1 = [[i] for i in range(1, 6)] + records2 = [[i] for i in range(6, 11)] + connection = AsyncConnectionStub( + records=(Records(["x"], records1), Records(["x"], records2)) + ) + result1 = AsyncResult(connection, HydratorStub(), 2, noop, noop) + await result1._run("CYPHER1", {}, None, None, "r", None) + result2 = AsyncResult(connection, HydratorStub(), 2, noop, noop) + await result2._run("CYPHER2", {}, None, None, "r", None) + if invert_fetch: + await fetch_and_compare_all_records( + result2, "x", records2, method + ) + await fetch_and_compare_all_records( + result1, "x", records1, method + ) + else: + await fetch_and_compare_all_records( + result1, "x", records1, method + ) + await fetch_and_compare_all_records( + result2, "x", records2, method + ) + + +@pytest.mark.parametrize("method", ("for loop", "next", "new iter")) +@pytest.mark.parametrize("invert_fetch", (True, False)) +@mark_async_test +async def test_interwoven_result_iteration(method, invert_fetch): + records1 = [[i] for i in range(1, 10)] + records2 = [[i] for i in range(11, 20)] + connection = AsyncConnectionStub( + records=(Records(["x"], records1), Records(["y"], records2)) + ) + result1 = AsyncResult(connection, HydratorStub(), 2, noop, noop) + await result1._run("CYPHER1", {}, None, None, "r", None) + result2 = AsyncResult(connection, HydratorStub(), 2, noop, noop) + await result2._run("CYPHER2", {}, None, None, "r", None) + start = 0 + for n in (1, 2, 3, 1, None): + end = n if n is None else start + n + if invert_fetch: + await fetch_and_compare_all_records( + result2, "y", records2[start:end], method, n + ) + await fetch_and_compare_all_records( + result1, "x", records1[start:end], method, n + ) + else: + await fetch_and_compare_all_records( + result1, "x", records1[start:end], method, n + ) + await fetch_and_compare_all_records( + result2, "y", records2[start:end], method, n + ) + start = end + + +@pytest.mark.parametrize("records", ([[1], [2]], [[1]], [])) +@pytest.mark.parametrize("fetch_size", (1, 2)) +@mark_async_test +async def test_result_peek(records, fetch_size): + connection = AsyncConnectionStub(records=Records(["x"], records)) + result = AsyncResult(connection, HydratorStub(), fetch_size, noop, noop) + await result._run("CYPHER", {}, None, None, "r", None) + for i in range(len(records) + 1): + record = await result.peek() + if i == len(records): + assert record is None + else: + assert isinstance(record, Record) + assert record.get("x") == records[i][0] + iter_ = AsyncUtil.iter(result) + await AsyncUtil.next(iter_) # consume the record + + +@pytest.mark.parametrize("records", ([[1], [2]], [[1]], [])) +@pytest.mark.parametrize("fetch_size", (1, 2)) +@mark_async_test +async def test_result_single(records, fetch_size): + connection = AsyncConnectionStub(records=Records(["x"], records)) + result = AsyncResult(connection, HydratorStub(), fetch_size, noop, noop) + await result._run("CYPHER", {}, None, None, "r", None) + with pytest.warns(None) as warning_record: + record = await result.single() + if not records: + assert not warning_record + assert record is None + else: + if len(records) > 1: + assert len(warning_record) == 1 + else: + assert not warning_record + assert isinstance(record, Record) + assert record.get("x") == records[0][0] + + +@mark_async_test +async def test_keys_are_available_before_and_after_stream(): + connection = AsyncConnectionStub(records=Records(["x"], [[1], [2]])) + result = AsyncResult(connection, HydratorStub(), 1, noop, noop) + await result._run("CYPHER", {}, None, None, "r", None) + assert list(result.keys()) == ["x"] + await AsyncUtil.list(result) + assert list(result.keys()) == ["x"] + + +@pytest.mark.parametrize("records", ([[1], [2]], [[1]], [])) +@pytest.mark.parametrize("consume_one", (True, False)) +@pytest.mark.parametrize("summary_meta", (None, {"database": "foobar"})) +@mark_async_test +async def test_consume(records, consume_one, summary_meta): + connection = AsyncConnectionStub( + records=Records(["x"], records), summary_meta=summary_meta + ) + result = AsyncResult(connection, HydratorStub(), 1, noop, noop) + await result._run("CYPHER", {}, None, None, "r", None) + if consume_one: + try: + await AsyncUtil.next(AsyncUtil.iter(result)) + except StopAsyncIteration: + pass + summary = await result.consume() + assert isinstance(summary, ResultSummary) + if summary_meta and "db" in summary_meta: + assert summary.database == summary_meta["db"] + else: + assert summary.database is None + server_info = summary.server + assert isinstance(server_info, ServerInfo) + assert server_info.version_info() == Version(4, 3) + assert server_info.protocol_version == Version(4, 3) + assert isinstance(summary.counters, SummaryCounters) + + +@pytest.mark.parametrize("t_first", (None, 0, 1, 123456789)) +@pytest.mark.parametrize("t_last", (None, 0, 1, 123456789)) +@mark_async_test +async def test_time_in_summary(t_first, t_last): + run_meta = None + if t_first is not None: + run_meta = {"t_first": t_first} + summary_meta = None + if t_last is not None: + summary_meta = {"t_last": t_last} + connection = AsyncConnectionStub( + records=Records(["n"], [[i] for i in range(100)]), run_meta=run_meta, + summary_meta=summary_meta + ) + + result = AsyncResult(connection, HydratorStub(), 1, noop, noop) + await result._run("CYPHER", {}, None, None, "r", None) + summary = await result.consume() + + if t_first is not None: + assert isinstance(summary.result_available_after, int) + assert summary.result_available_after == t_first + else: + assert summary.result_available_after is None + if t_last is not None: + assert isinstance(summary.result_consumed_after, int) + assert summary.result_consumed_after == t_last + else: + assert summary.result_consumed_after is None + assert not hasattr(summary, "t_first") + assert not hasattr(summary, "t_last") + + +@mark_async_test +async def test_counts_in_summary(): + connection = AsyncConnectionStub(records=Records(["n"], [[1], [2]])) + + result = AsyncResult(connection, HydratorStub(), 1, noop, noop) + await result._run("CYPHER", {}, None, None, "r", None) + summary = await result.consume() + + assert isinstance(summary.counters, SummaryCounters) + + +@pytest.mark.parametrize("query_type", ("r", "w", "rw", "s")) +@mark_async_test +async def test_query_type(query_type): + connection = AsyncConnectionStub( + records=Records(["n"], [[1], [2]]), summary_meta={"type": query_type} + ) + + result = AsyncResult(connection, HydratorStub(), 1, noop, noop) + await result._run("CYPHER", {}, None, None, "r", None) + summary = await result.consume() + + assert isinstance(summary.query_type, str) + assert summary.query_type == query_type + + +@pytest.mark.parametrize("num_records", range(0, 5)) +@mark_async_test +async def test_data(num_records): + connection = AsyncConnectionStub( + records=Records(["n"], [[i + 1] for i in range(num_records)]) + ) + + result = AsyncResult(connection, HydratorStub(), 1, noop, noop) + await result._run("CYPHER", {}, None, None, "r", None) + await result._buffer_all() + records = result._record_buffer.copy() + assert len(records) == num_records + expected_data = [] + for i, record in enumerate(records): + record.data = mock.Mock() + expected_data.append("magic_return_%s" % i) + record.data.return_value = expected_data[-1] + assert await result.data("hello", "world") == expected_data + for record in records: + assert record.data.called_once_with("hello", "world") diff --git a/tests/unit/async_/work/test_session.py b/tests/unit/async_/work/test_session.py new file mode 100644 index 00000000..5da701da --- /dev/null +++ b/tests/unit/async_/work/test_session.py @@ -0,0 +1,285 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from contextlib import contextmanager + +import pytest + +from neo4j import ( + AsyncSession, + AsyncTransaction, + SessionConfig, + unit_of_work, +) +from neo4j._async.io._pool import AsyncIOPool + +from ..._async_compat import ( + AsyncMock, + mark_async_test, + mock, +) +from ._fake_connection import AsyncFakeConnection + + +@pytest.fixture() +def pool(): + pool = AsyncMock(spec=AsyncIOPool) + pool.acquire.side_effect = iter(AsyncFakeConnection, 0) + return pool + + +@mark_async_test +async def test_session_context_calls_close(): + s = AsyncSession(None, SessionConfig()) + with mock.patch.object(s, 'close', autospec=True) as mock_close: + async with s: + pass + mock_close.assert_called_once_with() + + +@pytest.mark.parametrize("test_run_args", ( + ("RETURN $x", {"x": 1}), ("RETURN 1",) +)) +@pytest.mark.parametrize(("repetitions", "consume"), ( + (1, False), (2, False), (2, True) +)) +@mark_async_test +async def test_opens_connection_on_run( + pool, test_run_args, repetitions, consume +): + async with AsyncSession(pool, SessionConfig()) as session: + assert session._connection is None + result = await session.run(*test_run_args) + assert session._connection is not None + if consume: + await result.consume() + + +@pytest.mark.parametrize("test_run_args", ( + ("RETURN $x", {"x": 1}), ("RETURN 1",) +)) +@pytest.mark.parametrize("repetitions", range(1, 3)) +@mark_async_test +async def test_closes_connection_after_consume( + pool, test_run_args, repetitions +): + async with AsyncSession(pool, SessionConfig()) as session: + result = await session.run(*test_run_args) + await result.consume() + assert session._connection is None + assert session._connection is None + + +@pytest.mark.parametrize("test_run_args", ( + ("RETURN $x", {"x": 1}), ("RETURN 1",) +)) +@mark_async_test +async def test_keeps_connection_until_last_result_consumed( + pool, test_run_args +): + async with AsyncSession(pool, SessionConfig()) as session: + result1 = await session.run(*test_run_args) + result2 = await session.run(*test_run_args) + assert session._connection is not None + await result1.consume() + assert session._connection is not None + await result2.consume() + assert session._connection is None + + +@mark_async_test +async def test_opens_connection_on_tx_begin(pool): + async with AsyncSession(pool, SessionConfig()) as session: + assert session._connection is None + async with await session.begin_transaction() as _: + assert session._connection is not None + + +@pytest.mark.parametrize("test_run_args", ( + ("RETURN $x", {"x": 1}), ("RETURN 1",) +)) +@pytest.mark.parametrize("repetitions", range(1, 3)) +@mark_async_test +async def test_keeps_connection_on_tx_run(pool, test_run_args, repetitions): + async with AsyncSession(pool, SessionConfig()) as session: + async with await session.begin_transaction() as tx: + for _ in range(repetitions): + await tx.run(*test_run_args) + assert session._connection is not None + + +@pytest.mark.parametrize("test_run_args", ( + ("RETURN $x", {"x": 1}), ("RETURN 1",) +)) +@pytest.mark.parametrize("repetitions", range(1, 3)) +@mark_async_test +async def test_keeps_connection_on_tx_consume( + pool, test_run_args, repetitions +): + async with AsyncSession(pool, SessionConfig()) as session: + async with await session.begin_transaction() as tx: + for _ in range(repetitions): + result = await tx.run(*test_run_args) + await result.consume() + assert session._connection is not None + + +@pytest.mark.parametrize("test_run_args", ( + ("RETURN $x", {"x": 1}), ("RETURN 1",) +)) +@mark_async_test +async def test_closes_connection_after_tx_close(pool, test_run_args): + async with AsyncSession(pool, SessionConfig()) as session: + async with await session.begin_transaction() as tx: + for _ in range(2): + result = await tx.run(*test_run_args) + await result.consume() + await tx.close() + assert session._connection is None + assert session._connection is None + + +@pytest.mark.parametrize("test_run_args", ( + ("RETURN $x", {"x": 1}), ("RETURN 1",) +)) +@mark_async_test +async def test_closes_connection_after_tx_commit(pool, test_run_args): + async with AsyncSession(pool, SessionConfig()) as session: + async with await session.begin_transaction() as tx: + for _ in range(2): + result = await tx.run(*test_run_args) + await result.consume() + await tx.commit() + assert session._connection is None + assert session._connection is None + + +@pytest.mark.parametrize("bookmarks", (None, [], ["abc"], ["foo", "bar"])) +@mark_async_test +async def test_session_returns_bookmark_directly(pool, bookmarks): + async with AsyncSession( + pool, SessionConfig(bookmarks=bookmarks) + ) as session: + if bookmarks: + assert await session.last_bookmark() == bookmarks[-1] + else: + assert await session.last_bookmark() is None + + +@pytest.mark.parametrize(("query", "error_type"), ( + (None, ValueError), + (1234, TypeError), + ({"how about": "no?"}, TypeError), + (["I don't", "think so"], TypeError), +)) +@mark_async_test +async def test_session_run_wrong_types(pool, query, error_type): + async with AsyncSession(pool, SessionConfig()) as session: + with pytest.raises(error_type): + await session.run(query) + + +@pytest.mark.parametrize("tx_type", ("write_transaction", "read_transaction")) +@mark_async_test +async def test_tx_function_argument_type(pool, tx_type): + async def work(tx): + assert isinstance(tx, AsyncTransaction) + + async with AsyncSession(pool, SessionConfig()) as session: + getattr(session, tx_type)(work) + + +@pytest.mark.parametrize("tx_type", ("write_transaction", "read_transaction")) +@pytest.mark.parametrize("decorator_kwargs", ( + {}, + {"timeout": 5}, + {"metadata": {"foo": "bar"}}, + {"timeout": 5, "metadata": {"foo": "bar"}}, + +)) +@mark_async_test +async def test_decorated_tx_function_argument_type(pool, tx_type, decorator_kwargs): + @unit_of_work(**decorator_kwargs) + async def work(tx): + assert isinstance(tx, AsyncTransaction) + + async with AsyncSession(pool, SessionConfig()) as session: + getattr(session, tx_type)(work) + + +@mark_async_test +async def test_session_tx_type(pool): + async with AsyncSession(pool, SessionConfig()) as session: + tx = await session.begin_transaction() + assert isinstance(tx, AsyncTransaction) + + +@pytest.mark.parametrize(("parameters", "error_type"), ( + ({"x": None}, None), + ({"x": True}, None), + ({"x": False}, None), + ({"x": 123456789}, None), + ({"x": 3.1415926}, None), + ({"x": float("nan")}, None), + ({"x": float("inf")}, None), + ({"x": float("-inf")}, None), + ({"x": "foo"}, None), + ({"x": bytearray([0x00, 0x33, 0x66, 0x99, 0xCC, 0xFF])}, None), + ({"x": b"\x00\x33\x66\x99\xcc\xff"}, None), + ({"x": [1, 2, 3]}, None), + ({"x": ["a", "b", "c"]}, None), + ({"x": ["a", 2, 1.234]}, None), + ({"x": ["a", 2, ["c"]]}, None), + ({"x": {"one": "eins", "two": "zwei", "three": "drei"}}, None), + ({"x": {"one": ["eins", "uno", 1], "two": ["zwei", "dos", 2]}}, None), + + # maps must have string keys + ({"x": {1: 'eins', 2: 'zwei', 3: 'drei'}}, TypeError), + ({"x": {(1, 2): '1+2i', (2, 0): '2'}}, TypeError), +)) +@pytest.mark.parametrize("run_type", ("auto", "unmanaged", "managed")) +@mark_async_test +async def test_session_run_with_parameters( + pool, parameters, error_type, run_type +): + @contextmanager + def raises(): + if error_type is not None: + with pytest.raises(error_type) as exc: + yield exc + else: + yield None + + async with AsyncSession(pool, SessionConfig()) as session: + if run_type == "auto": + with raises(): + await session.run("RETURN $x", **parameters) + elif run_type == "unmanaged": + tx = await session.begin_transaction() + with raises(): + await tx.run("RETURN $x", **parameters) + elif run_type == "managed": + async def work(tx): + with raises() as exc: + await tx.run("RETURN $x", **parameters) + if exc is not None: + raise exc + with raises(): + await session.write_transaction(work) + else: + raise ValueError(run_type) diff --git a/tests/unit/async_/work/test_transaction.py b/tests/unit/async_/work/test_transaction.py new file mode 100644 index 00000000..6eafbb1a --- /dev/null +++ b/tests/unit/async_/work/test_transaction.py @@ -0,0 +1,185 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from unittest.mock import MagicMock +from uuid import uuid4 + +import pytest + +from neo4j import ( + Query, + Transaction, +) + +from ._fake_connection import async_fake_connection + + +@pytest.mark.parametrize(("explicit_commit", "close"), ( + (False, False), + (True, False), + (True, True), +)) +def test_transaction_context_when_committing(mocker, async_fake_connection, + explicit_commit, close): + on_closed = MagicMock() + on_error = MagicMock() + tx = Transaction(async_fake_connection, 2, on_closed, on_error) + mock_commit = mocker.patch.object(tx, "commit", wraps=tx.commit) + mock_rollback = mocker.patch.object(tx, "rollback", wraps=tx.rollback) + with tx as tx_: + assert mock_commit.call_count == 0 + assert mock_rollback.call_count == 0 + assert tx is tx_ + if explicit_commit: + tx_.commit() + mock_commit.assert_called_once_with() + assert tx.closed() + if close: + tx_.close() + assert tx_.closed() + mock_commit.assert_called_once_with() + assert mock_rollback.call_count == 0 + assert tx_.closed() + + +@pytest.mark.parametrize(("rollback", "close"), ( + (True, False), + (False, True), + (True, True), +)) +def test_transaction_context_with_explicit_rollback(mocker, async_fake_connection, + rollback, close): + on_closed = MagicMock() + on_error = MagicMock() + tx = Transaction(async_fake_connection, 2, on_closed, on_error) + mock_commit = mocker.patch.object(tx, "commit", wraps=tx.commit) + mock_rollback = mocker.patch.object(tx, "rollback", wraps=tx.rollback) + with tx as tx_: + assert mock_commit.call_count == 0 + assert mock_rollback.call_count == 0 + assert tx is tx_ + if rollback: + tx_.rollback() + mock_rollback.assert_called_once_with() + assert tx_.closed() + if close: + tx_.close() + mock_rollback.assert_called_once_with() + assert tx_.closed() + assert mock_commit.call_count == 0 + mock_rollback.assert_called_once_with() + assert tx_.closed() + + +def test_transaction_context_calls_rollback_on_error(mocker, async_fake_connection): + class OopsError(RuntimeError): + pass + + on_closed = MagicMock() + on_error = MagicMock() + tx = Transaction(async_fake_connection, 2, on_closed, on_error) + mock_commit = mocker.patch.object(tx, "commit", wraps=tx.commit) + mock_rollback = mocker.patch.object(tx, "rollback", wraps=tx.rollback) + with pytest.raises(OopsError): + with tx as tx_: + assert mock_commit.call_count == 0 + assert mock_rollback.call_count == 0 + assert tx is tx_ + raise OopsError + assert mock_commit.call_count == 0 + mock_rollback.assert_called_once_with() + assert tx_.closed() + + +@pytest.mark.parametrize(("parameters", "error_type"), ( + # maps must have string keys + ({"x": {1: 'eins', 2: 'zwei', 3: 'drei'}}, TypeError), + ({"x": {(1, 2): '1+2i', (2, 0): '2'}}, TypeError), + ({"x": uuid4()}, TypeError), +)) +def test_transaction_run_with_invalid_parameters(async_fake_connection, parameters, + error_type): + on_closed = MagicMock() + on_error = MagicMock() + tx = Transaction(async_fake_connection, 2, on_closed, on_error) + with pytest.raises(error_type): + tx.run("RETURN $x", **parameters) + + +def test_transaction_run_takes_no_query_object(async_fake_connection): + on_closed = MagicMock() + on_error = MagicMock() + tx = Transaction(async_fake_connection, 2, on_closed, on_error) + with pytest.raises(ValueError): + tx.run(Query("RETURN 1")) + + +def test_transaction_rollbacks_on_open_connections(async_fake_connection): + tx = Transaction(async_fake_connection, 2, + lambda *args, **kwargs: None, + lambda *args, **kwargs: None) + with tx as tx_: + async_fake_connection.is_reset_mock.return_value = False + async_fake_connection.is_reset_mock.reset_mock() + tx_.rollback() + async_fake_connection.is_reset_mock.assert_called_once() + async_fake_connection.reset.assert_not_called() + async_fake_connection.rollback.assert_called_once() + + +def test_transaction_no_rollback_on_reset_connections(async_fake_connection): + tx = Transaction(async_fake_connection, 2, + lambda *args, **kwargs: None, + lambda *args, **kwargs: None) + with tx as tx_: + async_fake_connection.is_reset_mock.return_value = True + async_fake_connection.is_reset_mock.reset_mock() + tx_.rollback() + async_fake_connection.is_reset_mock.assert_called_once() + async_fake_connection.reset.asset_not_called() + async_fake_connection.rollback.asset_not_called() + + +def test_transaction_no_rollback_on_closed_connections(async_fake_connection): + tx = Transaction(async_fake_connection, 2, + lambda *args, **kwargs: None, + lambda *args, **kwargs: None) + with tx as tx_: + async_fake_connection.closed.return_value = True + async_fake_connection.closed.reset_mock() + async_fake_connection.is_reset_mock.reset_mock() + tx_.rollback() + async_fake_connection.closed.assert_called_once() + async_fake_connection.is_reset_mock.asset_not_called() + async_fake_connection.reset.asset_not_called() + async_fake_connection.rollback.asset_not_called() + + +def test_transaction_no_rollback_on_defunct_connections(async_fake_connection): + tx = Transaction(async_fake_connection, 2, + lambda *args, **kwargs: None, + lambda *args, **kwargs: None) + with tx as tx_: + async_fake_connection.defunct.return_value = True + async_fake_connection.defunct.reset_mock() + async_fake_connection.is_reset_mock.reset_mock() + tx_.rollback() + async_fake_connection.defunct.assert_called_once() + async_fake_connection.is_reset_mock.asset_not_called() + async_fake_connection.reset.asset_not_called() + async_fake_connection.rollback.asset_not_called() diff --git a/tests/unit/common/__init__.py b/tests/unit/common/__init__.py new file mode 100644 index 00000000..b81a309d --- /dev/null +++ b/tests/unit/common/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/common/data/__init__.py b/tests/unit/common/data/__init__.py new file mode 100644 index 00000000..b81a309d --- /dev/null +++ b/tests/unit/common/data/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/data/test_packing.py b/tests/unit/common/data/test_packing.py similarity index 96% rename from tests/unit/data/test_packing.py rename to tests/unit/common/data/test_packing.py index 6fc296af..39d32576 100644 --- a/tests/unit/data/test_packing.py +++ b/tests/unit/common/data/test_packing.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,16 +16,20 @@ # limitations under the License. -import struct -from collections import OrderedDict from io import BytesIO from math import pi +import struct from unittest import TestCase from uuid import uuid4 from pytest import raises -from neo4j.packstream import Packer, UnpackableBuffer, Unpacker, Structure +from neo4j.packstream import ( + Packer, + Structure, + UnpackableBuffer, + Unpacker, +) class PackStreamTestCase(TestCase): @@ -254,7 +255,7 @@ def test_empty_map(self): def test_tiny_maps(self): for size in range(0x10): - data_in = OrderedDict() + data_in = dict() data_out = bytearray([0xA0 + size]) for el in range(1, size + 1): data_in[chr(64 + el)] = el @@ -262,17 +263,17 @@ def test_tiny_maps(self): self.assert_packable(data_in, bytes(data_out)) def test_map_8(self): - d = OrderedDict([(u"A%s" % i, 1) for i in range(40)]) + d = dict([(u"A%s" % i, 1) for i in range(40)]) b = b"".join(self.packb(u"A%s" % i, 1) for i in range(40)) self.assert_packable(d, b"\xD8\x28" + b) def test_map_16(self): - d = OrderedDict([(u"A%s" % i, 1) for i in range(40000)]) + d = dict([(u"A%s" % i, 1) for i in range(40000)]) b = b"".join(self.packb(u"A%s" % i, 1) for i in range(40000)) self.assert_packable(d, b"\xD9\x9C\x40" + b) def test_map_32(self): - d = OrderedDict([(u"A%s" % i, 1) for i in range(80000)]) + d = dict([(u"A%s" % i, 1) for i in range(80000)]) b = b"".join(self.packb(u"A%s" % i, 1) for i in range(80000)) self.assert_packable(d, b"\xDA\x00\x01\x38\x80" + b) diff --git a/tests/unit/common/io/__init__.py b/tests/unit/common/io/__init__.py new file mode 100644 index 00000000..b81a309d --- /dev/null +++ b/tests/unit/common/io/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/io/test_routing.py b/tests/unit/common/io/test_routing.py similarity index 84% rename from tests/unit/io/test_routing.py rename to tests/unit/common/io/test_routing.py index ff555bcb..41f09171 100644 --- a/tests/unit/io/test_routing.py +++ b/tests/unit/common/io/test_routing.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,19 +16,14 @@ # limitations under the License. -from unittest import TestCase +import pytest -from neo4j.io import ( - Bolt, - Neo4jPool, -) +from neo4j.api import DEFAULT_DATABASE from neo4j.routing import ( OrderedSet, RoutingTable, ) -from neo4j.api import ( - DEFAULT_DATABASE, -) + VALID_ROUTING_RECORD = { "ttl": 300, @@ -53,7 +45,7 @@ } -class OrderedSetTestCase(TestCase): +class TestOrderedSet: def test_should_repr_as_set(self): s = OrderedSet([1, 2, 3]) assert repr(s) == "{1, 2, 3}" @@ -68,14 +60,14 @@ def test_should_not_contain_non_element(self): def test_should_be_able_to_get_item_if_empty(self): s = OrderedSet([]) - with self.assertRaises(IndexError): + with pytest.raises(IndexError): _ = s[0] def test_should_be_able_to_get_items_by_index(self): s = OrderedSet([1, 2, 3]) - self.assertEqual(s[0], 1) - self.assertEqual(s[1], 2) - self.assertEqual(s[2], 3) + assert s[0] == 1 + assert s[1] == 2 + assert s[2] == 3 def test_should_be_iterable(self): s = OrderedSet([1, 2, 3]) @@ -117,7 +109,7 @@ def test_should_be_able_to_remove_existing(self): def test_should_not_be_able_to_remove_non_existing(self): s = OrderedSet([1, 2, 3]) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): s.remove(4) def test_should_be_able_to_update(self): @@ -131,14 +123,14 @@ def test_should_be_able_to_replace(self): assert list(s) == [3, 4, 5] -class RoutingTableConstructionTestCase(TestCase): +class TestRoutingTableConstruction: def test_should_be_initially_stale(self): table = RoutingTable(database=DEFAULT_DATABASE) assert not table.is_fresh(readonly=True) assert not table.is_fresh(readonly=False) -class RoutingTableParseRoutingInfoTestCase(TestCase): +class TestRoutingTableParseRoutingInfo: def test_should_return_routing_table_on_valid_record(self): table = RoutingTable.parse_routing_info( database=DEFAULT_DATABASE, @@ -162,7 +154,7 @@ def test_should_return_routing_table_on_valid_record_with_extra_role(self): assert table.ttl == 300 -class RoutingTableServersTestCase(TestCase): +class TestRoutingTableServers: def test_should_return_all_distinct_servers_in_routing_table(self): routing_table = { "ttl": 300, @@ -180,7 +172,7 @@ def test_should_return_all_distinct_servers_in_routing_table(self): assert table.servers() == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003), ('127.0.0.1', 9005)} -class RoutingTableFreshnessTestCase(TestCase): +class TestRoutingTableFreshness: def test_should_be_fresh_after_update(self): table = RoutingTable.parse_routing_info( database=DEFAULT_DATABASE, @@ -221,16 +213,20 @@ def test_should_become_stale_if_no_writers(self): assert not table.is_fresh(readonly=False) -class RoutingTableUpdateTestCase(TestCase): - def setUp(self): - self.table = RoutingTable( +class TestRoutingTableUpdate: + @pytest.fixture + def table(self): + return RoutingTable( database=DEFAULT_DATABASE, routers=[("192.168.1.1", 7687), ("192.168.1.2", 7687)], readers=[("192.168.1.3", 7687)], writers=[], ttl=0, ) - self.new_table = RoutingTable( + + @pytest.fixture + def new_table(self): + return RoutingTable( database=DEFAULT_DATABASE, routers=[("127.0.0.1", 9001), ("127.0.0.1", 9002), ("127.0.0.1", 9003)], readers=[("127.0.0.1", 9004), ("127.0.0.1", 9005)], @@ -238,18 +234,19 @@ def setUp(self): ttl=300, ) - def test_update_should_replace_routers(self): - self.table.update(self.new_table) - assert self.table.routers == {("127.0.0.1", 9001), ("127.0.0.1", 9002), ("127.0.0.1", 9003)} + def test_update_should_replace_routers(self, table, new_table): + table.update(new_table) + assert table.routers == {("127.0.0.1", 9001), ("127.0.0.1", 9002), + ("127.0.0.1", 9003)} - def test_update_should_replace_readers(self): - self.table.update(self.new_table) - assert self.table.readers == {("127.0.0.1", 9004), ("127.0.0.1", 9005)} + def test_update_should_replace_readers(self, table, new_table): + table.update(new_table) + assert table.readers == {("127.0.0.1", 9004), ("127.0.0.1", 9005)} - def test_update_should_replace_writers(self): - self.table.update(self.new_table) - assert self.table.writers == {("127.0.0.1", 9006)} + def test_update_should_replace_writers(self, table, new_table): + table.update(new_table) + assert table.writers == {("127.0.0.1", 9006)} - def test_update_should_replace_ttl(self): - self.table.update(self.new_table) - assert self.table.ttl == 300 + def test_update_should_replace_ttl(self, table, new_table): + table.update(new_table) + assert table.ttl == 300 diff --git a/tests/unit/common/spatial/__init__.py b/tests/unit/common/spatial/__init__.py new file mode 100644 index 00000000..b81a309d --- /dev/null +++ b/tests/unit/common/spatial/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/spatial/test_cartesian_point.py b/tests/unit/common/spatial/test_cartesian_point.py similarity index 98% rename from tests/unit/spatial/test_cartesian_point.py rename to tests/unit/common/spatial/test_cartesian_point.py index ee86e5b9..c33a90a1 100644 --- a/tests/unit/spatial/test_cartesian_point.py +++ b/tests/unit/common/spatial/test_cartesian_point.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# coding: utf-8 - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -18,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + import io import struct from unittest import TestCase diff --git a/tests/unit/spatial/test_point.py b/tests/unit/common/spatial/test_point.py similarity index 98% rename from tests/unit/spatial/test_point.py rename to tests/unit/common/spatial/test_point.py index 082f95c5..74eaa3a7 100644 --- a/tests/unit/spatial/test_point.py +++ b/tests/unit/common/spatial/test_point.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# coding: utf-8 - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -18,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + import io import struct from unittest import TestCase diff --git a/tests/unit/spatial/test_wgs84_point.py b/tests/unit/common/spatial/test_wgs84_point.py similarity index 98% rename from tests/unit/spatial/test_wgs84_point.py rename to tests/unit/common/spatial/test_wgs84_point.py index 8f725a58..0dee1913 100644 --- a/tests/unit/spatial/test_wgs84_point.py +++ b/tests/unit/common/spatial/test_wgs84_point.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# coding: utf-8 - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -18,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + import io import struct from unittest import TestCase diff --git a/tests/unit/test_addressing.py b/tests/unit/common/test_addressing.py similarity index 61% rename from tests/unit/test_addressing.py rename to tests/unit/common/test_addressing.py index 317ae5bf..925e7017 100644 --- a/tests/unit/test_addressing.py +++ b/tests/unit/common/test_addressing.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,19 +16,19 @@ # limitations under the License. -import pytest -import unittest.mock as mock from socket import ( AF_INET, AF_INET6, ) +import unittest.mock as mock -from neo4j.addressing import ( +import pytest + +from neo4j import ( Address, IPv4Address, - IPv6Address, ) -from neo4j import GraphDatabase + mock_socket_ipv4 = mock.Mock() mock_socket_ipv4.getpeername = lambda: ("127.0.0.1", 7687) # (address, port) @@ -39,8 +36,6 @@ mock_socket_ipv6 = mock.Mock() mock_socket_ipv6.getpeername = lambda: ("[::1]", 7687, 0, 0) # (address, port, flow info, scope id) -# python -m pytest tests/unit/test_addressing.py -s - @pytest.mark.parametrize( "test_input, expected", @@ -57,7 +52,6 @@ ] ) def test_address_initialization(test_input, expected): - # python -m pytest tests/unit/test_addressing.py -s -k test_address_initialization address = Address(test_input) assert address.family == expected["family"] assert address.host == expected["host"] @@ -74,7 +68,6 @@ def test_address_initialization(test_input, expected): ] ) def test_address_init_with_address_object_returns_same_instance(test_input): - # python -m pytest tests/unit/test_addressing.py -s -k test_address_init_with_address_object_returns_same_instance address = Address(test_input) assert address is test_input assert id(address) == id(test_input) @@ -90,7 +83,6 @@ def test_address_init_with_address_object_returns_same_instance(test_input): ] ) def test_address_initialization_with_incorrect_input(test_input, expected): - # python -m pytest tests/unit/test_addressing.py -s -k test_address_initialization_with_incorrect_input with pytest.raises(expected): address = Address(test_input) @@ -103,14 +95,12 @@ def test_address_initialization_with_incorrect_input(test_input, expected): ] ) def test_address_from_socket(test_input, expected): - # python -m pytest tests/unit/test_addressing.py -s -k test_address_from_socket address = Address.from_socket(test_input) assert address == expected def test_address_from_socket_with_none(): - # python -m pytest tests/unit/test_addressing.py -s -k test_address_from_socket_with_none with pytest.raises(AttributeError): address = Address.from_socket(None) @@ -128,7 +118,6 @@ def test_address_from_socket_with_none(): ] ) def test_address_parse_with_ipv4(test_input, expected): - # python -m pytest tests/unit/test_addressing.py -s -k test_address_parse_with_ipv4 parsed = Address.parse(test_input) assert parsed == expected @@ -143,7 +132,6 @@ def test_address_parse_with_ipv4(test_input, expected): ] ) def test_address_should_parse_ipv6(test_input, expected): - # python -m pytest tests/unit/test_addressing.py -s -k test_address_should_parse_ipv6 parsed = Address.parse(test_input) assert parsed == expected @@ -159,7 +147,6 @@ def test_address_should_parse_ipv6(test_input, expected): ] ) def test_address_parse_with_invalid_input(test_input, expected): - # python -m pytest tests/unit/test_addressing.py -s -k test_address_parse_with_invalid_input with pytest.raises(expected): parsed = Address.parse(test_input) @@ -174,7 +161,6 @@ def test_address_parse_with_invalid_input(test_input, expected): ] ) def test_address_parse_list(test_input, expected): - # python -m pytest tests/unit/test_addressing.py -s -k test_address_parse_list addresses = Address.parse_list(*test_input) assert len(addresses) == expected @@ -190,69 +176,5 @@ def test_address_parse_list(test_input, expected): ] ) def test_address_parse_list_with_invalid_input(test_input, expected): - # python -m pytest tests/unit/test_addressing.py -s -k test_address_parse_list_with_invalid_input with pytest.raises(TypeError): addresses = Address.parse_list(*test_input) - - -def test_address_resolve(): - # python -m pytest tests/unit/test_addressing.py -s -k test_address_resolve - address = Address(("127.0.0.1", 7687)) - resolved = address.resolve() - assert isinstance(resolved, Address) is False - assert isinstance(resolved, list) is True - assert len(resolved) == 1 - assert resolved[0] == IPv4Address(('127.0.0.1', 7687)) - - -def test_address_resolve_with_custom_resolver_none(): - # python -m pytest tests/unit/test_addressing.py -s -k test_address_resolve_with_custom_resolver_none - address = Address(("127.0.0.1", 7687)) - resolved = address.resolve(resolver=None) - assert isinstance(resolved, Address) is False - assert isinstance(resolved, list) is True - assert len(resolved) == 1 - assert resolved[0] == IPv4Address(('127.0.0.1', 7687)) - - -@pytest.mark.parametrize( - "test_input, expected", - [ - (Address(("127.0.0.1", "abcd")), ValueError), - (Address((None, None)), ValueError), - ] -) -def test_address_resolve_with_unresolvable_address(test_input, expected): - # python -m pytest tests/unit/test_addressing.py -s -k test_address_resolve_with_unresolvable_address - with pytest.raises(expected): - test_input.resolve(resolver=None) - - -def test_address_resolve_with_custom_resolver(): - # python -m pytest tests/unit/test_addressing.py -s -k test_address_resolve_with_custom_resolver - custom_resolver = lambda _: [("127.0.0.1", 7687), ("localhost", 1234)] - - address = Address(("127.0.0.1", 7687)) - resolved = address.resolve(family=AF_INET, resolver=custom_resolver) - assert isinstance(resolved, Address) is False - assert isinstance(resolved, list) is True - assert len(resolved) == 2 # IPv4 only - assert resolved[0] == IPv4Address(('127.0.0.1', 7687)) - assert resolved[1] == IPv4Address(('127.0.0.1', 1234)) - - -def test_address_unresolve(): - # python -m pytest tests/unit/test_addressing.py -s -k test_address_unresolve - custom_resolved = [("127.0.0.1", 7687), ("localhost", 4321)] - custom_resolver = lambda _: custom_resolved - - address = Address(("foobar", 1234)) - unresolved = address.unresolved - assert address.__class__ == unresolved.__class__ - assert address == unresolved - resolved = address.resolve(family=AF_INET, resolver=custom_resolver) - custom_resolved = sorted(Address(a) for a in custom_resolved) - unresolved = sorted(a.unresolved for a in resolved) - assert custom_resolved == unresolved - assert (list(map(lambda a: a.__class__, custom_resolved)) - == list(map(lambda a: a.__class__, unresolved))) diff --git a/tests/unit/test_api.py b/tests/unit/common/test_api.py similarity index 83% rename from tests/unit/test_api.py rename to tests/unit/common/test_api.py index 3e06bf8b..6e3caa07 100644 --- a/tests/unit/test_api.py +++ b/tests/unit/common/test_api.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,27 +16,24 @@ # limitations under the License. -import pytest from uuid import uuid4 +import pytest + import neo4j.api from neo4j.data import DataDehydrator -from neo4j.exceptions import ( - ConfigurationError, -) +from neo4j.exceptions import ConfigurationError + standard_ascii = [chr(i) for i in range(128)] not_ascii = "♥O◘♦♥O◘♦" -# python -m pytest tests/unit/test_api.py -s - def dehydrated_value(value): return DataDehydrator.fix_parameters({"_": value})["_"] def test_value_dehydration_should_allow_none(): - # python -m pytest tests/unit/test_api.py -s -k test_value_dehydration_should_allow_none assert dehydrated_value(None) is None @@ -51,7 +45,6 @@ def test_value_dehydration_should_allow_none(): ] ) def test_value_dehydration_should_allow_boolean(test_input, expected): - # python -m pytest tests/unit/test_api.py -s -k test_value_dehydration_should_allow_boolean assert dehydrated_value(test_input) is expected @@ -67,7 +60,6 @@ def test_value_dehydration_should_allow_boolean(test_input, expected): ] ) def test_value_dehydration_should_allow_integer(test_input, expected): - # python -m pytest tests/unit/test_api.py -s -k test_value_dehydration_should_allow_integer assert dehydrated_value(test_input) == expected @@ -79,7 +71,6 @@ def test_value_dehydration_should_allow_integer(test_input, expected): ] ) def test_value_dehydration_should_disallow_oversized_integer(test_input, expected): - # python -m pytest tests/unit/test_api.py -s -k test_value_dehydration_should_disallow_oversized_integer with pytest.raises(expected): dehydrated_value(test_input) @@ -94,7 +85,6 @@ def test_value_dehydration_should_disallow_oversized_integer(test_input, expecte ] ) def test_value_dehydration_should_allow_float(test_input, expected): - # python -m pytest tests/unit/test_api.py -s -k test_value_dehydration_should_allow_float assert dehydrated_value(test_input) == expected @@ -107,7 +97,6 @@ def test_value_dehydration_should_allow_float(test_input, expected): ] ) def test_value_dehydration_should_allow_string(test_input, expected): - # python -m pytest tests/unit/test_api.py -s -k test_value_dehydration_should_allow_string assert dehydrated_value(test_input) == expected @@ -119,7 +108,6 @@ def test_value_dehydration_should_allow_string(test_input, expected): ] ) def test_value_dehydration_should_allow_bytes(test_input, expected): - # python -m pytest tests/unit/test_api.py -s -k test_value_dehydration_should_allow_bytes assert dehydrated_value(test_input) == expected @@ -132,7 +120,6 @@ def test_value_dehydration_should_allow_bytes(test_input, expected): ] ) def test_value_dehydration_should_allow_list(test_input, expected): - # python -m pytest tests/unit/test_api.py -s -k test_value_dehydration_should_allow_list assert dehydrated_value(test_input) == expected @@ -146,7 +133,6 @@ def test_value_dehydration_should_allow_list(test_input, expected): ] ) def test_value_dehydration_should_allow_dict(test_input, expected): - # python -m pytest tests/unit/test_api.py -s -k test_value_dehydration_should_allow_dict assert dehydrated_value(test_input) == expected @@ -158,13 +144,11 @@ def test_value_dehydration_should_allow_dict(test_input, expected): ] ) def test_value_dehydration_should_disallow_object(test_input, expected): - # python -m pytest tests/unit/test_api.py -s -k test_value_dehydration_should_disallow_object with pytest.raises(expected): dehydrated_value(test_input) def test_bookmark_initialization_with_no_values(): - # python -m pytest tests/unit/test_api.py -s -k test_bookmark_initialization_with_no_values bookmark = neo4j.api.Bookmark() assert bookmark.values == frozenset() assert bool(bookmark) is False @@ -182,7 +166,6 @@ def test_bookmark_initialization_with_no_values(): ] ) def test_bookmark_initialization_with_values_none(test_input, expected_values, expected_bool, expected_repr): - # python -m pytest tests/unit/test_api.py -s -k test_bookmark_initialization_with_values_none bookmark = neo4j.api.Bookmark(*test_input) assert bookmark.values == expected_values assert bool(bookmark) is expected_bool @@ -200,7 +183,6 @@ def test_bookmark_initialization_with_values_none(test_input, expected_values, e ] ) def test_bookmark_initialization_with_values_empty_string(test_input, expected_values, expected_bool, expected_repr): - # python -m pytest tests/unit/test_api.py -s -k test_bookmark_initialization_with_values_empty_string bookmark = neo4j.api.Bookmark(*test_input) assert bookmark.values == expected_values assert bool(bookmark) is expected_bool @@ -216,7 +198,6 @@ def test_bookmark_initialization_with_values_empty_string(test_input, expected_v ] ) def test_bookmark_initialization_with_valid_strings(test_input, expected_values, expected_bool, expected_repr): - # python -m pytest tests/unit/test_api.py -s -k test_bookmark_initialization_with_valid_strings bookmark = neo4j.api.Bookmark(*test_input) assert bookmark.values == expected_values assert bool(bookmark) is expected_bool @@ -232,7 +213,6 @@ def test_bookmark_initialization_with_valid_strings(test_input, expected_values, ] ) def test_bookmark_initialization_with_invalid_strings(test_input, expected): - # python -m pytest tests/unit/test_api.py -s -k test_bookmark_initialization_with_invalid_strings with pytest.raises(expected) as e: bookmark = neo4j.api.Bookmark(*test_input) @@ -251,7 +231,6 @@ def test_bookmark_initialization_with_invalid_strings(test_input, expected): ] ) def test_version_initialization(test_input, expected_str, expected_repr): - # python -m pytest tests/unit/test_api.py -s -k test_version_initialization version = neo4j.api.Version(*test_input) assert str(version) == expected_str assert repr(version) == expected_repr @@ -268,7 +247,6 @@ def test_version_initialization(test_input, expected_str, expected_repr): ] ) def test_version_from_bytes_with_valid_bolt_version_handshake(test_input, expected_str, expected_repr): - # python -m pytest tests/unit/test_api.py -s -k test_version_from_bytes_with_valid_bolt_version_handshake version = neo4j.api.Version.from_bytes(test_input) assert str(version) == expected_str assert repr(version) == expected_repr @@ -285,7 +263,6 @@ def test_version_from_bytes_with_valid_bolt_version_handshake(test_input, expect ] ) def test_version_from_bytes_with_not_valid_bolt_version_handshake(test_input, expected): - # python -m pytest tests/unit/test_api.py -s -k test_version_from_bytes_with_not_valid_bolt_version_handshake with pytest.raises(expected): version = neo4j.api.Version.from_bytes(test_input) @@ -303,13 +280,11 @@ def test_version_from_bytes_with_not_valid_bolt_version_handshake(test_input, ex ] ) def test_version_to_bytes_with_valid_bolt_version(test_input, expected): - # python -m pytest tests/unit/test_api.py -s -k test_version_to_bytes_with_valid_bolt_version version = neo4j.api.Version(*test_input) assert version.to_bytes() == expected def test_serverinfo_initialization(): - # python -m pytest tests/unit/test_api.py -s -k test_serverinfo_initialization from neo4j.addressing import Address @@ -332,7 +307,6 @@ def test_serverinfo_initialization(): ] ) def test_serverinfo_with_metadata(test_input, expected_agent, expected_version_info): - # python -m pytest tests/unit/test_api.py -s -k test_serverinfo_with_metadata from neo4j.addressing import Address address = Address(("bolt://localhost", 7687)) @@ -365,7 +339,6 @@ def test_serverinfo_with_metadata(test_input, expected_agent, expected_version_i ] ) def test_uri_scheme(test_input, expected_driver_type, expected_security_type, expected_error): - # python -m pytest tests/unit/test_api.py -s -k test_uri_scheme if expected_error: with pytest.raises(expected_error): neo4j.api.parse_neo4j_uri(test_input) @@ -376,18 +349,15 @@ def test_uri_scheme(test_input, expected_driver_type, expected_security_type, ex def test_parse_routing_context(): - # python -m pytest tests/unit/test_api.py -s -v -k test_parse_routing_context context = neo4j.api.parse_routing_context(query="name=molly&color=white") assert context == {"name": "molly", "color": "white"} def test_parse_routing_context_should_error_when_value_missing(): - # python -m pytest tests/unit/test_api.py -s -v -k test_parse_routing_context_should_error_when_value_missing with pytest.raises(ConfigurationError): neo4j.api.parse_routing_context("name=&color=white") def test_parse_routing_context_should_error_when_key_duplicate(): - # python -m pytest tests/unit/test_api.py -s -v -k test_parse_routing_context_should_error_when_key_duplicate with pytest.raises(ConfigurationError): neo4j.api.parse_routing_context("name=molly&name=white") diff --git a/tests/unit/test_conf.py b/tests/unit/common/test_conf.py similarity index 98% rename from tests/unit/test_conf.py rename to tests/unit/common/test_conf.py index 9eb71819..192bad34 100644 --- a/tests/unit/test_conf.py +++ b/tests/unit/common/test_conf.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -21,24 +18,23 @@ import pytest -from neo4j.exceptions import ( - ConfigurationError, +from neo4j.api import ( + READ_ACCESS, + TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, + WRITE_ACCESS, ) from neo4j.conf import ( Config, PoolConfig, - WorkspaceConfig, SessionConfig, + WorkspaceConfig, ) -from neo4j.api import ( - TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, - WRITE_ACCESS, - READ_ACCESS, -) +from neo4j.debug import watch +from neo4j.exceptions import ConfigurationError + # python -m pytest tests/unit/test_conf.py -s -v -from neo4j.debug import watch watch("neo4j") test_pool_config = { diff --git a/tests/unit/test_data.py b/tests/unit/common/test_data.py similarity index 97% rename from tests/unit/test_data.py rename to tests/unit/common/test_data.py index 7f48af9d..14577efd 100644 --- a/tests/unit/test_data.py +++ b/tests/unit/common/test_data.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,11 +16,10 @@ # limitations under the License. -import pytest - from neo4j.data import DataHydrator from neo4j.packstream import Structure + # python -m pytest -s -v tests/unit/test_data.py diff --git a/tests/unit/test_exceptions.py b/tests/unit/common/test_exceptions.py similarity index 98% rename from tests/unit/test_exceptions.py rename to tests/unit/common/test_exceptions.py index f574ce67..802e480f 100644 --- a/tests/unit/test_exceptions.py +++ b/tests/unit/common/test_exceptions.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -21,52 +18,48 @@ import pytest +from neo4j._exceptions import ( + BoltConnectionBroken, + BoltConnectionClosed, + BoltConnectionError, + BoltError, + BoltFailure, + BoltHandshakeError, + BoltProtocolError, + BoltSecurityError, +) +from neo4j._sync.io import Bolt from neo4j.exceptions import ( - Neo4jError, + AuthConfigurationError, + AuthError, + CertificateConfigurationError, + CLASSIFICATION_CLIENT, + CLASSIFICATION_DATABASE, + CLASSIFICATION_TRANSIENT, ClientError, + ConfigurationError, + ConstraintError, CypherSyntaxError, CypherTypeError, - ConstraintError, - AuthError, - Forbidden, - ForbiddenOnReadOnlyDatabase, - NotALeader, DatabaseError, - TransientError, DatabaseUnavailable, DriverError, + Forbidden, + ForbiddenOnReadOnlyDatabase, + IncompleteCommit, + Neo4jError, + NotALeader, + ReadServiceUnavailable, + ResultConsumedError, + RoutingServiceUnavailable, + ServiceUnavailable, + SessionExpired, TransactionError, TransactionNestingError, - SessionExpired, - ServiceUnavailable, - RoutingServiceUnavailable, + TransientError, WriteServiceUnavailable, - ReadServiceUnavailable, - IncompleteCommit, - ConfigurationError, - AuthConfigurationError, - CertificateConfigurationError, - ResultConsumedError, - CLASSIFICATION_CLIENT, - CLASSIFICATION_DATABASE, - CLASSIFICATION_TRANSIENT, ) -from neo4j._exceptions import ( - BoltError, - BoltHandshakeError, - BoltConnectionError, - BoltSecurityError, - BoltConnectionBroken, - BoltConnectionClosed, - BoltFailure, - BoltProtocolError, -) - -from neo4j.io import Bolt - - -# python -m pytest tests/unit/test_exceptions.py -s -v def test_bolt_error(): with pytest.raises(BoltError) as e: diff --git a/tests/unit/common/test_import_neo4j.py b/tests/unit/common/test_import_neo4j.py new file mode 100644 index 00000000..89704d9e --- /dev/null +++ b/tests/unit/common/test_import_neo4j.py @@ -0,0 +1,164 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def test_import_dunder_version(): + from neo4j import __version__ + + +def test_import_graphdatabase(): + from neo4j import GraphDatabase + + +def test_import_async_graphdatabase(): + from neo4j import AsyncGraphDatabase + + +def test_import_driver(): + from neo4j import Driver + + +def test_import_async_driver(): + from neo4j import AsyncDriver + + +def test_import_boltdriver(): + from neo4j import BoltDriver + + +def test_import_async_boltdriver(): + from neo4j import AsyncBoltDriver + + +def test_import_neo4jdriver(): + from neo4j import Neo4jDriver + + +def test_import_async_neo4jdriver(): + from neo4j import AsyncNeo4jDriver + + +def test_import_auth(): + from neo4j import Auth + + +def test_import_authtoken(): + from neo4j import AuthToken + + +def test_import_basic_auth(): + from neo4j import basic_auth + + +def test_import_bearer_auth(): + from neo4j import bearer_auth + + +def test_import_kerberos_auth(): + from neo4j import kerberos_auth + + +def test_import_custom_auth(): + from neo4j import custom_auth + + +def test_import_read_access(): + from neo4j import READ_ACCESS + + +def test_import_write_access(): + from neo4j import WRITE_ACCESS + + +def test_import_transaction(): + from neo4j import Transaction + + +def test_import_async_transaction(): + from neo4j import AsyncTransaction + + +def test_import_record(): + from neo4j import Record + + +def test_import_session(): + from neo4j import Session + + +def test_import_async_session(): + from neo4j import AsyncSession + + +def test_import_sessionconfig(): + from neo4j import SessionConfig + + +def test_import_query(): + from neo4j import Query + + +def test_import_result(): + from neo4j import Result + + +def test_import_async_result(): + from neo4j import AsyncResult + + +def test_import_resultsummary(): + from neo4j import ResultSummary + + +def test_import_unit_of_work(): + from neo4j import unit_of_work + + +def test_import_config(): + from neo4j import Config + + +def test_import_poolconfig(): + from neo4j import PoolConfig + + +def test_import_graph(): + import neo4j.graph as graph + + +def test_import_graph_node(): + from neo4j.graph import Node + + +def test_import_graph_path(): + from neo4j.graph import Path + + +def test_import_graph_graph(): + from neo4j.graph import Graph + + +def test_import_spatial(): + import neo4j.spatial as spatial + + +def test_import_time(): + import neo4j.time as time + + +def test_import_exceptions(): + import neo4j.exceptions as exceptions diff --git a/tests/unit/test_record.py b/tests/unit/common/test_record.py similarity index 99% rename from tests/unit/test_record.py rename to tests/unit/common/test_record.py index 778b7ea3..85e0bc1f 100644 --- a/tests/unit/test_record.py +++ b/tests/unit/common/test_record.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -27,6 +24,7 @@ Record, ) + # python -m pytest -s -v tests/unit/test_record.py diff --git a/tests/unit/test_security.py b/tests/unit/common/test_security.py similarity index 98% rename from tests/unit/test_security.py rename to tests/unit/common/test_security.py index 3ac760fc..247cb4f5 100644 --- a/tests/unit/test_security.py +++ b/tests/unit/common/test_security.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -20,12 +17,13 @@ from neo4j.api import ( - kerberos_auth, basic_auth, bearer_auth, custom_auth, + kerberos_auth, ) + # python -m pytest -s -v tests/unit/test_security.py diff --git a/tests/unit/test_types.py b/tests/unit/common/test_types.py similarity index 98% rename from tests/unit/test_types.py rename to tests/unit/common/test_types.py index c2454316..979c4080 100644 --- a/tests/unit/test_types.py +++ b/tests/unit/common/test_types.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -18,18 +15,18 @@ # See the License for the specific language governing permissions and # limitations under the License. + from itertools import product import pytest -from neo4j.data import DataHydrator from neo4j.graph import ( + Graph, Node, Path, - Graph, Relationship, ) -from neo4j.packstream import Structure + # python -m pytest -s -v tests/unit/test_types.py diff --git a/tests/unit/common/time/__init__.py b/tests/unit/common/time/__init__.py new file mode 100644 index 00000000..b81a309d --- /dev/null +++ b/tests/unit/common/time/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/time/test_clock.py b/tests/unit/common/time/test_clock.py similarity index 96% rename from tests/unit/time/test_clock.py rename to tests/unit/common/time/test_clock.py index 4b4c18a7..c7bef0e4 100644 --- a/tests/unit/time/test_clock.py +++ b/tests/unit/common/time/test_clock.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# coding: utf-8 - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -21,7 +18,10 @@ from unittest import TestCase -from neo4j.time import Clock, ClockTime +from neo4j.time import ( + Clock, + ClockTime, +) class TestClock(TestCase): diff --git a/tests/unit/time/test_clocktime.py b/tests/unit/common/time/test_clocktime.py similarity index 97% rename from tests/unit/time/test_clocktime.py rename to tests/unit/common/time/test_clocktime.py index ca801d74..253672cd 100644 --- a/tests/unit/time/test_clocktime.py +++ b/tests/unit/common/time/test_clocktime.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# coding: utf-8 - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -21,7 +18,10 @@ from unittest import TestCase -from neo4j.time import ClockTime, Duration +from neo4j.time import ( + ClockTime, + Duration, +) class TestClockTime(TestCase): diff --git a/tests/unit/time/test_date.py b/tests/unit/common/time/test_date.py similarity index 99% rename from tests/unit/time/test_date.py rename to tests/unit/common/time/test_date.py index 3b34bddf..bf48b719 100644 --- a/tests/unit/time/test_date.py +++ b/tests/unit/common/time/test_date.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# coding: utf-8 - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,14 +16,19 @@ # limitations under the License. +import copy from datetime import date from time import struct_time from unittest import TestCase import pytz -import copy -from neo4j.time import Duration, Date, UnixEpoch, ZeroDate +from neo4j.time import ( + Date, + Duration, + UnixEpoch, + ZeroDate, +) eastern = pytz.timezone("US/Eastern") diff --git a/tests/unit/time/test_datetime.py b/tests/unit/common/time/test_datetime.py similarity index 98% rename from tests/unit/time/test_datetime.py rename to tests/unit/common/time/test_datetime.py index cc3f1ef2..ffc551e6 100644 --- a/tests/unit/time/test_datetime.py +++ b/tests/unit/common/time/test_datetime.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# coding: utf-8 - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -18,24 +15,25 @@ # See the License for the specific language governing permissions and # limitations under the License. + import copy -from decimal import Decimal from datetime import ( datetime, timedelta, ) +from decimal import Decimal import pytest from pytz import ( - timezone, FixedOffset, + timezone, ) from neo4j.time import ( DateTime, - MIN_YEAR, - MAX_YEAR, Duration, + MAX_YEAR, + MIN_YEAR, ) from neo4j.time.arithmetic import ( nano_add, @@ -45,17 +43,7 @@ Clock, ClockTime, ) -from neo4j.time.hydration import ( - hydrate_date, - dehydrate_date, - hydrate_time, - dehydrate_time, - hydrate_datetime, - dehydrate_datetime, - hydrate_duration, - dehydrate_duration, - dehydrate_timedelta, -) + timezone_us_eastern = timezone("US/Eastern") timezone_utc = timezone("UTC") diff --git a/tests/unit/time/test_duration.py b/tests/unit/common/time/test_duration.py similarity index 99% rename from tests/unit/time/test_duration.py rename to tests/unit/common/time/test_duration.py index aa29f750..5edc2c9c 100644 --- a/tests/unit/time/test_duration.py +++ b/tests/unit/common/time/test_duration.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# coding: utf-8 - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,9 +16,9 @@ # limitations under the License. +import copy from datetime import timedelta from decimal import Decimal -import copy import pytest diff --git a/tests/unit/time/test_hydration.py b/tests/unit/common/time/test_hydration.py similarity index 97% rename from tests/unit/time/test_hydration.py rename to tests/unit/common/time/test_hydration.py index d66f6949..0efe868d 100644 --- a/tests/unit/time/test_hydration.py +++ b/tests/unit/common/time/test_hydration.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# coding: utf-8 - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # diff --git a/tests/unit/time/test_time.py b/tests/unit/common/time/test_time.py similarity index 99% rename from tests/unit/time/test_time.py rename to tests/unit/common/time/test_time.py index 0336781f..b3bdc938 100644 --- a/tests/unit/time/test_time.py +++ b/tests/unit/common/time/test_time.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# coding: utf-8 - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -24,9 +21,8 @@ import pytest from pytz import ( - build_tzinfo, - timezone, FixedOffset, + timezone, ) from neo4j.time import Time diff --git a/tests/unit/data/__init__.py b/tests/unit/data/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/unit/io/__init__.py b/tests/unit/io/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/unit/io/test__common.py b/tests/unit/io/test__common.py deleted file mode 100644 index 3b61c710..00000000 --- a/tests/unit/io/test__common.py +++ /dev/null @@ -1,32 +0,0 @@ -import pytest - -from neo4j.io._common import Outbox - - -@pytest.mark.parametrize(("chunk_size", "data", "result"), ( - ( - 2, - (bytes(range(10, 15)),), - bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 1, 14)) - ), - ( - 2, - (bytes(range(10, 14)),), - bytes((0, 2, 10, 11, 0, 2, 12, 13)) - ), - ( - 2, - (bytes((5, 6, 7)), bytes((8, 9))), - bytes((0, 2, 5, 6, 0, 2, 7, 8, 0, 1, 9)) - ), -)) -def test_outbox_chunking(chunk_size, data, result): - outbox = Outbox(max_chunk_size=chunk_size) - assert bytes(outbox.view()) == b"" - for d in data: - outbox.write(d) - assert bytes(outbox.view()) == result - # make sure this works multiple times - assert bytes(outbox.view()) == result - outbox.clear() - assert bytes(outbox.view()) == b"" diff --git a/tests/unit/io/test_direct.py b/tests/unit/io/test_direct.py deleted file mode 100644 index 9e572fb2..00000000 --- a/tests/unit/io/test_direct.py +++ /dev/null @@ -1,309 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [http://neo4j.com] -# -# This file is part of Neo4j. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from unittest import ( - mock, - TestCase, -) -import pytest -from threading import ( - Condition, - Event, - Lock, - Thread, -) -import time - -from neo4j import ( - Config, - PoolConfig, - WorkspaceConfig, -) -from neo4j.io import ( - Bolt, - BoltPool, - IOPool -) -from neo4j.exceptions import ( - ClientError, - ServiceUnavailable, -) - - -class FakeSocket: - def __init__(self, address): - self.address = address - - def getpeername(self): - return self.address - - def sendall(self, data): - return - - def close(self): - return - - -class QuickConnection: - - def __init__(self, socket): - self.socket = socket - self.address = socket.getpeername() - - @property - def is_reset(self): - return True - - def stale(self): - return False - - def reset(self): - pass - - def close(self): - self.socket.close() - - def closed(self): - return False - - def defunct(self): - return False - - def timedout(self): - return False - - -class FakeBoltPool(IOPool): - - def __init__(self, address, *, auth=None, **config): - self.pool_config, self.workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) - if config: - raise ValueError("Unexpected config keys: %s" % ", ".join(config.keys())) - - def opener(addr, timeout): - return QuickConnection(FakeSocket(addr)) - - super().__init__(opener, self.pool_config, self.workspace_config) - self.address = address - - def acquire(self, access_mode=None, timeout=None, database=None, - bookmarks=None): - return self._acquire(self.address, timeout) - - -class BoltTestCase(TestCase): - - def test_open(self): - with pytest.raises(ServiceUnavailable): - connection = Bolt.open(("localhost", 9999), auth=("test", "test")) - - def test_open_timeout(self): - with pytest.raises(ServiceUnavailable): - connection = Bolt.open(("localhost", 9999), auth=("test", "test"), timeout=1) - - def test_ping(self): - protocol_version = Bolt.ping(("localhost", 9999)) - assert protocol_version is None - - def test_ping_timeout(self): - protocol_version = Bolt.ping(("localhost", 9999), timeout=1) - assert protocol_version is None - - -class MultiEvent: - # Adopted from threading.Event - - def __init__(self): - super().__init__() - self._cond = Condition(Lock()) - self._counter = 0 - - def _reset_internal_locks(self): - # private! called by Thread._reset_internal_locks by _after_fork() - self._cond.__init__(Lock()) - - def counter(self): - return self._counter - - def increment(self): - with self._cond: - self._counter += 1 - self._cond.notify_all() - - def decrement(self): - with self._cond: - self._counter -= 1 - self._cond.notify_all() - - def clear(self): - with self._cond: - self._counter = 0 - self._cond.notify_all() - - def wait(self, value=0, timeout=None): - with self._cond: - t_start = time.time() - while True: - if value == self._counter: - return True - if timeout is None: - time_left = None - else: - time_left = timeout - (time.time() - t_start) - if time_left <= 0: - return False - if not self._cond.wait(time_left): - return False - - -class ConnectionPoolTestCase(TestCase): - - def setUp(self): - self.pool = FakeBoltPool(("127.0.0.1", 7687)) - - def tearDown(self): - self.pool.close() - - def assert_pool_size(self, address, expected_active, expected_inactive, pool=None): - if pool is None: - pool = self.pool - try: - connections = pool.connections[address] - except KeyError: - self.assertEqual(0, expected_active) - self.assertEqual(0, expected_inactive) - else: - self.assertEqual(expected_active, len([cx for cx in connections if cx.in_use])) - self.assertEqual(expected_inactive, len([cx for cx in connections if not cx.in_use])) - - def test_can_acquire(self): - address = ("127.0.0.1", 7687) - connection = self.pool._acquire(address, timeout=3) - assert connection.address == address - self.assert_pool_size(address, 1, 0) - - def test_can_acquire_twice(self): - address = ("127.0.0.1", 7687) - connection_1 = self.pool._acquire(address, timeout=3) - connection_2 = self.pool._acquire(address, timeout=3) - assert connection_1.address == address - assert connection_2.address == address - assert connection_1 is not connection_2 - self.assert_pool_size(address, 2, 0) - - def test_can_acquire_two_addresses(self): - address_1 = ("127.0.0.1", 7687) - address_2 = ("127.0.0.1", 7474) - connection_1 = self.pool._acquire(address_1, timeout=3) - connection_2 = self.pool._acquire(address_2, timeout=3) - assert connection_1.address == address_1 - assert connection_2.address == address_2 - self.assert_pool_size(address_1, 1, 0) - self.assert_pool_size(address_2, 1, 0) - - def test_can_acquire_and_release(self): - address = ("127.0.0.1", 7687) - connection = self.pool._acquire(address, timeout=3) - self.assert_pool_size(address, 1, 0) - self.pool.release(connection) - self.assert_pool_size(address, 0, 1) - - def test_releasing_twice(self): - address = ("127.0.0.1", 7687) - connection = self.pool._acquire(address, timeout=3) - self.pool.release(connection) - self.assert_pool_size(address, 0, 1) - self.pool.release(connection) - self.assert_pool_size(address, 0, 1) - - def test_in_use_count(self): - address = ("127.0.0.1", 7687) - self.assertEqual(self.pool.in_use_connection_count(address), 0) - connection = self.pool._acquire(address, timeout=3) - self.assertEqual(self.pool.in_use_connection_count(address), 1) - self.pool.release(connection) - self.assertEqual(self.pool.in_use_connection_count(address), 0) - - def test_max_conn_pool_size(self): - with FakeBoltPool((), max_connection_pool_size=1) as pool: - address = ("127.0.0.1", 7687) - pool._acquire(address, timeout=0) - self.assertEqual(pool.in_use_connection_count(address), 1) - with self.assertRaises(ClientError): - pool._acquire(address, timeout=0) - self.assertEqual(pool.in_use_connection_count(address), 1) - - def test_multithread(self): - with FakeBoltPool((), max_connection_pool_size=5) as pool: - address = ("127.0.0.1", 7687) - acquired_counter = MultiEvent() - release_event = Event() - - # start 10 threads competing for connections from a pool of size 5 - threads = [] - for i in range(10): - t = Thread( - target=acquire_release_conn, - args=(pool, address, acquired_counter, release_event), - daemon=True - ) - t.start() - threads.append(t) - - if not acquired_counter.wait(5, timeout=1): - raise RuntimeError("Acquire threads not fast enough") - # The pool size should be 5, all are in-use - self.assert_pool_size(address, 5, 0, pool) - # Now we allow thread to release connections they obtained from pool - release_event.set() - - # wait for all threads to release connections back to pool - for t in threads: - t.join(timeout=1) - # The pool size is still 5, but all are free - self.assert_pool_size(address, 0, 5, pool) - - def test_reset_when_released(self): - def test(is_reset): - with mock.patch(__name__ + ".QuickConnection.is_reset", - new_callable=mock.PropertyMock) as is_reset_mock: - with mock.patch(__name__ + ".QuickConnection.reset", - new_callable=mock.MagicMock) as reset_mock: - is_reset_mock.return_value = is_reset - connection = self.pool._acquire(address, timeout=3) - self.assertIsInstance(connection, QuickConnection) - self.assertEqual(is_reset_mock.call_count, 0) - self.assertEqual(reset_mock.call_count, 0) - self.pool.release(connection) - self.assertEqual(is_reset_mock.call_count, 1) - self.assertEqual(reset_mock.call_count, int(not is_reset)) - - address = ("127.0.0.1", 7687) - for is_reset in (True, False): - with self.subTest(): - test(is_reset) - - -def acquire_release_conn(pool, address, acquired_counter, release_event): - conn = pool._acquire(address, timeout=3) - acquired_counter.increment() - release_event.wait() - pool.release(conn) diff --git a/tests/unit/mixed/__init__.py b/tests/unit/mixed/__init__.py new file mode 100644 index 00000000..b81a309d --- /dev/null +++ b/tests/unit/mixed/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/mixed/async_compat/__init__.py b/tests/unit/mixed/async_compat/__init__.py new file mode 100644 index 00000000..b81a309d --- /dev/null +++ b/tests/unit/mixed/async_compat/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/mixed/async_compat/test_concurrency.py b/tests/unit/mixed/async_compat/test_concurrency.py new file mode 100644 index 00000000..1bce780d --- /dev/null +++ b/tests/unit/mixed/async_compat/test_concurrency.py @@ -0,0 +1,115 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio + +import pytest + +from neo4j._async_compat.concurrency import AsyncRLock + + +@pytest.mark.asyncio +async def test_async_r_lock(): + counter = 1 + lock = AsyncRLock() + + async def worker(): + nonlocal counter + async with lock: + counter_ = counter + counter += 1 + await asyncio.sleep(0) + # assert no one else touched the counter + assert counter == counter_ + 1 + + assert not lock.locked() + await asyncio.gather(worker(), worker(), worker()) + assert not lock.locked() + + +@pytest.mark.asyncio +async def test_async_r_lock_is_reentrant(): + lock = AsyncRLock() + + async def worker(): + async with lock: + assert lock._count == 1 + async with lock: + assert lock._count == 2 + assert lock.locked() + + assert not lock.locked() + await asyncio.gather(worker(), worker(), worker()) + assert not lock.locked() + + +@pytest.mark.asyncio +async def test_async_r_lock_acquire_timeout_blocked(): + lock = AsyncRLock() + + async def blocker(): + await lock.acquire() + + async def waiter(): + # make sure blocker has a chance to acquire the lock + await asyncio.sleep(0) + await lock.acquire(0.1) + + assert not lock.locked() + with pytest.raises(asyncio.TimeoutError): + await asyncio.gather(blocker(), waiter()) + + +@pytest.mark.asyncio +async def test_async_r_lock_acquire_timeout_released(): + lock = AsyncRLock() + + async def blocker(): + await lock.acquire() + await asyncio.sleep(0) + # waiter: lock.acquire(0.1) + lock.release() + + async def waiter(): + await asyncio.sleep(0) + # blocker: lock.acquire() + await lock.acquire(0.1) + # blocker: lock.release() + + assert not lock.locked() + await asyncio.gather(blocker(), waiter()) + assert lock.locked() # waiter still owns it! + + +@pytest.mark.asyncio +async def test_async_r_lock_acquire_timeout_reentrant(): + lock = AsyncRLock() + assert not lock.locked() + + await lock.acquire() + assert lock._count == 1 + await lock.acquire() + assert lock._count == 2 + await lock.acquire(0.1) + assert lock._count == 3 + await lock.acquire(0.1) + assert lock._count == 4 + for _ in range(4): + lock.release() + + assert not lock.locked() diff --git a/tests/unit/mixed/io/__init__.py b/tests/unit/mixed/io/__init__.py new file mode 100644 index 00000000..b81a309d --- /dev/null +++ b/tests/unit/mixed/io/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/mixed/io/test_direct.py b/tests/unit/mixed/io/test_direct.py new file mode 100644 index 00000000..4cf6afa5 --- /dev/null +++ b/tests/unit/mixed/io/test_direct.py @@ -0,0 +1,243 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +from asyncio import ( + Condition as AsyncCondition, + Event as AsyncEvent, + Lock as AsyncLock, +) +from threading import ( + Condition, + Event, + Lock, + Thread, +) +import time +from unittest import ( + mock, + TestCase, +) + +import pytest + +from neo4j import ( + Config, + PoolConfig, + WorkspaceConfig, +) +from neo4j._async.io import AsyncBolt +from neo4j._async.io._pool import AsyncIOPool +from neo4j.exceptions import ( + ClientError, + ServiceUnavailable, +) + +from ...async_.io.test_direct import AsyncFakeBoltPool +from ...sync.io.test_direct import FakeBoltPool + + +class MultiEvent: + # Adopted from threading.Event + + def __init__(self): + super().__init__() + self._cond = Condition(Lock()) + self._counter = 0 + + def _reset_internal_locks(self): + # private! called by Thread._reset_internal_locks by _after_fork() + self._cond.__init__(Lock()) + + def counter(self): + return self._counter + + def increment(self): + with self._cond: + self._counter += 1 + self._cond.notify_all() + + def decrement(self): + with self._cond: + self._counter -= 1 + self._cond.notify_all() + + def clear(self): + with self._cond: + self._counter = 0 + self._cond.notify_all() + + def wait(self, value=0, timeout=None): + with self._cond: + t_start = time.time() + while True: + if value == self._counter: + return True + if timeout is None: + time_left = None + else: + time_left = timeout - (time.time() - t_start) + if time_left <= 0: + return False + if not self._cond.wait(time_left): + return False + + +class AsyncMultiEvent: + # Adopted from threading.Event + + def __init__(self): + super().__init__() + self._cond = AsyncCondition() + self._counter = 0 + + def _reset_internal_locks(self): + # private! called by Thread._reset_internal_locks by _after_fork() + self._cond.__init__(AsyncLock()) + + def counter(self): + return self._counter + + async def increment(self): + async with self._cond: + self._counter += 1 + self._cond.notify_all() + + async def decrement(self): + async with self._cond: + self._counter -= 1 + self._cond.notify_all() + + async def clear(self): + async with self._cond: + self._counter = 0 + self._cond.notify_all() + + async def wait(self, value=0, timeout=None): + async with self._cond: + t_start = time.time() + while True: + if value == self._counter: + return True + if timeout is None: + time_left = None + else: + time_left = timeout - (time.time() - t_start) + if time_left <= 0: + return False + try: + await asyncio.wait_for(self._cond.wait(), time_left) + except asyncio.TimeoutError: + return False + + +class TestMixedConnectionPoolTestCase: + def assert_pool_size(self, address, expected_active, expected_inactive, + pool): + try: + connections = pool.connections[address] + except KeyError: + assert 0 == expected_active + assert 0 == expected_inactive + else: + assert (expected_active + == len([cx for cx in connections if cx.in_use])) + assert (expected_inactive + == len([cx for cx in connections if not cx.in_use])) + + def test_multithread(self): + def acquire_release_conn(pool, address, acquired_counter, + release_event): + conn = pool._acquire(address, timeout=3) + acquired_counter.increment() + release_event.wait() + pool.release(conn) + + with FakeBoltPool((), max_connection_pool_size=5) as pool: + address = ("127.0.0.1", 7687) + acquired_counter = MultiEvent() + release_event = Event() + + # start 10 threads competing for connections from a pool of size 5 + threads = [] + for i in range(10): + t = Thread( + target=acquire_release_conn, + args=(pool, address, acquired_counter, release_event), + daemon=True + ) + t.start() + threads.append(t) + + if not acquired_counter.wait(5, timeout=1): + raise RuntimeError("Acquire threads not fast enough") + # The pool size should be 5, all are in-use + self.assert_pool_size(address, 5, 0, pool) + # Now we allow the threads to release connections they obtained + # from the pool + release_event.set() + + # wait for all threads to release connections back to pool + for t in threads: + t.join(timeout=1) + # The pool size is still 5, but all are free + self.assert_pool_size(address, 0, 5, pool) + + @pytest.mark.asyncio + async def test_multi_coroutine(self): + async def acquire_release_conn(pool_, address_, acquired_counter_, + release_event_): + try: + conn = await pool_._acquire(address_, timeout=3) + await acquired_counter_.increment() + await release_event_.wait() + await pool_.release(conn) + except ClientError: + raise + + async def waiter(pool_, acquired_counter_, release_event_): + if not await acquired_counter_.wait(5, timeout=1): + raise RuntimeError("Acquire coroutines not fast enough") + # The pool size should be 5, all are in-use + self.assert_pool_size(address, 5, 0, pool_) + # Now we allow the coroutines to release connections they obtained + # from the pool + release_event_.set() + + # wait for all coroutines to release connections back to pool + if not await acquired_counter_.wait(10, timeout=5): + raise RuntimeError("Acquire coroutines not fast enough") + # The pool size is still 5, but all are free + self.assert_pool_size(address, 0, 5, pool_) + + async with AsyncFakeBoltPool((), max_connection_pool_size=5) as pool: + address = ("127.0.0.1", 7687) + acquired_counter = AsyncMultiEvent() + release_event = AsyncEvent() + + # start 10 coroutines competing for connections from a pool of size + # 5 + coroutines = [ + acquire_release_conn( + pool, address, acquired_counter, release_event + ) for _ in range(10) + ] + await asyncio.gather( + waiter(pool, acquired_counter, release_event), + *coroutines + ) diff --git a/tests/unit/spatial/__init__.py b/tests/unit/spatial/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/unit/sync/__init__.py b/tests/unit/sync/__init__.py new file mode 100644 index 00000000..b81a309d --- /dev/null +++ b/tests/unit/sync/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/sync/io/__init__.py b/tests/unit/sync/io/__init__.py new file mode 100644 index 00000000..b81a309d --- /dev/null +++ b/tests/unit/sync/io/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/io/conftest.py b/tests/unit/sync/io/conftest.py similarity index 92% rename from tests/unit/io/conftest.py rename to tests/unit/sync/io/conftest.py index 0f8d21ea..33309fc9 100644 --- a/tests/unit/io/conftest.py +++ b/tests/unit/sync/io/conftest.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -20,12 +17,19 @@ from io import BytesIO -from struct import pack as struct_pack, unpack as struct_unpack +from struct import ( + pack as struct_pack, + unpack as struct_unpack, +) import pytest -from neo4j.io._common import MessageInbox -from neo4j.packstream import Packer, UnpackableBuffer, Unpacker +from neo4j._sync.io._common import MessageInbox +from neo4j.packstream import ( + Packer, + UnpackableBuffer, + Unpacker, +) class FakeSocket: @@ -89,7 +93,7 @@ def close(self): def inject(self, data): self.recv_buffer += data - def pop_chunk(self): + def _pop_chunk(self): chunk_size, = struct_unpack(">H", self.recv_buffer[:2]) print("CHUNK SIZE %r" % chunk_size) end = 2 + chunk_size @@ -99,7 +103,7 @@ def pop_chunk(self): def pop_message(self): data = bytearray() while True: - chunk = self.pop_chunk() + chunk = self._pop_chunk() print("CHUNK %r" % chunk) if chunk: data.extend(chunk) diff --git a/tests/unit/sync/io/test__common.py b/tests/unit/sync/io/test__common.py new file mode 100644 index 00000000..5106a2da --- /dev/null +++ b/tests/unit/sync/io/test__common.py @@ -0,0 +1,50 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._sync.io._common import Outbox + + +@pytest.mark.parametrize(("chunk_size", "data", "result"), ( + ( + 2, + (bytes(range(10, 15)),), + bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 1, 14)) + ), + ( + 2, + (bytes(range(10, 14)),), + bytes((0, 2, 10, 11, 0, 2, 12, 13)) + ), + ( + 2, + (bytes((5, 6, 7)), bytes((8, 9))), + bytes((0, 2, 5, 6, 0, 2, 7, 8, 0, 1, 9)) + ), +)) +def test_async_outbox_chunking(chunk_size, data, result): + outbox = Outbox(max_chunk_size=chunk_size) + assert bytes(outbox.view()) == b"" + for d in data: + outbox.write(d) + assert bytes(outbox.view()) == result + # make sure this works multiple times + assert bytes(outbox.view()) == result + outbox.clear() + assert bytes(outbox.view()) == b"" diff --git a/tests/unit/io/test_class_bolt.py b/tests/unit/sync/io/test_class_bolt.py similarity index 96% rename from tests/unit/io/test_class_bolt.py rename to tests/unit/sync/io/test_class_bolt.py index 2001546c..b7d1e6c5 100644 --- a/tests/unit/io/test_class_bolt.py +++ b/tests/unit/sync/io/test_class_bolt.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -20,7 +17,9 @@ import pytest -from neo4j.io import Bolt + +from neo4j._sync.io import Bolt + # python -m pytest tests/unit/io/test_class_bolt.py -s -v diff --git a/tests/unit/io/test_class_bolt3.py b/tests/unit/sync/io/test_class_bolt3.py similarity index 88% rename from tests/unit/io/test_class_bolt3.py rename to tests/unit/sync/io/test_class_bolt3.py index f7d63e85..b42512d0 100644 --- a/tests/unit/io/test_class_bolt3.py +++ b/tests/unit/sync/io/test_class_bolt3.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -18,17 +15,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import MagicMock import pytest -from neo4j.io._bolt3 import Bolt3 +from neo4j._sync.io._bolt3 import Bolt3 from neo4j.conf import PoolConfig -from neo4j.exceptions import ( - ConfigurationError, -) +from neo4j.exceptions import ConfigurationError -# python -m pytest tests/unit/io/test_class_bolt3.py -s -v +from ..._async_compat import ( + MagicMock, + mark_sync_test, +) @pytest.mark.parametrize("set_stale", (True, False)) @@ -75,6 +72,7 @@ def test_db_extra_not_supported_in_run(fake_socket): connection.run("", db="something") +@mark_sync_test def test_simple_discard(fake_socket): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -86,6 +84,7 @@ def test_simple_discard(fake_socket): assert len(fields) == 0 +@mark_sync_test def test_simple_pull(fake_socket): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -98,7 +97,10 @@ def test_simple_pull(fake_socket): @pytest.mark.parametrize("recv_timeout", (1, -1)) -def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout): +@mark_sync_test +def test_hint_recv_timeout_seconds_gets_ignored( + fake_socket_pair, recv_timeout +): address = ("127.0.0.1", 7687) sockets = fake_socket_pair(address) sockets.client.settimeout = MagicMock() @@ -106,7 +108,8 @@ def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout): "server": "Neo4j/3.5.0", "hints": {"connection.recv_timeout_seconds": recv_timeout}, }) - connection = Bolt3(address, sockets.client, - PoolConfig.max_connection_lifetime) + connection = Bolt3( + address, sockets.client, PoolConfig.max_connection_lifetime + ) connection.hello() sockets.client.settimeout.assert_not_called() diff --git a/tests/unit/io/test_class_bolt4x0.py b/tests/unit/sync/io/test_class_bolt4x0.py similarity index 92% rename from tests/unit/io/test_class_bolt4x0.py rename to tests/unit/sync/io/test_class_bolt4x0.py index 3879acb0..5f94d5c0 100644 --- a/tests/unit/io/test_class_bolt4x0.py +++ b/tests/unit/sync/io/test_class_bolt4x0.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -18,13 +15,16 @@ # See the License for the specific language governing permissions and # limitations under the License. + from unittest.mock import MagicMock import pytest -from neo4j.io._bolt4 import Bolt4x0 +from neo4j._sync.io._bolt4 import Bolt4x0 from neo4j.conf import PoolConfig +from ..._async_compat import mark_sync_test + @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): @@ -56,6 +56,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): assert connection.stale() is set_stale +@mark_sync_test def test_db_extra_in_begin(fake_socket): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -68,6 +69,7 @@ def test_db_extra_in_begin(fake_socket): assert fields[0] == {"db": "something"} +@mark_sync_test def test_db_extra_in_run(fake_socket): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -82,6 +84,7 @@ def test_db_extra_in_run(fake_socket): assert fields[2] == {"db": "something"} +@mark_sync_test def test_n_extra_in_discard(fake_socket): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -101,6 +104,7 @@ def test_n_extra_in_discard(fake_socket): (-1, {"n": -1}), ] ) +@mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -120,8 +124,8 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): (-1, {"n": 666}), ] ) +@mark_sync_test def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): - # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard address = ("127.0.0.1", 7687) socket = fake_socket(address) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) @@ -140,6 +144,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): (-1, {"n": -1}), ] ) +@mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -159,8 +164,8 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): (-1, {"n": -1}), ] ) +@mark_sync_test def test_qid_extra_in_pull(fake_socket, test_input, expected): - # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull address = ("127.0.0.1", 7687) socket = fake_socket(address) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) @@ -172,6 +177,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): assert fields[0] == expected +@mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -185,7 +191,10 @@ def test_n_and_qid_extras_in_pull(fake_socket): @pytest.mark.parametrize("recv_timeout", (1, -1)) -def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout): +@mark_sync_test +def test_hint_recv_timeout_seconds_gets_ignored( + fake_socket_pair, recv_timeout +): address = ("127.0.0.1", 7687) sockets = fake_socket_pair(address) sockets.client.settimeout = MagicMock() @@ -193,7 +202,8 @@ def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout): "server": "Neo4j/4.0.0", "hints": {"connection.recv_timeout_seconds": recv_timeout}, }) - connection = Bolt4x0(address, sockets.client, - PoolConfig.max_connection_lifetime) + connection = Bolt4x0( + address, sockets.client, PoolConfig.max_connection_lifetime + ) connection.hello() sockets.client.settimeout.assert_not_called() diff --git a/tests/unit/io/test_class_bolt4x1.py b/tests/unit/sync/io/test_class_bolt4x1.py similarity index 92% rename from tests/unit/io/test_class_bolt4x1.py rename to tests/unit/sync/io/test_class_bolt4x1.py index 663d3cbe..2d69b9de 100644 --- a/tests/unit/io/test_class_bolt4x1.py +++ b/tests/unit/sync/io/test_class_bolt4x1.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -18,13 +15,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import MagicMock import pytest -from neo4j.io._bolt4 import Bolt4x1 +from neo4j._sync.io._bolt4 import Bolt4x1 from neo4j.conf import PoolConfig +from ..._async_compat import ( + MagicMock, + mark_sync_test, +) + @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): @@ -56,6 +57,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): assert connection.stale() is set_stale +@mark_sync_test def test_db_extra_in_begin(fake_socket): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -68,6 +70,7 @@ def test_db_extra_in_begin(fake_socket): assert fields[0] == {"db": "something"} +@mark_sync_test def test_db_extra_in_run(fake_socket): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -82,6 +85,7 @@ def test_db_extra_in_run(fake_socket): assert fields[2] == {"db": "something"} +@mark_sync_test def test_n_extra_in_discard(fake_socket): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -101,6 +105,7 @@ def test_n_extra_in_discard(fake_socket): (-1, {"n": -1}), ] ) +@mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -120,6 +125,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): (-1, {"n": 666}), ] ) +@mark_sync_test def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard address = ("127.0.0.1", 7687) @@ -140,6 +146,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): (-1, {"n": -1}), ] ) +@mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -159,6 +166,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): (-1, {"n": -1}), ] ) +@mark_sync_test def test_qid_extra_in_pull(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull address = ("127.0.0.1", 7687) @@ -172,6 +180,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): assert fields[0] == expected +@mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -184,12 +193,15 @@ def test_n_and_qid_extras_in_pull(fake_socket): assert fields[0] == {"n": 666, "qid": 777} +@mark_sync_test def test_hello_passes_routing_metadata(fake_socket_pair): address = ("127.0.0.1", 7687) sockets = fake_socket_pair(address) sockets.server.send_message(0x70, {"server": "Neo4j/4.1.0"}) - connection = Bolt4x1(address, sockets.client, PoolConfig.max_connection_lifetime, - routing_context={"foo": "bar"}) + connection = Bolt4x1( + address, sockets.client, PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"} + ) connection.hello() tag, fields = sockets.server.pop_message() assert tag == 0x01 @@ -198,7 +210,10 @@ def test_hello_passes_routing_metadata(fake_socket_pair): @pytest.mark.parametrize("recv_timeout", (1, -1)) -def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout): +@mark_sync_test +def test_hint_recv_timeout_seconds_gets_ignored( + fake_socket_pair, recv_timeout +): address = ("127.0.0.1", 7687) sockets = fake_socket_pair(address) sockets.client.settimeout = MagicMock() diff --git a/tests/unit/io/test_class_bolt4x2.py b/tests/unit/sync/io/test_class_bolt4x2.py similarity index 91% rename from tests/unit/io/test_class_bolt4x2.py rename to tests/unit/sync/io/test_class_bolt4x2.py index 470adf5c..03605796 100644 --- a/tests/unit/io/test_class_bolt4x2.py +++ b/tests/unit/sync/io/test_class_bolt4x2.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -19,13 +16,16 @@ # limitations under the License. -from unittest.mock import MagicMock - import pytest -from neo4j.io._bolt4 import Bolt4x2 +from neo4j._sync.io._bolt4 import Bolt4x2 from neo4j.conf import PoolConfig +from ..._async_compat import ( + MagicMock, + mark_sync_test, +) + @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): @@ -57,6 +57,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): assert connection.stale() is set_stale +@mark_sync_test def test_db_extra_in_begin(fake_socket): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -69,6 +70,7 @@ def test_db_extra_in_begin(fake_socket): assert fields[0] == {"db": "something"} +@mark_sync_test def test_db_extra_in_run(fake_socket): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -83,6 +85,7 @@ def test_db_extra_in_run(fake_socket): assert fields[2] == {"db": "something"} +@mark_sync_test def test_n_extra_in_discard(fake_socket): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -102,6 +105,7 @@ def test_n_extra_in_discard(fake_socket): (-1, {"n": -1}), ] ) +@mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -121,6 +125,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): (-1, {"n": 666}), ] ) +@mark_sync_test def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard address = ("127.0.0.1", 7687) @@ -141,6 +146,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): (-1, {"n": -1}), ] ) +@mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -160,6 +166,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): (-1, {"n": -1}), ] ) +@mark_sync_test def test_qid_extra_in_pull(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull address = ("127.0.0.1", 7687) @@ -173,6 +180,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): assert fields[0] == expected +@mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -185,12 +193,15 @@ def test_n_and_qid_extras_in_pull(fake_socket): assert fields[0] == {"n": 666, "qid": 777} +@mark_sync_test def test_hello_passes_routing_metadata(fake_socket_pair): address = ("127.0.0.1", 7687) sockets = fake_socket_pair(address) sockets.server.send_message(0x70, {"server": "Neo4j/4.2.0"}) - connection = Bolt4x2(address, sockets.client, PoolConfig.max_connection_lifetime, - routing_context={"foo": "bar"}) + connection = Bolt4x2( + address, sockets.client, PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"} + ) connection.hello() tag, fields = sockets.server.pop_message() assert tag == 0x01 @@ -199,7 +210,10 @@ def test_hello_passes_routing_metadata(fake_socket_pair): @pytest.mark.parametrize("recv_timeout", (1, -1)) -def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout): +@mark_sync_test +def test_hint_recv_timeout_seconds_gets_ignored( + fake_socket_pair, recv_timeout +): address = ("127.0.0.1", 7687) sockets = fake_socket_pair(address) sockets.client.settimeout = MagicMock() @@ -207,7 +221,8 @@ def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout): "server": "Neo4j/4.2.0", "hints": {"connection.recv_timeout_seconds": recv_timeout}, }) - connection = Bolt4x2(address, sockets.client, - PoolConfig.max_connection_lifetime) + connection = Bolt4x2( + address, sockets.client, PoolConfig.max_connection_lifetime + ) connection.hello() sockets.client.settimeout.assert_not_called() diff --git a/tests/unit/io/test_class_bolt4x3.py b/tests/unit/sync/io/test_class_bolt4x3.py similarity index 92% rename from tests/unit/io/test_class_bolt4x3.py rename to tests/unit/sync/io/test_class_bolt4x3.py index fc08f5b9..43469fc9 100644 --- a/tests/unit/io/test_class_bolt4x3.py +++ b/tests/unit/sync/io/test_class_bolt4x3.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -18,14 +15,19 @@ # See the License for the specific language governing permissions and # limitations under the License. + import logging -from unittest.mock import MagicMock import pytest -from neo4j.io._bolt4 import Bolt4x3 +from neo4j._sync.io._bolt4 import Bolt4x3 from neo4j.conf import PoolConfig +from ..._async_compat import ( + MagicMock, + mark_sync_test, +) + @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): @@ -57,6 +59,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): assert connection.stale() is set_stale +@mark_sync_test def test_db_extra_in_begin(fake_socket): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -69,6 +72,7 @@ def test_db_extra_in_begin(fake_socket): assert fields[0] == {"db": "something"} +@mark_sync_test def test_db_extra_in_run(fake_socket): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -83,6 +87,7 @@ def test_db_extra_in_run(fake_socket): assert fields[2] == {"db": "something"} +@mark_sync_test def test_n_extra_in_discard(fake_socket): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -102,6 +107,7 @@ def test_n_extra_in_discard(fake_socket): (-1, {"n": -1}), ] ) +@mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -121,6 +127,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): (-1, {"n": 666}), ] ) +@mark_sync_test def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard address = ("127.0.0.1", 7687) @@ -141,6 +148,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): (-1, {"n": -1}), ] ) +@mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -160,6 +168,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): (-1, {"n": -1}), ] ) +@mark_sync_test def test_qid_extra_in_pull(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull address = ("127.0.0.1", 7687) @@ -173,6 +182,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): assert fields[0] == expected +@mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -185,13 +195,15 @@ def test_n_and_qid_extras_in_pull(fake_socket): assert fields[0] == {"n": 666, "qid": 777} +@mark_sync_test def test_hello_passes_routing_metadata(fake_socket_pair): address = ("127.0.0.1", 7687) sockets = fake_socket_pair(address) sockets.server.send_message(0x70, {"server": "Neo4j/4.3.0"}) - connection = Bolt4x3(address, sockets.client, - PoolConfig.max_connection_lifetime, - routing_context={"foo": "bar"}) + connection = Bolt4x3( + address, sockets.client, PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"} + ) connection.hello() tag, fields = sockets.server.pop_message() assert tag == 0x01 @@ -211,12 +223,16 @@ def test_hello_passes_routing_metadata(fake_socket_pair): ({"connection.recv_timeout_seconds": False}, False), ({"connection.recv_timeout_seconds": "1"}, False), )) -def test_hint_recv_timeout_seconds(fake_socket_pair, hints, valid, - caplog): +@mark_sync_test +def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog +): address = ("127.0.0.1", 7687) sockets = fake_socket_pair(address) sockets.client.settimeout = MagicMock() - sockets.server.send_message(0x70, {"server": "Neo4j/4.3.0", "hints": hints}) + sockets.server.send_message( + 0x70, {"server": "Neo4j/4.3.0", "hints": hints} + ) connection = Bolt4x3(address, sockets.client, PoolConfig.max_connection_lifetime) with caplog.at_level(logging.INFO): diff --git a/tests/unit/io/test_class_bolt4x4.py b/tests/unit/sync/io/test_class_bolt4x4.py similarity index 92% rename from tests/unit/io/test_class_bolt4x4.py rename to tests/unit/sync/io/test_class_bolt4x4.py index 19378a1c..b2523b1c 100644 --- a/tests/unit/io/test_class_bolt4x4.py +++ b/tests/unit/sync/io/test_class_bolt4x4.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -18,14 +15,20 @@ # See the License for the specific language governing permissions and # limitations under the License. + import logging from unittest.mock import MagicMock import pytest -from neo4j.io._bolt4 import Bolt4x4 +from neo4j._sync.io._bolt4 import Bolt4x4 from neo4j.conf import PoolConfig +from ..._async_compat import ( + MagicMock, + mark_sync_test, +) + @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): @@ -56,6 +59,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): connection.set_stale() assert connection.stale() is set_stale + @pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( (("", {}), {"db": "something"}, ({"db": "something"},)), (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), @@ -65,6 +69,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): ({"db": "something", "imp_user": "imposter"},) ), )) +@mark_sync_test def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -75,6 +80,7 @@ def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): assert tag == b"\x11" assert tuple(is_fields) == expected_fields + @pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), (("", {}), {"imp_user": "imposter"}, ("", {}, {"imp_user": "imposter"})), @@ -84,6 +90,7 @@ def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): ("", {}, {"db": "something", "imp_user": "imposter"}) ), )) +@mark_sync_test def test_extra_in_run(fake_socket, args, kwargs, expected_fields): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -95,6 +102,7 @@ def test_extra_in_run(fake_socket, args, kwargs, expected_fields): assert tuple(is_fields) == expected_fields +@mark_sync_test def test_n_extra_in_discard(fake_socket): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -114,6 +122,7 @@ def test_n_extra_in_discard(fake_socket): (-1, {"n": -1}), ] ) +@mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -133,6 +142,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): (-1, {"n": 666}), ] ) +@mark_sync_test def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard address = ("127.0.0.1", 7687) @@ -153,6 +163,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): (-1, {"n": -1}), ] ) +@mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -172,6 +183,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): (-1, {"n": -1}), ] ) +@mark_sync_test def test_qid_extra_in_pull(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull address = ("127.0.0.1", 7687) @@ -185,6 +197,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): assert fields[0] == expected +@mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): address = ("127.0.0.1", 7687) socket = fake_socket(address) @@ -197,13 +210,15 @@ def test_n_and_qid_extras_in_pull(fake_socket): assert fields[0] == {"n": 666, "qid": 777} +@mark_sync_test def test_hello_passes_routing_metadata(fake_socket_pair): address = ("127.0.0.1", 7687) sockets = fake_socket_pair(address) sockets.server.send_message(0x70, {"server": "Neo4j/4.4.0"}) - connection = Bolt4x4(address, sockets.client, - PoolConfig.max_connection_lifetime, - routing_context={"foo": "bar"}) + connection = Bolt4x4( + address, sockets.client, PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"} + ) connection.hello() tag, fields = sockets.server.pop_message() assert tag == 0x01 @@ -223,14 +238,19 @@ def test_hello_passes_routing_metadata(fake_socket_pair): ({"connection.recv_timeout_seconds": False}, False), ({"connection.recv_timeout_seconds": "1"}, False), )) -def test_hint_recv_timeout_seconds(fake_socket_pair, hints, valid, - caplog): +@mark_sync_test +def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog +): address = ("127.0.0.1", 7687) sockets = fake_socket_pair(address) sockets.client.settimeout = MagicMock() - sockets.server.send_message(0x70, {"server": "Neo4j/4.3.4", "hints": hints}) - connection = Bolt4x4(address, sockets.client, - PoolConfig.max_connection_lifetime) + sockets.server.send_message( + 0x70, {"server": "Neo4j/4.3.4", "hints": hints} + ) + connection = Bolt4x4( + address, sockets.client, PoolConfig.max_connection_lifetime + ) with caplog.at_level(logging.INFO): connection.hello() if valid: diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py new file mode 100644 index 00000000..d5ff16cb --- /dev/null +++ b/tests/unit/sync/io/test_direct.py @@ -0,0 +1,231 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j import ( + Config, + PoolConfig, + WorkspaceConfig, +) +from neo4j._sync.io import Bolt +from neo4j._sync.io._pool import IOPool +from neo4j.exceptions import ( + ClientError, + ServiceUnavailable, +) + +from ..._async_compat import ( + mark_sync_test, + Mock, + mock, +) + + +class FakeSocket: + def __init__(self, address): + self.address = address + + def getpeername(self): + return self.address + + def sendall(self, data): + return + + def close(self): + return + + +class QuickConnection: + def __init__(self, socket): + self.socket = socket + self.address = socket.getpeername() + + @property + def is_reset(self): + return True + + def stale(self): + return False + + def reset(self): + pass + + def close(self): + self.socket.close() + + def closed(self): + return False + + def defunct(self): + return False + + def timedout(self): + return False + + +class FakeBoltPool(IOPool): + + def __init__(self, address, *, auth=None, **config): + self.pool_config, self.workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) + if config: + raise ValueError("Unexpected config keys: %s" % ", ".join(config.keys())) + + def opener(addr, timeout): + return QuickConnection(FakeSocket(addr)) + + super().__init__(opener, self.pool_config, self.workspace_config) + self.address = address + + def acquire( + self, access_mode=None, timeout=None, database=None, bookmarks=None + ): + return self._acquire(self.address, timeout) + + +@mark_sync_test +def test_bolt_connection_open(): + with pytest.raises(ServiceUnavailable): + Bolt.open(("localhost", 9999), auth=("test", "test")) + + +@mark_sync_test +def test_bolt_connection_open_timeout(): + with pytest.raises(ServiceUnavailable): + Bolt.open(("localhost", 9999), auth=("test", "test"), + timeout=1) + + +@mark_sync_test +def test_bolt_connection_ping(): + protocol_version = Bolt.ping(("localhost", 9999)) + assert protocol_version is None + + +@mark_sync_test +def test_bolt_connection_ping_timeout(): + protocol_version = Bolt.ping(("localhost", 9999), timeout=1) + assert protocol_version is None + + +@pytest.fixture +def pool(): + with FakeBoltPool(("127.0.0.1", 7687)) as pool: + yield pool + + +def assert_pool_size( address, expected_active, expected_inactive, pool): + try: + connections = pool.connections[address] + except KeyError: + assert 0 == expected_active + assert 0 == expected_inactive + else: + assert expected_active == len([cx for cx in connections if cx.in_use]) + assert (expected_inactive + == len([cx for cx in connections if not cx.in_use])) + + +@mark_sync_test +def test_pool_can_acquire(pool): + address = ("127.0.0.1", 7687) + connection = pool._acquire(address, timeout=3) + assert connection.address == address + assert_pool_size(address, 1, 0, pool) + + +@mark_sync_test +def test_pool_can_acquire_twice(pool): + address = ("127.0.0.1", 7687) + connection_1 = pool._acquire(address, timeout=3) + connection_2 = pool._acquire(address, timeout=3) + assert connection_1.address == address + assert connection_2.address == address + assert connection_1 is not connection_2 + assert_pool_size(address, 2, 0, pool) + + +@mark_sync_test +def test_pool_can_acquire_two_addresses(pool): + address_1 = ("127.0.0.1", 7687) + address_2 = ("127.0.0.1", 7474) + connection_1 = pool._acquire(address_1, timeout=3) + connection_2 = pool._acquire(address_2, timeout=3) + assert connection_1.address == address_1 + assert connection_2.address == address_2 + assert_pool_size(address_1, 1, 0, pool) + assert_pool_size(address_2, 1, 0, pool) + + +@mark_sync_test +def test_pool_can_acquire_and_release(pool): + address = ("127.0.0.1", 7687) + connection = pool._acquire(address, timeout=3) + assert_pool_size(address, 1, 0, pool) + pool.release(connection) + assert_pool_size(address, 0, 1, pool) + + +@mark_sync_test +def test_pool_releasing_twice(pool): + address = ("127.0.0.1", 7687) + connection = pool._acquire(address, timeout=3) + pool.release(connection) + assert_pool_size(address, 0, 1, pool) + pool.release(connection) + assert_pool_size(address, 0, 1, pool) + + +@mark_sync_test +def test_pool_in_use_count(pool): + address = ("127.0.0.1", 7687) + assert pool.in_use_connection_count(address) == 0 + connection = pool._acquire(address, timeout=3) + assert pool.in_use_connection_count(address) == 1 + pool.release(connection) + assert pool.in_use_connection_count(address) == 0 + + +@mark_sync_test +def test_pool_max_conn_pool_size(pool): + with FakeBoltPool((), max_connection_pool_size=1) as pool: + address = ("127.0.0.1", 7687) + pool._acquire(address, timeout=0) + assert pool.in_use_connection_count(address) == 1 + with pytest.raises(ClientError): + pool._acquire(address, timeout=0) + assert pool.in_use_connection_count(address) == 1 + + +@pytest.mark.parametrize("is_reset", (True, False)) +@mark_sync_test +def test_pool_reset_when_released(is_reset, pool): + address = ("127.0.0.1", 7687) + quick_connection_name = QuickConnection.__name__ + with mock.patch(f"{__name__}.{quick_connection_name}.is_reset", + new_callable=mock.PropertyMock) as is_reset_mock: + with mock.patch(f"{__name__}.{quick_connection_name}.reset", + new_callable=Mock) as reset_mock: + is_reset_mock.return_value = is_reset + connection = pool._acquire(address, timeout=3) + assert isinstance(connection, QuickConnection) + assert is_reset_mock.call_count == 0 + assert reset_mock.call_count == 0 + pool.release(connection) + assert is_reset_mock.call_count == 1 + assert reset_mock.call_count == int(not is_reset) diff --git a/tests/unit/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py similarity index 85% rename from tests/unit/io/test_neo4j_pool.py rename to tests/unit/sync/io/test_neo4j_pool.py index a5df0a90..6fb57b98 100644 --- a/tests/unit/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -23,19 +20,23 @@ import pytest -from ..work import FakeConnection - from neo4j import ( READ_ACCESS, WRITE_ACCESS, ) +from neo4j._sync.io import Neo4jPool from neo4j.addressing import ResolvedAddress from neo4j.conf import ( PoolConfig, RoutingConfig, - WorkspaceConfig + WorkspaceConfig, ) -from neo4j.io import Neo4jPool + +from ..._async_compat import ( + mark_sync_test, + Mock, +) +from ..work import FakeConnection ROUTER_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host") @@ -68,8 +69,11 @@ def open_(addr, timeout): return opener_ +@mark_sync_test def test_acquires_new_routing_table_if_deleted(opener): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = Neo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) cx = pool.acquire(READ_ACCESS, 30, "test_db", None) pool.release(cx) assert pool.routing_tables.get("test_db") @@ -81,8 +85,11 @@ def test_acquires_new_routing_table_if_deleted(opener): assert pool.routing_tables.get("test_db") +@mark_sync_test def test_acquires_new_routing_table_if_stale(opener): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = Neo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) cx = pool.acquire(READ_ACCESS, 30, "test_db", None) pool.release(cx) assert pool.routing_tables.get("test_db") @@ -95,8 +102,11 @@ def test_acquires_new_routing_table_if_stale(opener): assert pool.routing_tables["test_db"].last_updated_time > old_value +@mark_sync_test def test_removes_old_routing_table(opener): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = Neo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) cx = pool.acquire(READ_ACCESS, 30, "test_db1", None) pool.release(cx) assert pool.routing_tables.get("test_db1") @@ -116,8 +126,11 @@ def test_removes_old_routing_table(opener): @pytest.mark.parametrize("type_", ("r", "w")) +@mark_sync_test def test_chooses_right_connection_type(opener, type_): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = Neo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) cx1 = pool.acquire(READ_ACCESS if type_ == "r" else WRITE_ACCESS, 30, "test_db", None) pool.release(cx1) @@ -127,8 +140,11 @@ def test_chooses_right_connection_type(opener, type_): assert cx1.addr == WRITER_ADDRESS +@mark_sync_test def test_reuses_connection(opener): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = Neo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) pool.release(cx1) cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None) @@ -136,6 +152,7 @@ def test_reuses_connection(opener): @pytest.mark.parametrize("break_on_close", (True, False)) +@mark_sync_test def test_closes_stale_connections(opener, break_on_close): def break_connection(): pool.deactivate(cx1.addr) @@ -143,7 +160,9 @@ def break_connection(): if cx_close_mock_side_effect: cx_close_mock_side_effect() - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = Neo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) pool.release(cx1) assert cx1 in pool.connections[cx1.addr] @@ -166,8 +185,11 @@ def break_connection(): assert cx2 in pool.connections[cx2.addr] +@mark_sync_test def test_does_not_close_stale_connections_in_use(opener): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = Neo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) assert cx1 in pool.connections[cx1.addr] # simulate connection going stale (e.g. exceeding) while being in use @@ -194,8 +216,11 @@ def test_does_not_close_stale_connections_in_use(opener): assert cx3 in pool.connections[cx2.addr] +@mark_sync_test def test_release_resets_connections(opener): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = Neo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) cx1.is_reset_mock.return_value = False cx1.is_reset_mock.reset_mock() @@ -204,8 +229,11 @@ def test_release_resets_connections(opener): cx1.reset.assert_called_once() +@mark_sync_test def test_release_does_not_resets_closed_connections(opener): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = Neo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) cx1.closed.return_value = True cx1.closed.reset_mock() @@ -216,8 +244,11 @@ def test_release_does_not_resets_closed_connections(opener): cx1.reset.asset_not_called() +@mark_sync_test def test_release_does_not_resets_defunct_connections(opener): - pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + pool = Neo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) cx1.defunct.return_value = True cx1.defunct.reset_mock() diff --git a/tests/unit/sync/test_addressing.py b/tests/unit/sync/test_addressing.py new file mode 100644 index 00000000..4fd814b2 --- /dev/null +++ b/tests/unit/sync/test_addressing.py @@ -0,0 +1,125 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from socket import ( + AF_INET, + AF_INET6, +) +import unittest.mock as mock + +import pytest + +from neo4j import ( + Address, + IPv4Address, +) +from neo4j._async_compat.network import NetworkUtil +from neo4j._async_compat.util import Util + +from .._async_compat import mark_sync_test + + +mock_socket_ipv4 = mock.Mock() +mock_socket_ipv4.getpeername = lambda: ("127.0.0.1", 7687) # (address, port) + +mock_socket_ipv6 = mock.Mock() +mock_socket_ipv6.getpeername = lambda: ("[::1]", 7687, 0, 0) # (address, port, flow info, scope id) + + +@mark_sync_test +def test_address_resolve(): + address = Address(("127.0.0.1", 7687)) + resolved = NetworkUtil.resolve_address(address) + resolved = Util.list(resolved) + assert isinstance(resolved, Address) is False + assert isinstance(resolved, list) is True + assert len(resolved) == 1 + assert resolved[0] == IPv4Address(('127.0.0.1', 7687)) + + +@mark_sync_test +def test_address_resolve_with_custom_resolver_none(): + address = Address(("127.0.0.1", 7687)) + resolved = NetworkUtil.resolve_address(address, resolver=None) + resolved = Util.list(resolved) + assert isinstance(resolved, Address) is False + assert isinstance(resolved, list) is True + assert len(resolved) == 1 + assert resolved[0] == IPv4Address(('127.0.0.1', 7687)) + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (Address(("127.0.0.1", "abcd")), ValueError), + (Address((None, None)), ValueError), + ] + +) +@mark_sync_test +def test_address_resolve_with_unresolvable_address(test_input, expected): + with pytest.raises(expected): + Util.list( + NetworkUtil.resolve_address(test_input, resolver=None) + ) + + +@mark_sync_test +@pytest.mark.parametrize("resolver_type", ("sync", "async")) +def test_address_resolve_with_custom_resolver(resolver_type): + def custom_resolver_sync(_): + return [("127.0.0.1", 7687), ("localhost", 1234)] + + def custom_resolver_async(_): + return [("127.0.0.1", 7687), ("localhost", 1234)] + + if resolver_type == "sync": + custom_resolver = custom_resolver_sync + else: + custom_resolver = custom_resolver_async + + address = Address(("127.0.0.1", 7687)) + resolved = NetworkUtil.resolve_address( + address, family=AF_INET, resolver=custom_resolver + ) + resolved = Util.list(resolved) + assert isinstance(resolved, Address) is False + assert isinstance(resolved, list) is True + assert len(resolved) == 2 # IPv4 only + assert resolved[0] == IPv4Address(('127.0.0.1', 7687)) + assert resolved[1] == IPv4Address(('127.0.0.1', 1234)) + + +@mark_sync_test +def test_address_unresolve(): + custom_resolved = [("127.0.0.1", 7687), ("localhost", 4321)] + custom_resolver = lambda _: custom_resolved + + address = Address(("foobar", 1234)) + unresolved = address.unresolved + assert address.__class__ == unresolved.__class__ + assert address == unresolved + resolved = NetworkUtil.resolve_address( + address, family=AF_INET, resolver=custom_resolver + ) + resolved = Util.list(resolved) + custom_resolved = sorted(Address(a) for a in custom_resolved) + unresolved = sorted(a.unresolved for a in resolved) + assert custom_resolved == unresolved + assert (list(map(lambda a: a.__class__, custom_resolved)) + == list(map(lambda a: a.__class__, unresolved))) diff --git a/tests/unit/test_driver.py b/tests/unit/sync/test_driver.py similarity index 92% rename from tests/unit/test_driver.py rename to tests/unit/sync/test_driver.py index 0c1192e4..93579b2e 100644 --- a/tests/unit/test_driver.py +++ b/tests/unit/sync/test_driver.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -18,18 +15,24 @@ # See the License for the specific language governing permissions and # limitations under the License. + import pytest from neo4j import ( BoltDriver, GraphDatabase, Neo4jDriver, - TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, TRUST_ALL_CERTIFICATES, + TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, ) from neo4j.api import WRITE_ACCESS from neo4j.exceptions import ConfigurationError +from .._async_compat import ( + mark_sync_test, + mock, +) + @pytest.mark.parametrize("protocol", ("bolt://", "bolt+s://", "bolt+ssc://")) @pytest.mark.parametrize("host", ("localhost", "127.0.0.1", @@ -121,18 +124,22 @@ def test_driver_trust_config_error( "bolt://127.0.0.1:9000", "neo4j://127.0.0.1:9000", )) +@mark_sync_test def test_driver_opens_write_session_by_default(uri, mocker): driver = GraphDatabase.driver(uri) - from neo4j.work.transaction import Transaction + from neo4j import Transaction + # we set a specific db, because else the driver would try to fetch a RT # to get hold of the actual home database (which won't work in this # unittest) with driver.session(database="foobar") as session: - acquire_mock = mocker.patch.object(session._pool, "acquire", - autospec=True) - tx_begin_mock = mocker.patch.object(Transaction, "_begin", - autospec=True) - tx = session.begin_transaction() + with mock.patch.object( + session._pool, "acquire", autospec=True + ) as acquire_mock: + with mock.patch.object( + Transaction, "_begin", autospec=True + ) as tx_begin_mock: + tx = session.begin_transaction() acquire_mock.assert_called_once_with( access_mode=WRITE_ACCESS, timeout=mocker.ANY, diff --git a/tests/unit/work/__init__.py b/tests/unit/sync/work/__init__.py similarity index 93% rename from tests/unit/work/__init__.py rename to tests/unit/sync/work/__init__.py index 238e61d3..2613b53d 100644 --- a/tests/unit/work/__init__.py +++ b/tests/unit/sync/work/__init__.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -18,7 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. + from ._fake_connection import ( - FakeConnection, fake_connection, + FakeConnection, ) diff --git a/tests/unit/work/_fake_connection.py b/tests/unit/sync/work/_fake_connection.py similarity index 73% rename from tests/unit/work/_fake_connection.py rename to tests/unit/sync/work/_fake_connection.py index fef0b580..1748ea61 100644 --- a/tests/unit/work/_fake_connection.py +++ b/tests/unit/sync/work/_fake_connection.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -20,11 +17,16 @@ import inspect -from unittest import mock import pytest from neo4j import ServerInfo +from neo4j._sync.io import Bolt + +from ..._async_compat import ( + Mock, + mock, +) class FakeConnection(mock.NonCallableMagicMock): @@ -32,22 +34,27 @@ class FakeConnection(mock.NonCallableMagicMock): server_info = ServerInfo("127.0.0.1", (4, 3)) def __init__(self, *args, **kwargs): + kwargs["spec"] = Bolt super().__init__(*args, **kwargs) - self.attach_mock(mock.Mock(return_value=True), "is_reset_mock") - self.attach_mock(mock.Mock(return_value=False), "defunct") - self.attach_mock(mock.Mock(return_value=False), "stale") - self.attach_mock(mock.Mock(return_value=False), "closed") + self.attach_mock(Mock(return_value=True), "is_reset_mock") + self.attach_mock(Mock(return_value=False), "defunct") + self.attach_mock(Mock(return_value=False), "stale") + self.attach_mock(Mock(return_value=False), "closed") + self.attach_mock(Mock(), "unresolved_address") def close_side_effect(): self.closed.return_value = True - self.attach_mock(mock.Mock(side_effect=close_side_effect), "close") + self.attach_mock(Mock(side_effect=close_side_effect), + "close") @property def is_reset(self): if self.closed.return_value or self.defunct.return_value: - raise AssertionError("is_reset should not be called on a closed or " - "defunct connection.") + raise AssertionError( + "is_reset should not be called on a closed or defunct " + "connection." + ) return self.is_reset_mock() def fetch_message(self, *args, **kwargs): @@ -81,9 +88,13 @@ def callback(): # e.g. built-in method as cb pass if param_count == 1: - cb({}) + res = cb({}) else: - cb() + res = cb() + try: + res # maybe the callback is async + except TypeError: + pass # or maybe it wasn't ;) self.callbacks.append(callback) return func diff --git a/tests/unit/work/test_result.py b/tests/unit/sync/work/test_result.py similarity index 83% rename from tests/unit/work/test_result.py rename to tests/unit/sync/work/test_result.py index c4b5aa13..863df838 100644 --- a/tests/unit/work/test_result.py +++ b/tests/unit/sync/work/test_result.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -26,13 +23,16 @@ from neo4j import ( Address, Record, + Result, ResultSummary, ServerInfo, SummaryCounters, Version, ) +from neo4j._async_compat.util import Util from neo4j.data import DataHydrator -from neo4j.work.result import Result + +from ..._async_compat import mark_sync_test class Records: @@ -62,8 +62,7 @@ def __init__(self, message, *args, **kwargs): def _cb(self, cb_name, *args, **kwargs): # print(self.message, cb_name.upper(), args, kwargs) cb = self.kwargs.get(cb_name) - if callable(self.kwargs.get(cb_name)): - cb(*args, **kwargs) + Util.callback(cb, *args, **kwargs) def on_success(self, metadata): self._cb("on_success", metadata) @@ -151,8 +150,9 @@ def fetch_message(self): if self.record_idxs[qid] < len(self._records[qid]): msg.on_success({"has_more": True}) else: - msg.on_success({"bookmark": "foo", - **(self.summary_meta or {})}) + msg.on_success( + {"bookmark": "foo", **(self.summary_meta or {})} + ) self._exhausted[qid] = True msg.on_summary() @@ -184,8 +184,9 @@ def noop(*_, **__): pass -def _fetch_and_compare_all_records(result, key, expected_records, method, - limit=None): +def fetch_and_compare_all_records( + result, key, expected_records, method, limit=None +): received_records = [] if method == "for loop": for record in result: @@ -196,42 +197,48 @@ def _fetch_and_compare_all_records(result, key, expected_records, method, if limit is None: assert result._closed elif method == "next": - iter_ = iter(result) + iter_ = Util.iter(result) n = len(expected_records) if limit is None else limit for _ in range(n): - received_records.append([next(iter_).get(key, None)]) + record = Util.next(iter_) + received_records.append([record.get(key, None)]) if limit is None: with pytest.raises(StopIteration): - received_records.append([next(iter_).get(key, None)]) + Util.next(iter_) assert result._closed elif method == "new iter": n = len(expected_records) if limit is None else limit for _ in range(n): - received_records.append([next(iter(result)).get(key, None)]) + iter_ = Util.iter(result) + record = Util.next(iter_) + received_records.append([record.get(key, None)]) if limit is None: + iter_ = Util.iter(result) with pytest.raises(StopIteration): - received_records.append([next(iter(result)).get(key, None)]) + Util.next(iter_) assert result._closed else: raise ValueError() assert received_records == expected_records -@pytest.mark.parametrize("method", ("for loop", "next", "new iter")) +@pytest.mark.parametrize("method", ("for loop", "next", "new iter")) @pytest.mark.parametrize("records", ( [], [[42]], [[1], [2], [3], [4], [5]], )) +@mark_sync_test def test_result_iteration(method, records): connection = ConnectionStub(records=Records(["x"], records)) result = Result(connection, HydratorStub(), 2, noop, noop) result._run("CYPHER", {}, None, None, "r", None) - _fetch_and_compare_all_records(result, "x", records, method) + fetch_and_compare_all_records(result, "x", records, method) -@pytest.mark.parametrize("method", ("for loop", "next", "new iter")) +@pytest.mark.parametrize("method", ("for loop", "next", "new iter")) @pytest.mark.parametrize("invert_fetch", (True, False)) +@mark_sync_test def test_parallel_result_iteration(method, invert_fetch): records1 = [[i] for i in range(1, 6)] records2 = [[i] for i in range(6, 11)] @@ -243,15 +250,24 @@ def test_parallel_result_iteration(method, invert_fetch): result2 = Result(connection, HydratorStub(), 2, noop, noop) result2._run("CYPHER2", {}, None, None, "r", None) if invert_fetch: - _fetch_and_compare_all_records(result2, "x", records2, method) - _fetch_and_compare_all_records(result1, "x", records1, method) + fetch_and_compare_all_records( + result2, "x", records2, method + ) + fetch_and_compare_all_records( + result1, "x", records1, method + ) else: - _fetch_and_compare_all_records(result1, "x", records1, method) - _fetch_and_compare_all_records(result2, "x", records2, method) + fetch_and_compare_all_records( + result1, "x", records1, method + ) + fetch_and_compare_all_records( + result2, "x", records2, method + ) -@pytest.mark.parametrize("method", ("for loop", "next", "new iter")) +@pytest.mark.parametrize("method", ("for loop", "next", "new iter")) @pytest.mark.parametrize("invert_fetch", (True, False)) +@mark_sync_test def test_interwoven_result_iteration(method, invert_fetch): records1 = [[i] for i in range(1, 10)] records2 = [[i] for i in range(11, 20)] @@ -266,20 +282,25 @@ def test_interwoven_result_iteration(method, invert_fetch): for n in (1, 2, 3, 1, None): end = n if n is None else start + n if invert_fetch: - _fetch_and_compare_all_records(result2, "y", records2[start:end], - method, n) - _fetch_and_compare_all_records(result1, "x", records1[start:end], - method, n) + fetch_and_compare_all_records( + result2, "y", records2[start:end], method, n + ) + fetch_and_compare_all_records( + result1, "x", records1[start:end], method, n + ) else: - _fetch_and_compare_all_records(result1, "x", records1[start:end], - method, n) - _fetch_and_compare_all_records(result2, "y", records2[start:end], - method, n) + fetch_and_compare_all_records( + result1, "x", records1[start:end], method, n + ) + fetch_and_compare_all_records( + result2, "y", records2[start:end], method, n + ) start = end @pytest.mark.parametrize("records", ([[1], [2]], [[1]], [])) @pytest.mark.parametrize("fetch_size", (1, 2)) +@mark_sync_test def test_result_peek(records, fetch_size): connection = ConnectionStub(records=Records(["x"], records)) result = Result(connection, HydratorStub(), fetch_size, noop, noop) @@ -291,11 +312,13 @@ def test_result_peek(records, fetch_size): else: assert isinstance(record, Record) assert record.get("x") == records[i][0] - next(iter(result)) # consume the record + iter_ = Util.iter(result) + Util.next(iter_) # consume the record @pytest.mark.parametrize("records", ([[1], [2]], [[1]], [])) @pytest.mark.parametrize("fetch_size", (1, 2)) +@mark_sync_test def test_result_single(records, fetch_size): connection = ConnectionStub(records=Records(["x"], records)) result = Result(connection, HydratorStub(), fetch_size, noop, noop) @@ -314,26 +337,29 @@ def test_result_single(records, fetch_size): assert record.get("x") == records[0][0] +@mark_sync_test def test_keys_are_available_before_and_after_stream(): connection = ConnectionStub(records=Records(["x"], [[1], [2]])) result = Result(connection, HydratorStub(), 1, noop, noop) result._run("CYPHER", {}, None, None, "r", None) assert list(result.keys()) == ["x"] - list(result) + Util.list(result) assert list(result.keys()) == ["x"] @pytest.mark.parametrize("records", ([[1], [2]], [[1]], [])) @pytest.mark.parametrize("consume_one", (True, False)) @pytest.mark.parametrize("summary_meta", (None, {"database": "foobar"})) +@mark_sync_test def test_consume(records, consume_one, summary_meta): - connection = ConnectionStub(records=Records(["x"], records), - summary_meta=summary_meta) + connection = ConnectionStub( + records=Records(["x"], records), summary_meta=summary_meta + ) result = Result(connection, HydratorStub(), 1, noop, noop) result._run("CYPHER", {}, None, None, "r", None) if consume_one: try: - next(iter(result)) + Util.next(Util.iter(result)) except StopIteration: pass summary = result.consume() @@ -351,6 +377,7 @@ def test_consume(records, consume_one, summary_meta): @pytest.mark.parametrize("t_first", (None, 0, 1, 123456789)) @pytest.mark.parametrize("t_last", (None, 0, 1, 123456789)) +@mark_sync_test def test_time_in_summary(t_first, t_last): run_meta = None if t_first is not None: @@ -358,9 +385,10 @@ def test_time_in_summary(t_first, t_last): summary_meta = None if t_last is not None: summary_meta = {"t_last": t_last} - connection = ConnectionStub(records=Records(["n"], - [[i] for i in range(100)]), - run_meta=run_meta, summary_meta=summary_meta) + connection = ConnectionStub( + records=Records(["n"], [[i] for i in range(100)]), run_meta=run_meta, + summary_meta=summary_meta + ) result = Result(connection, HydratorStub(), 1, noop, noop) result._run("CYPHER", {}, None, None, "r", None) @@ -380,6 +408,7 @@ def test_time_in_summary(t_first, t_last): assert not hasattr(summary, "t_last") +@mark_sync_test def test_counts_in_summary(): connection = ConnectionStub(records=Records(["n"], [[1], [2]])) @@ -391,9 +420,11 @@ def test_counts_in_summary(): @pytest.mark.parametrize("query_type", ("r", "w", "rw", "s")) +@mark_sync_test def test_query_type(query_type): - connection = ConnectionStub(records=Records(["n"], [[1], [2]]), - summary_meta={"type": query_type}) + connection = ConnectionStub( + records=Records(["n"], [[1], [2]]), summary_meta={"type": query_type} + ) result = Result(connection, HydratorStub(), 1, noop, noop) result._run("CYPHER", {}, None, None, "r", None) @@ -404,6 +435,7 @@ def test_query_type(query_type): @pytest.mark.parametrize("num_records", range(0, 5)) +@mark_sync_test def test_data(num_records): connection = ConnectionStub( records=Records(["n"], [[i + 1] for i in range(num_records)]) diff --git a/tests/unit/work/test_session.py b/tests/unit/sync/work/test_session.py similarity index 87% rename from tests/unit/work/test_session.py rename to tests/unit/sync/work/test_session.py index 71a3bde8..4a12f695 100644 --- a/tests/unit/work/test_session.py +++ b/tests/unit/sync/work/test_session.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -18,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + from contextlib import contextmanager import pytest @@ -28,23 +26,30 @@ Transaction, unit_of_work, ) +from neo4j._sync.io._pool import IOPool +from ..._async_compat import ( + mark_sync_test, + Mock, + mock, +) from ._fake_connection import FakeConnection @pytest.fixture() -def pool(mocker): - pool = mocker.MagicMock() - pool.acquire = mocker.MagicMock(side_effect=iter(FakeConnection, 0)) +def pool(): + pool = Mock(spec=IOPool) + pool.acquire.side_effect = iter(FakeConnection, 0) return pool -def test_session_context_calls_close(mocker): +@mark_sync_test +def test_session_context_calls_close(): s = Session(None, SessionConfig()) - mock_close = mocker.patch.object(s, 'close', autospec=True) - with s: - pass - mock_close.assert_called_once_with() + with mock.patch.object(s, 'close', autospec=True) as mock_close: + with s: + pass + mock_close.assert_called_once_with() @pytest.mark.parametrize("test_run_args", ( @@ -53,7 +58,10 @@ def test_session_context_calls_close(mocker): @pytest.mark.parametrize(("repetitions", "consume"), ( (1, False), (2, False), (2, True) )) -def test_opens_connection_on_run(pool, test_run_args, repetitions, consume): +@mark_sync_test +def test_opens_connection_on_run( + pool, test_run_args, repetitions, consume +): with Session(pool, SessionConfig()) as session: assert session._connection is None result = session.run(*test_run_args) @@ -66,7 +74,10 @@ def test_opens_connection_on_run(pool, test_run_args, repetitions, consume): ("RETURN $x", {"x": 1}), ("RETURN 1",) )) @pytest.mark.parametrize("repetitions", range(1, 3)) -def test_closes_connection_after_consume(pool, test_run_args, repetitions): +@mark_sync_test +def test_closes_connection_after_consume( + pool, test_run_args, repetitions +): with Session(pool, SessionConfig()) as session: result = session.run(*test_run_args) result.consume() @@ -77,7 +88,10 @@ def test_closes_connection_after_consume(pool, test_run_args, repetitions): @pytest.mark.parametrize("test_run_args", ( ("RETURN $x", {"x": 1}), ("RETURN 1",) )) -def test_keeps_connection_until_last_result_consumed(pool, test_run_args): +@mark_sync_test +def test_keeps_connection_until_last_result_consumed( + pool, test_run_args +): with Session(pool, SessionConfig()) as session: result1 = session.run(*test_run_args) result2 = session.run(*test_run_args) @@ -88,6 +102,7 @@ def test_keeps_connection_until_last_result_consumed(pool, test_run_args): assert session._connection is None +@mark_sync_test def test_opens_connection_on_tx_begin(pool): with Session(pool, SessionConfig()) as session: assert session._connection is None @@ -99,6 +114,7 @@ def test_opens_connection_on_tx_begin(pool): ("RETURN $x", {"x": 1}), ("RETURN 1",) )) @pytest.mark.parametrize("repetitions", range(1, 3)) +@mark_sync_test def test_keeps_connection_on_tx_run(pool, test_run_args, repetitions): with Session(pool, SessionConfig()) as session: with session.begin_transaction() as tx: @@ -111,7 +127,10 @@ def test_keeps_connection_on_tx_run(pool, test_run_args, repetitions): ("RETURN $x", {"x": 1}), ("RETURN 1",) )) @pytest.mark.parametrize("repetitions", range(1, 3)) -def test_keeps_connection_on_tx_consume(pool, test_run_args, repetitions): +@mark_sync_test +def test_keeps_connection_on_tx_consume( + pool, test_run_args, repetitions +): with Session(pool, SessionConfig()) as session: with session.begin_transaction() as tx: for _ in range(repetitions): @@ -123,6 +142,7 @@ def test_keeps_connection_on_tx_consume(pool, test_run_args, repetitions): @pytest.mark.parametrize("test_run_args", ( ("RETURN $x", {"x": 1}), ("RETURN 1",) )) +@mark_sync_test def test_closes_connection_after_tx_close(pool, test_run_args): with Session(pool, SessionConfig()) as session: with session.begin_transaction() as tx: @@ -137,6 +157,7 @@ def test_closes_connection_after_tx_close(pool, test_run_args): @pytest.mark.parametrize("test_run_args", ( ("RETURN $x", {"x": 1}), ("RETURN 1",) )) +@mark_sync_test def test_closes_connection_after_tx_commit(pool, test_run_args): with Session(pool, SessionConfig()) as session: with session.begin_transaction() as tx: @@ -149,8 +170,11 @@ def test_closes_connection_after_tx_commit(pool, test_run_args): @pytest.mark.parametrize("bookmarks", (None, [], ["abc"], ["foo", "bar"])) +@mark_sync_test def test_session_returns_bookmark_directly(pool, bookmarks): - with Session(pool, SessionConfig(bookmarks=bookmarks)) as session: + with Session( + pool, SessionConfig(bookmarks=bookmarks) + ) as session: if bookmarks: assert session.last_bookmark() == bookmarks[-1] else: @@ -163,6 +187,7 @@ def test_session_returns_bookmark_directly(pool, bookmarks): ({"how about": "no?"}, TypeError), (["I don't", "think so"], TypeError), )) +@mark_sync_test def test_session_run_wrong_types(pool, query, error_type): with Session(pool, SessionConfig()) as session: with pytest.raises(error_type): @@ -170,6 +195,7 @@ def test_session_run_wrong_types(pool, query, error_type): @pytest.mark.parametrize("tx_type", ("write_transaction", "read_transaction")) +@mark_sync_test def test_tx_function_argument_type(pool, tx_type): def work(tx): assert isinstance(tx, Transaction) @@ -186,6 +212,7 @@ def work(tx): {"timeout": 5, "metadata": {"foo": "bar"}}, )) +@mark_sync_test def test_decorated_tx_function_argument_type(pool, tx_type, decorator_kwargs): @unit_of_work(**decorator_kwargs) def work(tx): @@ -195,6 +222,7 @@ def work(tx): getattr(session, tx_type)(work) +@mark_sync_test def test_session_tx_type(pool): with Session(pool, SessionConfig()) as session: tx = session.begin_transaction() @@ -225,7 +253,10 @@ def test_session_tx_type(pool): ({"x": {(1, 2): '1+2i', (2, 0): '2'}}, TypeError), )) @pytest.mark.parametrize("run_type", ("auto", "unmanaged", "managed")) -def test_session_run_with_parameters(pool, parameters, error_type, run_type): +@mark_sync_test +def test_session_run_with_parameters( + pool, parameters, error_type, run_type +): @contextmanager def raises(): if error_type is not None: diff --git a/tests/unit/work/test_transaction.py b/tests/unit/sync/work/test_transaction.py similarity index 98% rename from tests/unit/work/test_transaction.py rename to tests/unit/sync/work/test_transaction.py index 06e75566..b5b40283 100644 --- a/tests/unit/work/test_transaction.py +++ b/tests/unit/sync/work/test_transaction.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - # Copyright (c) "Neo4j" # Neo4j Sweden AB [http://neo4j.com] # @@ -18,11 +15,9 @@ # See the License for the specific language governing permissions and # limitations under the License. + +from unittest.mock import MagicMock from uuid import uuid4 -from unittest.mock import ( - MagicMock, - NonCallableMagicMock, -) import pytest diff --git a/tests/unit/test_import_neo4j.py b/tests/unit/test_import_neo4j.py deleted file mode 100644 index a7ab16e1..00000000 --- a/tests/unit/test_import_neo4j.py +++ /dev/null @@ -1,171 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [http://neo4j.com] -# -# This file is part of Neo4j. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import pytest - -# python -m pytest tests/unit/test_import_neo4j.py -s -v - - -def test_import_dunder_version(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_dunder_version - from neo4j import __version__ - - -def test_import_graphdatabase(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_graphdatabase - from neo4j import GraphDatabase - - -def test_import_driver(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_driver - from neo4j import Driver - - -def test_import_boltdriver(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_boltdriver - from neo4j import BoltDriver - - -def test_import_neo4jdriver(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_neo4jdriver - from neo4j import Neo4jDriver - - -def test_import_auth(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_auth - from neo4j import Auth - - -def test_import_authtoken(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_authtoken - from neo4j import AuthToken - - -def test_import_basic_auth(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_auth - from neo4j import basic_auth - - -def test_import_kerberos_auth(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_kerberos_auth - from neo4j import kerberos_auth - - -def test_import_custom_auth(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_custom_auth - from neo4j import custom_auth - - -def test_import_read_access(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_read_access - from neo4j import READ_ACCESS - - -def test_import_write_access(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_write_access - from neo4j import WRITE_ACCESS - - -def test_import_transaction(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_transaction - from neo4j import Transaction - - -def test_import_record(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_record - from neo4j import Record - - -def test_import_session(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_session - from neo4j import Session - - -def test_import_sessionconfig(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_sessionconfig - from neo4j import SessionConfig - - -def test_import_query(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_query - from neo4j import Query - - -def test_import_result(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_result - from neo4j import Result - - -def test_import_resultsummary(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_resultsummary - from neo4j import ResultSummary - - -def test_import_unit_of_work(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_unit_of_work - from neo4j import unit_of_work - - -def test_import_config(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_config - from neo4j import Config - - -def test_import_poolconfig(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_poolconfig - from neo4j import PoolConfig - - -def test_import_graph(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_graph - import neo4j.graph as graph - - -def test_import_graph_node(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_graph_node - from neo4j.graph import Node - - -def test_import_graph_path(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_graph_path - from neo4j.graph import Path - - -def test_import_graph_graph(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_graph_graph - from neo4j.graph import Graph - - -def test_import_spatial(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_spatial - import neo4j.spatial as spatial - - -def test_import_time(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_time - import neo4j.time as time - - -def test_import_exceptions(): - # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_exceptions - import neo4j.exceptions as exceptions - -