diff --git a/.editorconfig b/.editorconfig
new file mode 100644
index 00000000..823ad91d
--- /dev/null
+++ b/.editorconfig
@@ -0,0 +1,22 @@
+# top-most EditorConfig file
+root = true
+
+# Unix-style newlines with a newline ending every file
+[*]
+end_of_line = lf
+insert_final_newline = true
+charset = utf-8
+
+[*.{py,js,rst,txt,sh,bat}]
+trim_trailing_whitespace = true
+
+[{Makefile,Dockerfile}]
+trim_trailing_whitespace = true
+
+[*.bat]
+end_of_line = crlf
+
+[*.py]
+max_line_length = 79
+indent_style = space
+indent_size = 4
diff --git a/.gitignore b/.gitignore
index 6f5d483d..f485e71b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -29,3 +29,4 @@ neo4j-enterprise-*
*.so
testkit/CAs
+testkit/CustomCAs
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 00000000..c4eafcb5
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,32 @@
+# See https://pre-commit.com for more information
+# See https://pre-commit.com/hooks.html for more hooks
+repos:
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.0.1
+ hooks:
+ - id: check-case-conflict
+ - id: check-docstring-first
+ - id: check-symlinks
+ - id: destroyed-symlinks
+ - id: end-of-file-fixer
+ exclude_types:
+ - image
+ - id: fix-encoding-pragma
+ args: [ --remove ]
+ - id: mixed-line-ending
+ args: [ --fix=lf ]
+ exclude_types:
+ - batch
+ - id: trailing-whitespace
+ args: [ --markdown-linebreak-ext=md ]
+ - repo: https://github.com/pycqa/isort
+ rev: 5.10.0
+ hooks:
+ - id: isort
+ - repo: local
+ hooks:
+ - id: unasync
+ name: unasync
+ entry: bin/make-unasync
+ language: system
+ files: "^(neo4j/_async|tests/unit/async_|testkitbackend/_async)/.*"
diff --git a/CHANGELOG.md b/CHANGELOG.md
index e98eeeb8..4f6e2000 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -4,7 +4,10 @@
- Python 3.10 support added
- Python 3.6 support has been dropped.
-
+- `Result`, `Session`, and `Transaction` can no longer be imported from
+ `neo4j.work`. They should've been imported from `neo4j` all along.
+- Experimental pipelines feature has been removed.
+- Experimental async driver has been added.
## Version 4.4
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index b6ebbe8d..d2656fd6 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -51,12 +51,32 @@ Occasionally, we might also have logistical, commercial, or legal reasons why we
Remember that many community members have become regular contributors and some are now even Neo employees!
+## Specifically for this project:
+
+All code in `_sync` or `sync` folders is auto-generated. Don't change it, but
+install the pre-commit hooks as described below insted. They will take care of
+updating the code if necessary.
+
+Setting up the development environment:
+ * Install Python 3.7+
+ * Install the requirements
+ ```bash
+ $ python3 -m pip install -U pip
+ $ python3 -m pip install -Ur requirements-dev.txt
+ ```
+ * Install the pre-commit hook, that will do some code-format-checking everytime
+ you commit.
+ ```bash
+ $ pre-commit install
+ ```
+
+
## Got an idea for a new project?
If you have an idea for a new tool or library, start by talking to other people in the community.
Chances are that someone has a similar idea or may have already started working on it.
The best software comes from getting like minds together to solve a problem.
-And we'll do our best to help you promote and co-ordinate your Neo ecosystem projects.
+And we'll do our best to help you promote and co-ordinate your Neo4j ecosystem projects.
## Further reading
diff --git a/ISSUE_TEMPLATE.md b/ISSUE_TEMPLATE.md
index d413cc9c..0a9addf8 100644
--- a/ISSUE_TEMPLATE.md
+++ b/ISSUE_TEMPLATE.md
@@ -9,7 +9,7 @@ If you simply want to get started or have a question on how to use a particular
[StackOverflow](http://stackoverflow.com/questions/tagged/neo4j) also hosts a ton of questions and might already have a discussion around your problem.
Make sure you have a look there too.
-If you want to make a feature request, please prefix your issue title with `[Feature Request]` so that it is clear to us.
+If you want to make a feature request, please prefix your issue title with `[Feature Request]` so that it is clear to us.
If you have a bug report however, please continue reading.
To help us understand your issue, please specify important details, primarily:
diff --git a/bin/make-unasync b/bin/make-unasync
new file mode 100755
index 00000000..c217a539
--- /dev/null
+++ b/bin/make-unasync
@@ -0,0 +1,311 @@
+#!/usr/bin/env python
+
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import collections
+import errno
+import os
+from pathlib import Path
+import re
+import sys
+import tokenize as std_tokenize
+
+import isort
+import isort.files
+import unasync
+
+
+ROOT_DIR = Path(__file__).parents[1].absolute()
+ASYNC_DIR = ROOT_DIR / "neo4j" / "_async"
+SYNC_DIR = ROOT_DIR / "neo4j" / "_sync"
+ASYNC_TEST_DIR = ROOT_DIR / "tests" / "unit" / "async_"
+SYNC_TEST_DIR = ROOT_DIR / "tests" / "unit" / "sync"
+ASYNC_TESTKIT_BACKEND_DIR = ROOT_DIR / "testkitbackend" / "_async"
+SYNC_TESTKIT_BACKEND_DIR = ROOT_DIR / "testkitbackend" / "_sync"
+UNASYNC_SUFFIX = ".unasync"
+
+PY_FILE_EXTENSIONS = {".py", ".pyi"}
+
+
+# copy from unasync for customization -----------------------------------------
+# https://github.com/python-trio/unasync
+# License: MIT and Apache2
+
+
+Token = collections.namedtuple(
+ "Token", ["type", "string", "start", "end", "line"]
+)
+
+
+def _makedirs_existok(dir):
+ try:
+ os.makedirs(dir)
+ except OSError as e:
+ if e.errno != errno.EEXIST:
+ raise
+
+
+def _get_tokens(f):
+ if sys.version_info[0] == 2:
+ for tok in std_tokenize.generate_tokens(f.readline):
+ type_, string, start, end, line = tok
+ yield Token(type_, string, start, end, line)
+ else:
+ for tok in std_tokenize.tokenize(f.readline):
+ if tok.type == std_tokenize.ENCODING:
+ continue
+ yield tok
+
+
+def _tokenize(f):
+ last_end = (1, 0)
+ for tok in _get_tokens(f):
+ if last_end[0] < tok.start[0]:
+ yield "", std_tokenize.STRING, " \\\n"
+ last_end = (tok.start[0], 0)
+
+ space = ""
+ if tok.start > last_end:
+ assert tok.start[0] == last_end[0]
+ space = " " * (tok.start[1] - last_end[1])
+ yield space, tok.type, tok.string
+
+ last_end = tok.end
+ if tok.type in [std_tokenize.NEWLINE, std_tokenize.NL]:
+ last_end = (tok.end[0] + 1, 0)
+
+
+def _untokenize(tokens):
+ return "".join(space + tokval for space, tokval in tokens)
+
+
+# end of copy -----------------------------------------------------------------
+
+
+class CustomRule(unasync.Rule):
+ def __init__(self, *args, **kwargs):
+ super(CustomRule, self).__init__(*args, **kwargs)
+ self.out_files = []
+
+ def _unasync_tokens(self, tokens):
+ # copy from unasync to hook into string handling
+ # https://github.com/python-trio/unasync
+ # License: MIT and Apache2
+ # TODO __await__, ...?
+ used_space = None
+ for space, toknum, tokval in tokens:
+ if tokval in ["async", "await"]:
+ # When removing async or await, we want to use the whitespace
+ # that was before async/await before the next token so that
+ # `print(await stuff)` becomes `print(stuff)` and not
+ # `print( stuff)`
+ used_space = space
+ else:
+ if toknum == std_tokenize.NAME:
+ tokval = self._unasync_name(tokval)
+ elif toknum == std_tokenize.STRING:
+ if tokval[0] == tokval[1] and len(tokval) > 2:
+ # multiline string (`"""..."""` or `'''...'''`)
+ left_quote, name, right_quote = \
+ tokval[:3], tokval[3:-3], tokval[-3:]
+ else:
+ # simple string (`"..."` or `'...'`)
+ left_quote, name, right_quote = \
+ tokval[:1], tokval[1:-1], tokval[-1:]
+ tokval = \
+ left_quote + self._unasync_string(name) + right_quote
+ if used_space is None:
+ used_space = space
+ yield (used_space, tokval)
+ used_space = None
+
+ def _unasync_string(self, name):
+ start = 0
+ end = 1
+ out = ""
+ while end < len(name):
+ sub_name = name[start:end]
+ if sub_name.isidentifier():
+ end += 1
+ else:
+ if end == start + 1:
+ out += sub_name
+ start += 1
+ end += 1
+ else:
+ out += self._unasync_prefix(name[start:(end - 1)])
+ start = end - 1
+
+ sub_name = name[start:]
+ if sub_name.isidentifier():
+ out += self._unasync_prefix(name[start:])
+ else:
+ out += sub_name
+
+ # very boiled down unasync version that removes "async" and "await"
+ # substrings.
+ out = re.subn(r"(^|\s+|(?<=\W))(?:async|await)\s+", r"\1", out,
+ flags=re.MULTILINE)[0]
+ # Convert doc-reference names from 'async-xyz' to 'xyz'
+ out = re.subn(r":ref:`async-", ":ref:`", out)[0]
+ return out
+
+ def _unasync_prefix(self, name):
+ # Convert class names from 'AsyncXyz' to 'Xyz'
+ if len(name) > 5 and name.startswith("Async") and name[5].isupper():
+ return name[5:]
+ # Convert variable/method/function names from 'async_xyz' to 'xyz'
+ elif len(name) > 6 and name.startswith("async_"):
+ return name[6:]
+ return name
+
+ def _unasync_name(self, name):
+ # copy from unasync to customize renaming rules
+ # https://github.com/python-trio/unasync
+ # License: MIT and Apache2
+ if name in self.token_replacements:
+ return self.token_replacements[name]
+ return self._unasync_prefix(name)
+
+ def _unasync_file(self, filepath):
+ # copy from unasync to append file suffix to out path
+ # https://github.com/python-trio/unasync
+ # License: MIT and Apache2
+ with open(filepath, "rb") as f:
+ write_kwargs = {}
+ if sys.version_info[0] >= 3:
+ encoding, _ = std_tokenize.detect_encoding(f.readline)
+ write_kwargs["encoding"] = encoding
+ f.seek(0)
+ tokens = _tokenize(f)
+ tokens = self._unasync_tokens(tokens)
+ result = _untokenize(tokens)
+ outfile_path = filepath.replace(self.fromdir, self.todir)
+ outfile_path += UNASYNC_SUFFIX
+ self.out_files.append(outfile_path)
+ _makedirs_existok(os.path.dirname(outfile_path))
+ with open(outfile_path, "w", **write_kwargs) as f:
+ print(result, file=f, end="")
+
+
+def apply_unasync(files):
+ """Generate sync code from async code."""
+
+ additional_main_replacements = {}
+ additional_test_replacements = {
+ "_async": "_sync",
+ "mark_async_test": "mark_sync_test",
+ }
+ additional_testkit_backend_replacements = {}
+ rules = [
+ CustomRule(
+ fromdir=str(ASYNC_DIR),
+ todir=str(SYNC_DIR),
+ additional_replacements=additional_main_replacements,
+ ),
+ CustomRule(
+ fromdir=str(ASYNC_TEST_DIR),
+ todir=str(SYNC_TEST_DIR),
+ additional_replacements=additional_test_replacements,
+ ),
+ CustomRule(
+ fromdir=str(ASYNC_TESTKIT_BACKEND_DIR),
+ todir=str(SYNC_TESTKIT_BACKEND_DIR),
+ additional_replacements=additional_testkit_backend_replacements,
+ ),
+ ]
+
+ if not files:
+ paths = list(ASYNC_DIR.rglob("*"))
+ paths += list(ASYNC_TEST_DIR.rglob("*"))
+ paths += list(ASYNC_TESTKIT_BACKEND_DIR.rglob("*"))
+ else:
+ paths = [ROOT_DIR / Path(f) for f in files]
+ filtered_paths = []
+ for path in paths:
+ if path.suffix in PY_FILE_EXTENSIONS:
+ filtered_paths.append(path)
+
+ unasync.unasync_files(map(str, filtered_paths), rules)
+
+ return [Path(path) for rule in rules for path in rule.out_files]
+
+
+def apply_isort(paths):
+ """Sort imports in generated sync code.
+
+ Since classes in imports are renamed from AsyncXyz to Xyz, the alphabetical
+ order of the import can change.
+ """
+ isort_config = isort.Config(settings_path=str(ROOT_DIR), quiet=True)
+
+ for path in paths:
+ isort.file(str(path), config=isort_config)
+
+ return paths
+
+
+def apply_changes(paths):
+ def files_equal(path1, path2):
+ with open(path1, "rb") as f1:
+ with open(path2, "rb") as f2:
+ data1 = f1.read(1024)
+ data2 = f2.read(1024)
+ while data1 or data2:
+ if data1 != data2:
+ changed_paths[path1] = path2
+ return False
+ data1 = f1.read(1024)
+ data2 = f2.read(1024)
+ return True
+
+ changed_paths = {}
+
+ for in_path in paths:
+ out_path = Path(str(in_path)[:-len(UNASYNC_SUFFIX)])
+ if not out_path.is_file():
+ changed_paths[in_path] = out_path
+ continue
+ if not files_equal(in_path, out_path):
+ changed_paths[in_path] = out_path
+ continue
+ in_path.unlink()
+
+ for in_path, out_path in changed_paths.items():
+ in_path.replace(out_path)
+
+ return list(changed_paths.values())
+
+
+def main():
+ files = None
+ if len(sys.argv) >= 1:
+ files = sys.argv[1:]
+ paths = apply_unasync(files)
+ paths = apply_isort(paths)
+ changed_paths = apply_changes(paths)
+
+ if changed_paths:
+ for path in changed_paths:
+ print("Updated " + str(path))
+ exit(1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/docs/source/_static/anchor_new_target.js b/docs/source/_static/anchor_new_target.js
index 367af0a2..015df714 100644
--- a/docs/source/_static/anchor_new_target.js
+++ b/docs/source/_static/anchor_new_target.js
@@ -1,4 +1,4 @@
$(document).ready(function () {
$('a[href^="http://"], a[href^="https://"]').not('a[class*=internal]').attr('target', '_blank');
-});
\ No newline at end of file
+});
diff --git a/docs/source/api.rst b/docs/source/api.rst
index 3481480f..f10c71b9 100644
--- a/docs/source/api.rst
+++ b/docs/source/api.rst
@@ -11,34 +11,35 @@ GraphDatabase
Driver Construction
===================
-The :class:`neo4j.Driver` construction is via a `classmethod` on the :class:`neo4j.GraphDatabase` class.
+The :class:`neo4j.Driver` construction is done via a `classmethod` on the :class:`neo4j.GraphDatabase` class.
.. autoclass:: neo4j.GraphDatabase
:members: driver
-Example, driver creation:
+Driver creation example:
.. code-block:: python
from neo4j import GraphDatabase
uri = "neo4j://example.com:7687"
- driver = GraphDatabase.driver(uri, auth=("neo4j", "password"), max_connection_lifetime=1000)
+ driver = GraphDatabase.driver(uri, auth=("neo4j", "password"))
driver.close() # close the driver object
-For basic auth, this can be a simple tuple, for example:
+For basic authentication, `auth` can be a simple tuple, for example:
.. code-block:: python
auth = ("neo4j", "password")
-This will implicitly create a :class:`neo4j.Auth` with a ``scheme="basic"``
+This will implicitly create a :class:`neo4j.Auth` with a ``scheme="basic"``.
+Other authentication methods are described under :ref:`auth-ref`.
-Example, with block context:
+``with`` block context example:
.. code-block:: python
@@ -330,7 +331,8 @@ For example:
Connection details held by the :class:`neo4j.Driver` are immutable.
Therefore if, for example, a password is changed, a replacement :class:`neo4j.Driver` object must be created.
-More than one :class:`.Driver` may be required if connections to multiple databases, or connections as multiple users, are required.
+More than one :class:`.Driver` may be required if connections to multiple databases, or connections as multiple users, are required,
+unless when using impersonation (:ref:`impersonated-user-ref`).
:class:`neo4j.Driver` objects are thread-safe but cannot be shared across processes.
Therefore, ``multithreading`` should generally be preferred over ``multiprocessing`` for parallel database access.
@@ -345,11 +347,9 @@ BoltDriver
URI schemes:
``bolt``, ``bolt+ssc``, ``bolt+s``
-Driver subclass:
- :class:`neo4j.BoltDriver`
+Will result in:
-..
- .. autoclass:: neo4j.BoltDriver
+.. autoclass:: neo4j.BoltDriver
.. _neo4j-driver-ref:
@@ -360,11 +360,9 @@ Neo4jDriver
URI schemes:
``neo4j``, ``neo4j+ssc``, ``neo4j+s``
-Driver subclass:
- :class:`neo4j.Neo4jDriver`
+Will result in:
-..
- .. autoclass:: neo4j.Neo4jDriver
+.. autoclass:: neo4j.Neo4jDriver
***********************
@@ -374,7 +372,7 @@ All database activity is co-ordinated through two mechanisms: the :class:`neo4j.
A :class:`neo4j.Session` is a logical container for any number of causally-related transactional units of work.
Sessions automatically provide guarantees of causal consistency within a clustered environment but multiple sessions can also be causally chained if required.
-Sessions provide the top-level of containment for database activity.
+Sessions provide the top level of containment for database activity.
Session creation is a lightweight operation and *sessions are not thread safe*.
Connections are drawn from the :class:`neo4j.Driver` connection pool as required.
@@ -604,7 +602,8 @@ Example:
def create_person(driver, name):
with driver.session(default_access_mode=neo4j.WRITE_ACCESS) as session:
- result = session.run("CREATE (a:Person { name: $name }) RETURN id(a) AS node_id", name=name)
+ query = "CREATE (a:Person { name: $name }) RETURN id(a) AS node_id"
+ result = session.run(query, name=name)
record = result.single()
return record["node_id"]
@@ -665,13 +664,15 @@ Example:
tx.close()
def create_person_node(tx):
+ query = "CREATE (a:Person { name: $name }) RETURN id(a) AS node_id"
name = "default_name"
- result = tx.run("CREATE (a:Person { name: $name }) RETURN id(a) AS node_id", name=name)
+ result = tx.run(query, name=name)
record = result.single()
return record["node_id"]
def set_person_name(tx, node_id, name):
- result = tx.run("MATCH (a:Person) WHERE id(a) = $id SET a.name = $name", id=node_id, name=name)
+ query = "MATCH (a:Person) WHERE id(a) = $id SET a.name = $name"
+ result = tx.run(query, id=node_id, name=name)
info = result.consume()
# use the info for logging etc.
@@ -698,7 +699,8 @@ Example:
node_id = session.write_transaction(create_person_tx, name)
def create_person_tx(tx, name):
- result = tx.run("CREATE (a:Person { name: $name }) RETURN id(a) AS node_id", name=name)
+ query = "CREATE (a:Person { name: $name }) RETURN id(a) AS node_id"
+ result = tx.run(query, name=name)
record = result.single()
return record["node_id"]
@@ -708,12 +710,6 @@ To exert more control over how a transaction function is carried out, the :func:
-
-
-
-
-
-
******
Result
******
diff --git a/docs/source/async_api.rst b/docs/source/async_api.rst
new file mode 100644
index 00000000..989ab200
--- /dev/null
+++ b/docs/source/async_api.rst
@@ -0,0 +1,498 @@
+.. _async-api-documentation:
+
+#######################
+Async API Documentation
+#######################
+
+.. warning::
+ The whole async API is currently in experimental phase.
+
+ This means everything documented on this page might be removed or change
+ its API at any time (including in patch releases).
+
+******************
+AsyncGraphDatabase
+******************
+
+Async Driver Construction
+=========================
+
+The :class:`neo4j.AsyncDriver` construction is done via a `classmethod` on the :class:`neo4j.AsyncGraphDatabase` class.
+
+.. autoclass:: neo4j.AsyncGraphDatabase
+ :members: driver
+
+
+Driver creation example:
+
+.. code-block:: python
+
+ import asyncio
+
+ from neo4j import AsyncGraphDatabase
+
+ async def main():
+ uri = "neo4j://example.com:7687"
+ driver = AsyncGraphDatabase.driver(uri, auth=("neo4j", "password"))
+
+ await driver.close() # close the driver object
+
+ asyncio.run(main())
+
+
+For basic authentication, ``auth`` can be a simple tuple, for example:
+
+.. code-block:: python
+
+ auth = ("neo4j", "password")
+
+This will implicitly create a :class:`neo4j.Auth` with a ``scheme="basic"``.
+Other authentication methods are described under :ref:`auth-ref`.
+
+``with`` block context example:
+
+.. code-block:: python
+
+ import asyncio
+
+ from neo4j import AsyncGraphDatabase
+
+ async def main():
+ uri = "neo4j://example.com:7687"
+ auth = ("neo4j", "password")
+ async with AsyncGraphDatabase.driver(uri, auth=auth) as driver:
+ # use the driver
+ ...
+
+ asyncio.run(main())
+
+
+.. _async-uri-ref:
+
+URI
+===
+
+On construction, the `scheme` of the URI determines the type of :class:`neo4j.AsyncDriver` object created.
+
+Available valid URIs:
+
++ ``bolt://host[:port]``
++ ``bolt+ssc://host[:port]``
++ ``bolt+s://host[:port]``
++ ``neo4j://host[:port][?routing_context]``
++ ``neo4j+ssc://host[:port][?routing_context]``
++ ``neo4j+s://host[:port][?routing_context]``
+
+.. code-block:: python
+
+ uri = "bolt://example.com:7687"
+
+.. code-block:: python
+
+ uri = "neo4j://example.com:7687"
+
+Each supported scheme maps to a particular :class:`neo4j.AsyncDriver` subclass that implements a specific behaviour.
+
++------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+
+| URI Scheme | Driver Object and Setting |
++========================+=============================================================================================================================================+
+| bolt | :ref:`async-bolt-driver-ref` with no encryption. |
++------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+
+| bolt+ssc | :ref:`async-bolt-driver-ref` with encryption (accepts self signed certificates). |
++------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+
+| bolt+s | :ref:`async-bolt-driver-ref` with encryption (accepts only certificates signed by a certificate authority), full certificate checks. |
++------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+
+| neo4j | :ref:`async-neo4j-driver-ref` with no encryption. |
++------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+
+| neo4j+ssc | :ref:`async-neo4j-driver-ref` with encryption (accepts self signed certificates). |
++------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+
+| neo4j+s | :ref:`async-neo4j-driver-ref` with encryption (accepts only certificates signed by a certificate authority), full certificate checks. |
++------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+
+
+.. note::
+
+ See https://neo4j.com/docs/operations-manual/current/configuration/ports/ for Neo4j ports.
+
+
+
+***********
+AsyncDriver
+***********
+
+Every Neo4j-backed application will require a :class:`neo4j.AsyncDriver` object.
+
+This object holds the details required to establish connections with a Neo4j database, including server URIs, credentials and other configuration.
+:class:`neo4j.AsyncDriver` objects hold a connection pool from which :class:`neo4j.AsyncSession` objects can borrow connections.
+Closing a driver will immediately shut down all connections in the pool.
+
+.. autoclass:: neo4j.AsyncDriver()
+ :members: session, close
+
+
+.. _async-driver-configuration-ref:
+
+Async Driver Configuration
+==========================
+
+:class:`neo4j.AsyncDriver` is configured exactly like :class:`neo4j.Driver`
+(see :ref:`driver-configuration-ref`). The only difference is that the async
+driver accepts an async custom resolver function:
+
+.. _async-resolver-ref:
+
+``resolver``
+------------
+A custom resolver function to resolve host and port values ahead of DNS resolution.
+This function is called with a 2-tuple of (host, port) and should return an iterable of 2-tuples (host, port).
+
+If no custom resolver function is supplied, the internal resolver moves straight to regular DNS resolution.
+
+The custom resolver function can but does not have to be a coroutine.
+
+For example:
+
+.. code-block:: python
+
+ from neo4j import AsyncGraphDatabase
+
+ async def custom_resolver(socket_address):
+ if socket_address == ("example.com", 9999):
+ yield "::1", 7687
+ yield "127.0.0.1", 7687
+ else:
+ from socket import gaierror
+ raise gaierror("Unexpected socket address %r" % socket_address)
+
+ # alternatively
+ def custom_resolver(socket_address):
+ ...
+
+ driver = AsyncGraphDatabase.driver("neo4j://example.com:9999",
+ auth=("neo4j", "password"),
+ resolver=custom_resolver)
+
+
+:Default: ``None``
+
+
+
+Driver Object Lifetime
+======================
+
+For general applications, it is recommended to create one top-level :class:`neo4j.AsyncDriver` object that lives for the lifetime of the application.
+
+For example:
+
+.. code-block:: python
+
+ from neo4j import AsyncGraphDatabase
+
+ class Application:
+
+ def __init__(self, uri, user, password)
+ self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
+
+ async def close(self):
+ await self.driver.close()
+
+Connection details held by the :class:`neo4j.AsyncDriver` are immutable.
+Therefore if, for example, a password is changed, a replacement :class:`neo4j.AsyncDriver` object must be created.
+More than one :class:`.AsyncDriver` may be required if connections to multiple databases, or connections as multiple users, are required,
+unless when using impersonation (:ref:`impersonated-user-ref`).
+
+:class:`neo4j.AsyncDriver` objects are safe to be used in concurrent coroutines.
+They are not thread-safe.
+
+
+.. _async-bolt-driver-ref:
+
+AsyncBoltDriver
+===============
+
+URI schemes:
+ ``bolt``, ``bolt+ssc``, ``bolt+s``
+
+Will result in:
+
+.. autoclass:: neo4j.AsyncBoltDriver
+
+
+.. _async-neo4j-driver-ref:
+
+AsyncNeo4jDriver
+================
+
+URI schemes:
+ ``neo4j``, ``neo4j+ssc``, ``neo4j+s``
+
+Will result in:
+
+.. autoclass:: neo4j.AsyncNeo4jDriver
+
+
+*********************************
+AsyncSessions & AsyncTransactions
+*********************************
+All database activity is co-ordinated through two mechanisms: the :class:`neo4j.AsyncSession` and the :class:`neo4j.AsyncTransaction`.
+
+A :class:`neo4j.AsyncSession` is a logical container for any number of causally-related transactional units of work.
+Sessions automatically provide guarantees of causal consistency within a clustered environment but multiple sessions can also be causally chained if required.
+Sessions provide the top level of containment for database activity.
+Session creation is a lightweight operation and *sessions cannot be shared between coroutines*.
+
+Connections are drawn from the :class:`neo4j.AsyncDriver` connection pool as required.
+
+A :class:`neo4j.AsyncTransaction` is a unit of work that is either committed in its entirety or is rolled back on failure.
+
+
+.. _async-session-construction-ref:
+
+*************************
+AsyncSession Construction
+*************************
+
+To construct a :class:`neo4j.AsyncSession` use the :meth:`neo4j.AsyncDriver.session` method.
+
+.. code-block:: python
+
+ import asyncio
+
+ from neo4j import AsyncGraphDatabase
+
+ async def main():
+ driver = AsyncGraphDatabase(uri, auth=(user, password))
+ session = driver.session()
+ result = await session.run("MATCH (a:Person) RETURN a.name AS name")
+ names = [record["name"] async for record in result]
+ await session.close()
+ await driver.close()
+
+ asyncio.run(main())
+
+
+Sessions will often be created and destroyed using a *with block context*.
+
+.. code-block:: python
+
+ async with driver.session() as session:
+ result = await session.run("MATCH (a:Person) RETURN a.name AS name")
+ # do something with the result...
+
+
+Sessions will often be created with some configuration settings, see :ref:`async-session-configuration-ref`.
+
+.. code-block:: python
+
+ async with driver.session(database="example_database",
+ fetch_size=100) as session:
+ result = await session.run("MATCH (a:Person) RETURN a.name AS name")
+ # do something with the result...
+
+
+************
+AsyncSession
+************
+
+.. autoclass:: neo4j.AsyncSession()
+
+ .. automethod:: close
+
+ .. automethod:: run
+
+ .. automethod:: last_bookmark
+
+ .. automethod:: begin_transaction
+
+ .. automethod:: read_transaction
+
+ .. automethod:: write_transaction
+
+
+
+.. _async-session-configuration-ref:
+
+Session Configuration
+=====================
+
+:class:`neo4j.AsyncSession` is configured exactly like :class:`neo4j.Session`
+(see :ref:`session-configuration-ref`).
+
+
+****************
+AsyncTransaction
+****************
+
+Neo4j supports three kinds of async transaction:
+
++ :ref:`async-auto-commit-transactions-ref`
++ :ref:`async-explicit-transactions-ref`
++ :ref:`async-managed-transactions-ref`
+
+Each has pros and cons but if in doubt, use a managed transaction with a `transaction function`.
+
+
+.. _async-auto-commit-transactions-ref:
+
+Async Auto-commit Transactions
+==============================
+Auto-commit transactions are the simplest form of transaction, available via :py:meth:`neo4j.AsyncSession.run`.
+
+These are easy to use but support only one statement per transaction and are not automatically retried on failure.
+Auto-commit transactions are also the only way to run ``PERIODIC COMMIT`` statements, since this Cypher clause manages its own transactions internally.
+
+Example:
+
+.. code-block:: python
+
+ import neo4j
+
+ async def create_person(driver, name):
+ async with driver.session(
+ default_access_mode=neo4j.WRITE_ACCESS
+ ) as session:
+ query = "CREATE (a:Person { name: $name }) RETURN id(a) AS node_id"
+ result = await session.run(query, name=name)
+ record = await result.single()
+ return record["node_id"]
+
+Example:
+
+.. code-block:: python
+
+ import neo4j
+
+ async def get_numbers(driver):
+ numbers = []
+ async with driver.session(
+ default_access_mode=neo4j.READ_ACCESS
+ ) as session:
+ result = await session.run("UNWIND [1, 2, 3] AS x RETURN x")
+ async for record in result:
+ numbers.append(record["x"])
+ return numbers
+
+
+.. _async-explicit-transactions-ref:
+
+Explicit Async Transactions
+===========================
+Explicit transactions support multiple statements and must be created with an explicit :py:meth:`neo4j.AsyncSession.begin_transaction` call.
+
+This creates a new :class:`neo4j.AsyncTransaction` object that can be used to run Cypher.
+
+It also gives applications the ability to directly control `commit` and `rollback` activity.
+
+.. autoclass:: neo4j.AsyncTransaction()
+
+ .. automethod:: run
+
+ .. automethod:: close
+
+ .. automethod:: closed
+
+ .. automethod:: commit
+
+ .. automethod:: rollback
+
+Closing an explicit transaction can either happen automatically at the end of a ``async with`` block,
+or can be explicitly controlled through the :py:meth:`neo4j.AsyncTransaction.commit`, :py:meth:`neo4j.AsyncTransaction.rollback` or :py:meth:`neo4j.AsyncTransaction.close` methods.
+
+Explicit transactions are most useful for applications that need to distribute Cypher execution across multiple functions for the same transaction.
+
+Example:
+
+.. code-block:: python
+
+ import neo4j
+
+ async def create_person(driver, name):
+ async with driver.session(
+ default_access_mode=neo4j.WRITE_ACCESS
+ ) as session:
+ tx = await session.begin_transaction()
+ node_id = await create_person_node(tx)
+ await set_person_name(tx, node_id, name)
+ await tx.commit()
+ await tx.close()
+
+ async def create_person_node(tx):
+ query = "CREATE (a:Person { name: $name }) RETURN id(a) AS node_id"
+ name = "default_name"
+ result = await tx.run(query, name=name)
+ record = await result.single()
+ return record["node_id"]
+
+ async def set_person_name(tx, node_id, name):
+ query = "MATCH (a:Person) WHERE id(a) = $id SET a.name = $name"
+ result = await tx.run(query, id=node_id, name=name)
+ info = await result.consume()
+ # use the info for logging etc.
+
+.. _async-managed-transactions-ref:
+
+
+Managed Async Transactions (`transaction functions`)
+====================================================
+Transaction functions are the most powerful form of transaction, providing access mode override and retry capabilities.
+
++ :py:meth:`neo4j.AsyncSession.write_transaction`
++ :py:meth:`neo4j.AsyncSession.read_transaction`
+
+These allow a function object representing the transactional unit of work to be passed as a parameter.
+This function is called one or more times, within a configurable time limit, until it succeeds.
+Results should be fully consumed within the function and only aggregate or status values should be returned.
+Returning a live result object would prevent the driver from correctly managing connections and would break retry guarantees.
+
+Example:
+
+.. code-block:: python
+
+ async def create_person(driver, name)
+ async with driver.session() as session:
+ node_id = await session.write_transaction(create_person_tx, name)
+
+ async def create_person_tx(tx, name):
+ query = "CREATE (a:Person { name: $name }) RETURN id(a) AS node_id"
+ result = await tx.run(query, name=name)
+ record = await result.single()
+ return record["node_id"]
+
+To exert more control over how a transaction function is carried out, the :func:`neo4j.unit_of_work` decorator can be used.
+
+
+
+***********
+AsyncResult
+***********
+
+Every time a query is executed, a :class:`neo4j.AsyncResult` is returned.
+
+This provides a handle to the result of the query, giving access to the records within it as well as the result metadata.
+
+Results also contain a buffer that automatically stores unconsumed records when results are consumed out of order.
+
+A :class:`neo4j.AsyncResult` is attached to an active connection, through a :class:`neo4j.AsyncSession`, until all its content has been buffered or consumed.
+
+.. autoclass:: neo4j.AsyncResult()
+
+ .. describe:: iter(result)
+
+ .. automethod:: keys
+
+ .. automethod:: consume
+
+ .. automethod:: single
+
+ .. automethod:: peek
+
+ .. automethod:: graph
+
+ **This is experimental.** (See :ref:`filter-warnings-ref`)
+
+ .. automethod:: value
+
+ .. automethod:: values
+
+ .. automethod:: data
+
+See https://neo4j.com/docs/driver-manual/current/cypher-workflow/#driver-type-mapping for more about type mapping.
diff --git a/docs/source/conf.py b/docs/source/conf.py
index dae62ce1..72111d76 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-#
# Neo4j Bolt Driver for Python documentation build configuration file, created by
# sphinx-quickstart on Mon Sep 21 11:48:02 2015.
#
@@ -13,16 +10,20 @@
# All configuration values have a default; values that are commented out
# serve to show the default.
-import sys
+
import os
-import shlex
+import sys
+
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
sys.path.insert(0, os.path.abspath(os.path.join("..", "..")))
+
+
from neo4j.meta import version as project_version
+
# -- General configuration ------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
diff --git a/docs/source/errors.rst b/docs/source/errors.rst
index 0a4d5927..14db7d5a 100644
--- a/docs/source/errors.rst
+++ b/docs/source/errors.rst
@@ -43,4 +43,4 @@ Python Version:
Python Driver Version:
Neo4j Version:
-the code block with a description that produced the error and the error message.
\ No newline at end of file
+the code block with a description that produced the error and the error message.
diff --git a/docs/source/index.rst b/docs/source/index.rst
index ac4bda16..24ef2386 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -23,6 +23,8 @@ Topics
+ :ref:`api-documentation`
++ :ref:`async-api-documentation` (experimental)
+
+ :ref:`temporal-data-types`
+ :ref:`breaking-changes`
@@ -32,6 +34,7 @@ Topics
:hidden:
api.rst
+ async_api.rst
temporal_types.rst
breaking_changes.rst
diff --git a/docs/source/transactions.rst b/docs/source/transactions.rst
index 93dc3510..1bb4db2a 100644
--- a/docs/source/transactions.rst
+++ b/docs/source/transactions.rst
@@ -11,7 +11,7 @@ Sessions automatically provide guarantees of causal consistency within a cluster
Sessions
========
-Sessions provide the top-level of containment for database activity.
+Sessions provide the top level of containment for database activity.
Session creation is a lightweight operation and sessions are `not` thread safe.
Connections are drawn from the :class:`neo4j.Driver` connection pool as required; an idle session will not hold onto a connection.
diff --git a/neo4j/__init__.py b/neo4j/__init__.py
index cd0139b6..f54aa5e1 100644
--- a/neo4j/__init__.py
+++ b/neo4j/__init__.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -21,441 +18,117 @@
__all__ = [
"__version__",
- "GraphDatabase",
- "Driver",
- "BoltDriver",
- "Neo4jDriver",
+ "Address",
+ "AsyncBoltDriver",
+ "AsyncDriver",
+ "AsyncGraphDatabase",
+ "AsyncNeo4jDriver",
+ "AsyncResult",
+ "AsyncSession",
+ "AsyncTransaction",
"Auth",
"AuthToken",
"basic_auth",
- "kerberos_auth",
"bearer_auth",
- "custom_auth",
+ "BoltDriver",
"Bookmark",
- "ServerInfo",
- "Version",
- "READ_ACCESS",
- "WRITE_ACCESS",
+ "Config",
+ "custom_auth",
"DEFAULT_DATABASE",
- "TRUST_ALL_CERTIFICATES",
- "TRUST_SYSTEM_CA_SIGNED_CERTIFICATES",
- "Address",
+ "Driver",
+ "ExperimentalWarning",
+ "GraphDatabase",
"IPv4Address",
"IPv6Address",
- "Config",
+ "kerberos_auth",
+ "Neo4jDriver",
"PoolConfig",
- "WorkspaceConfig",
- "SessionConfig",
+ "Query",
+ "READ_ACCESS",
"Record",
- "Transaction",
"Result",
"ResultSummary",
- "SummaryCounters",
- "Query",
+ "ServerInfo",
"Session",
+ "SessionConfig",
+ "SummaryCounters",
+ "Transaction",
+ "TRUST_ALL_CERTIFICATES",
+ "TRUST_SYSTEM_CA_SIGNED_CERTIFICATES",
"unit_of_work",
- "ExperimentalWarning",
+ "Version",
+ "WorkspaceConfig",
+ "WRITE_ACCESS",
]
-from logging import getLogger
+from logging import getLogger
-from neo4j.addressing import (
+from ._async.driver import (
+ AsyncBoltDriver,
+ AsyncDriver,
+ AsyncGraphDatabase,
+ AsyncNeo4jDriver,
+)
+from ._async.work import (
+ AsyncResult,
+ AsyncSession,
+ AsyncTransaction,
+)
+from ._sync.driver import (
+ BoltDriver,
+ Driver,
+ GraphDatabase,
+ Neo4jDriver,
+)
+from ._sync.work import (
+ Result,
+ Session,
+ Transaction,
+)
+from .addressing import (
Address,
IPv4Address,
IPv6Address,
)
-from neo4j.api import (
+from .api import (
Auth, # TODO: Validate naming for Auth compared to other drivers.
+)
+from .api import (
AuthToken,
basic_auth,
- kerberos_auth,
bearer_auth,
- custom_auth,
Bookmark,
- ServerInfo,
- Version,
+ custom_auth,
+ DEFAULT_DATABASE,
+ kerberos_auth,
READ_ACCESS,
- WRITE_ACCESS,
+ ServerInfo,
SYSTEM_DATABASE,
- DEFAULT_DATABASE,
TRUST_ALL_CERTIFICATES,
TRUST_SYSTEM_CA_SIGNED_CERTIFICATES,
+ Version,
+ WRITE_ACCESS,
)
-from neo4j.conf import (
+from .conf import (
Config,
PoolConfig,
- WorkspaceConfig,
SessionConfig,
+ WorkspaceConfig,
)
-from neo4j.meta import (
+from .data import Record
+from .meta import (
experimental,
ExperimentalWarning,
get_user_agent,
version as __version__,
)
-from neo4j.data import (
- Record,
-)
-from neo4j.work.simple import (
+from .work import (
Query,
- Session,
- unit_of_work,
-)
-from neo4j.work.transaction import (
- Transaction,
-)
-from neo4j.work.result import (
- Result,
-)
-from neo4j.work.summary import (
ResultSummary,
SummaryCounters,
+ unit_of_work,
)
log = getLogger("neo4j")
-
-
-class GraphDatabase:
- """Accessor for :class:`neo4j.Driver` construction.
- """
-
- @classmethod
- def driver(cls, uri, *, auth=None, **config):
- """Create a driver.
-
- :param uri: the connection URI for the driver, see :ref:`uri-ref` for available URIs.
- :param auth: the authentication details, see :ref:`auth-ref` for available authentication details.
- :param config: driver configuration key-word arguments, see :ref:`driver-configuration-ref` for available key-word arguments.
-
- :return: :ref:`neo4j-driver-ref` or :ref:`bolt-driver-ref`
- """
-
- from neo4j.api import (
- parse_neo4j_uri,
- parse_routing_context,
- DRIVER_BOLT,
- DRIVER_NEO4j,
- SECURITY_TYPE_NOT_SECURE,
- SECURITY_TYPE_SELF_SIGNED_CERTIFICATE,
- SECURITY_TYPE_SECURE,
- URI_SCHEME_BOLT,
- URI_SCHEME_NEO4J,
- URI_SCHEME_BOLT_SELF_SIGNED_CERTIFICATE,
- URI_SCHEME_BOLT_SECURE,
- URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE,
- URI_SCHEME_NEO4J_SECURE,
- )
-
- driver_type, security_type, parsed = parse_neo4j_uri(uri)
-
- if "trust" in config.keys():
- if config.get("trust") not in [TRUST_ALL_CERTIFICATES, TRUST_SYSTEM_CA_SIGNED_CERTIFICATES]:
- from neo4j.exceptions import ConfigurationError
- raise ConfigurationError("The config setting `trust` values are {!r}".format(
- [
- TRUST_ALL_CERTIFICATES,
- TRUST_SYSTEM_CA_SIGNED_CERTIFICATES,
- ]
- ))
-
- if security_type in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE] and ("encrypted" in config.keys() or "trust" in config.keys()):
- from neo4j.exceptions import ConfigurationError
- raise ConfigurationError("The config settings 'encrypted' and 'trust' can only be used with the URI schemes {!r}. Use the other URI schemes {!r} for setting encryption settings.".format(
- [
- URI_SCHEME_BOLT,
- URI_SCHEME_NEO4J,
- ],
- [
- URI_SCHEME_BOLT_SELF_SIGNED_CERTIFICATE,
- URI_SCHEME_BOLT_SECURE,
- URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE,
- URI_SCHEME_NEO4J_SECURE,
- ]
- ))
-
- if security_type == SECURITY_TYPE_SECURE:
- config["encrypted"] = True
- elif security_type == SECURITY_TYPE_SELF_SIGNED_CERTIFICATE:
- config["encrypted"] = True
- config["trust"] = TRUST_ALL_CERTIFICATES
-
- if driver_type == DRIVER_BOLT:
- return cls.bolt_driver(parsed.netloc, auth=auth, **config)
- elif driver_type == DRIVER_NEO4j:
- routing_context = parse_routing_context(parsed.query)
- return cls.neo4j_driver(parsed.netloc, auth=auth, routing_context=routing_context, **config)
-
- @classmethod
- def bolt_driver(cls, target, *, auth=None, **config):
- """ Create a driver for direct Bolt server access that uses
- socket I/O and thread-based concurrency.
- """
- from neo4j._exceptions import BoltHandshakeError, BoltSecurityError
-
- try:
- return BoltDriver.open(target, auth=auth, **config)
- except (BoltHandshakeError, BoltSecurityError) as error:
- from neo4j.exceptions import ServiceUnavailable
- raise ServiceUnavailable(str(error)) from error
-
- @classmethod
- def neo4j_driver(cls, *targets, auth=None, routing_context=None, **config):
- """ Create a driver for routing-capable Neo4j service access
- that uses socket I/O and thread-based concurrency.
- """
- from neo4j._exceptions import BoltHandshakeError, BoltSecurityError
-
- try:
- return Neo4jDriver.open(*targets, auth=auth, routing_context=routing_context, **config)
- except (BoltHandshakeError, BoltSecurityError) as error:
- from neo4j.exceptions import ServiceUnavailable
- raise ServiceUnavailable(str(error)) from error
-
-
-class Direct:
-
- default_host = "localhost"
- default_port = 7687
-
- default_target = ":"
-
- def __init__(self, address):
- self._address = address
-
- @property
- def address(self):
- return self._address
-
- @classmethod
- def parse_target(cls, target):
- """ Parse a target string to produce an address.
- """
- if not target:
- target = cls.default_target
- address = Address.parse(target, default_host=cls.default_host,
- default_port=cls.default_port)
- return address
-
-
-class Routing:
-
- default_host = "localhost"
- default_port = 7687
-
- default_targets = ": :17601 :17687"
-
- def __init__(self, initial_addresses):
- self._initial_addresses = initial_addresses
-
- @property
- def initial_addresses(self):
- return self._initial_addresses
-
- @classmethod
- def parse_targets(cls, *targets):
- """ Parse a sequence of target strings to produce an address
- list.
- """
- targets = " ".join(targets)
- if not targets:
- targets = cls.default_targets
- addresses = Address.parse_list(targets, default_host=cls.default_host, default_port=cls.default_port)
- return addresses
-
-
-class Driver:
- """ Base class for all types of :class:`neo4j.Driver`, instances of which are
- used as the primary access point to Neo4j.
- """
-
- #: Connection pool
- _pool = None
-
- def __init__(self, pool):
- assert pool is not None
- self._pool = pool
-
- def __del__(self):
- self.close()
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_value, traceback):
- self.close()
-
- @property
- def encrypted(self):
- return bool(self._pool.pool_config.encrypted)
-
- def session(self, **config):
- """Create a session, see :ref:`session-construction-ref`
-
- :param config: session configuration key-word arguments, see :ref:`session-configuration-ref` for available key-word arguments.
-
- :returns: new :class:`neo4j.Session` object
- """
- raise NotImplementedError
-
- @experimental("The pipeline API is experimental and may be removed or changed in a future release")
- def pipeline(self, **config):
- """ Create a pipeline.
- """
- raise NotImplementedError
-
- def close(self):
- """ Shut down, closing any open connections in the pool.
- """
- self._pool.close()
-
- @experimental("The configuration may change in the future.")
- def verify_connectivity(self, **config):
- """ This verifies if the driver can connect to a remote server or a cluster
- by establishing a network connection with the remote and possibly exchanging
- a few data before closing the connection. It throws exception if fails to connect.
-
- Use the exception to further understand the cause of the connectivity problem.
-
- Note: Even if this method throws an exception, the driver still need to be closed via close() to free up all resources.
- """
- raise NotImplementedError
-
- @experimental("Feature support query, based on Bolt Protocol Version and Neo4j Server Version will change in the future.")
- def supports_multi_db(self):
- """ Check if the server or cluster supports multi-databases.
-
- :return: Returns true if the server or cluster the driver connects to supports multi-databases, otherwise false.
- :rtype: bool
- """
- with self.session() as session:
- session._connect(READ_ACCESS)
- return session._connection.supports_multiple_databases
-
-
-class BoltDriver(Direct, Driver):
- """ A :class:`.BoltDriver` is created from a ``bolt`` URI and addresses
- a single database machine. This may be a standalone server or could be a
- specific member of a cluster.
-
- Connections established by a :class:`.BoltDriver` are always made to the
- exact host and port detailed in the URI.
- """
-
- @classmethod
- def open(cls, target, *, auth=None, **config):
- """
- :param target:
- :param auth:
- :param config: The values that can be specified are found in :class: `neo4j.PoolConfig` and :class: `neo4j.WorkspaceConfig`
-
- :return:
- :rtype: :class: `neo4j.BoltDriver`
- """
- from neo4j.io import BoltPool
- address = cls.parse_target(target)
- pool_config, default_workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig)
- pool = BoltPool.open(address, auth=auth, pool_config=pool_config, workspace_config=default_workspace_config)
- return cls(pool, default_workspace_config)
-
- def __init__(self, pool, default_workspace_config):
- Direct.__init__(self, pool.address)
- Driver.__init__(self, pool)
- self._default_workspace_config = default_workspace_config
-
- def session(self, **config):
- """
- :param config: The values that can be specified are found in :class: `neo4j.SessionConfig`
-
- :return:
- :rtype: :class: `neo4j.Session`
- """
- from neo4j.work.simple import Session
- session_config = SessionConfig(self._default_workspace_config, config)
- SessionConfig.consume(config) # Consume the config
- return Session(self._pool, session_config)
-
- def pipeline(self, **config):
- from neo4j.work.pipelining import Pipeline, PipelineConfig
- pipeline_config = PipelineConfig(self._default_workspace_config, config)
- PipelineConfig.consume(config) # Consume the config
- return Pipeline(self._pool, pipeline_config)
-
- @experimental("The configuration may change in the future.")
- def verify_connectivity(self, **config):
- server_agent = None
- config["fetch_size"] = -1
- with self.session(**config) as session:
- result = session.run("RETURN 1 AS x")
- value = result.single().value()
- summary = result.consume()
- server_agent = summary.server.agent
- return server_agent
-
-
-class Neo4jDriver(Routing, Driver):
- """ A :class:`.Neo4jDriver` is created from a ``neo4j`` URI. The
- routing behaviour works in tandem with Neo4j's `Causal Clustering
- `_
- feature by directing read and write behaviour to appropriate
- cluster members.
- """
-
- @classmethod
- def open(cls, *targets, auth=None, routing_context=None, **config):
- from neo4j.io import Neo4jPool
- addresses = cls.parse_targets(*targets)
- pool_config, default_workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig)
- pool = Neo4jPool.open(*addresses, auth=auth, routing_context=routing_context, pool_config=pool_config, workspace_config=default_workspace_config)
- return cls(pool, default_workspace_config)
-
- def __init__(self, pool, default_workspace_config):
- Routing.__init__(self, pool.get_default_database_initial_router_addresses())
- Driver.__init__(self, pool)
- self._default_workspace_config = default_workspace_config
-
- def session(self, **config):
- session_config = SessionConfig(self._default_workspace_config, config)
- SessionConfig.consume(config) # Consume the config
- return Session(self._pool, session_config)
-
- def pipeline(self, **config):
- from neo4j.work.pipelining import Pipeline, PipelineConfig
- pipeline_config = PipelineConfig(self._default_workspace_config, config)
- PipelineConfig.consume(config) # Consume the config
- return Pipeline(self._pool, pipeline_config)
-
- @experimental("The configuration may change in the future.")
- def verify_connectivity(self, **config):
- """
- :raise ServiceUnavailable: raised if the server does not support routing or if routing support is broken.
- """
- # TODO: Improve and update Stub Test Server to be able to test.
- return self._verify_routing_connectivity()
-
- def _verify_routing_connectivity(self):
- from neo4j.exceptions import (
- Neo4jError,
- ServiceUnavailable,
- SessionExpired,
- )
-
- table = self._pool.get_routing_table_for_default_database()
- routing_info = {}
- for ix in list(table.routers):
- try:
- routing_info[ix] = self._pool.fetch_routing_info(
- address=table.routers[0],
- database=self._default_workspace_config.database,
- imp_user=self._default_workspace_config.impersonated_user,
- bookmarks=None,
- timeout=self._default_workspace_config
- .connection_acquisition_timeout
- )
- except (ServiceUnavailable, SessionExpired, Neo4jError):
- routing_info[ix] = None
- for key, val in routing_info.items():
- if val is not None:
- return routing_info
- raise ServiceUnavailable("Could not connect to any routing servers.")
diff --git a/tests/unit/time/__init__.py b/neo4j/_async/__init__.py
similarity index 94%
rename from tests/unit/time/__init__.py
rename to neo4j/_async/__init__.py
index 0665bdc9..b81a309d 100644
--- a/tests/unit/time/__init__.py
+++ b/neo4j/_async/__init__.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# coding: utf-8
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
diff --git a/neo4j/_async/driver.py b/neo4j/_async/driver.py
new file mode 100644
index 00000000..009efc32
--- /dev/null
+++ b/neo4j/_async/driver.py
@@ -0,0 +1,380 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import asyncio
+
+from .._async_compat.util import AsyncUtil
+from ..addressing import Address
+from ..api import (
+ READ_ACCESS,
+ TRUST_ALL_CERTIFICATES,
+ TRUST_SYSTEM_CA_SIGNED_CERTIFICATES,
+)
+from ..conf import (
+ Config,
+ PoolConfig,
+ SessionConfig,
+ WorkspaceConfig,
+)
+from ..meta import experimental
+
+
+class AsyncGraphDatabase:
+ """Accessor for :class:`neo4j.Driver` construction.
+ """
+
+ @classmethod
+ @AsyncUtil.experimental_async(
+ "neo4j async is in experimental phase. It might be removed or changed "
+ "at any time (including patch releases)."
+ )
+ def driver(cls, uri, *, auth=None, **config):
+ """Create a driver.
+
+ :param uri: the connection URI for the driver, see :ref:`async-uri-ref` for available URIs.
+ :param auth: the authentication details, see :ref:`auth-ref` for available authentication details.
+ :param config: driver configuration key-word arguments, see :ref:`async-driver-configuration-ref` for available key-word arguments.
+
+ :rtype: AsyncNeo4jDriver or AsyncBoltDriver
+ """
+
+ from ..api import (
+ DRIVER_BOLT,
+ DRIVER_NEO4j,
+ parse_neo4j_uri,
+ parse_routing_context,
+ SECURITY_TYPE_NOT_SECURE,
+ SECURITY_TYPE_SECURE,
+ SECURITY_TYPE_SELF_SIGNED_CERTIFICATE,
+ URI_SCHEME_BOLT,
+ URI_SCHEME_BOLT_SECURE,
+ URI_SCHEME_BOLT_SELF_SIGNED_CERTIFICATE,
+ URI_SCHEME_NEO4J,
+ URI_SCHEME_NEO4J_SECURE,
+ URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE,
+ )
+
+ driver_type, security_type, parsed = parse_neo4j_uri(uri)
+
+ if "trust" in config.keys():
+ if config.get("trust") not in [TRUST_ALL_CERTIFICATES, TRUST_SYSTEM_CA_SIGNED_CERTIFICATES]:
+ from neo4j.exceptions import ConfigurationError
+ raise ConfigurationError("The config setting `trust` values are {!r}".format(
+ [
+ TRUST_ALL_CERTIFICATES,
+ TRUST_SYSTEM_CA_SIGNED_CERTIFICATES,
+ ]
+ ))
+
+ if security_type in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE] and ("encrypted" in config.keys() or "trust" in config.keys()):
+ from neo4j.exceptions import ConfigurationError
+ raise ConfigurationError("The config settings 'encrypted' and 'trust' can only be used with the URI schemes {!r}. Use the other URI schemes {!r} for setting encryption settings.".format(
+ [
+ URI_SCHEME_BOLT,
+ URI_SCHEME_NEO4J,
+ ],
+ [
+ URI_SCHEME_BOLT_SELF_SIGNED_CERTIFICATE,
+ URI_SCHEME_BOLT_SECURE,
+ URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE,
+ URI_SCHEME_NEO4J_SECURE,
+ ]
+ ))
+
+ if security_type == SECURITY_TYPE_SECURE:
+ config["encrypted"] = True
+ elif security_type == SECURITY_TYPE_SELF_SIGNED_CERTIFICATE:
+ config["encrypted"] = True
+ config["trust"] = TRUST_ALL_CERTIFICATES
+
+ if driver_type == DRIVER_BOLT:
+ return cls.bolt_driver(parsed.netloc, auth=auth, **config)
+ elif driver_type == DRIVER_NEO4j:
+ routing_context = parse_routing_context(parsed.query)
+ return cls.neo4j_driver(parsed.netloc, auth=auth, routing_context=routing_context, **config)
+
+ @classmethod
+ def bolt_driver(cls, target, *, auth=None, **config):
+ """ Create a driver for direct Bolt server access that uses
+ socket I/O and thread-based concurrency.
+ """
+ from .._exceptions import (
+ BoltHandshakeError,
+ BoltSecurityError,
+ )
+
+ try:
+ return AsyncBoltDriver.open(target, auth=auth, **config)
+ except (BoltHandshakeError, BoltSecurityError) as error:
+ from neo4j.exceptions import ServiceUnavailable
+ raise ServiceUnavailable(str(error)) from error
+
+ @classmethod
+ def neo4j_driver(cls, *targets, auth=None, routing_context=None, **config):
+ """ Create a driver for routing-capable Neo4j service access
+ that uses socket I/O and thread-based concurrency.
+ """
+ from neo4j._exceptions import (
+ BoltHandshakeError,
+ BoltSecurityError,
+ )
+
+ try:
+ return AsyncNeo4jDriver.open(*targets, auth=auth, routing_context=routing_context, **config)
+ except (BoltHandshakeError, BoltSecurityError) as error:
+ from neo4j.exceptions import ServiceUnavailable
+ raise ServiceUnavailable(str(error)) from error
+
+
+class _Direct:
+
+ default_host = "localhost"
+ default_port = 7687
+
+ default_target = ":"
+
+ def __init__(self, address):
+ self._address = address
+
+ @property
+ def address(self):
+ return self._address
+
+ @classmethod
+ def parse_target(cls, target):
+ """ Parse a target string to produce an address.
+ """
+ if not target:
+ target = cls.default_target
+ address = Address.parse(target, default_host=cls.default_host,
+ default_port=cls.default_port)
+ return address
+
+
+class _Routing:
+
+ default_host = "localhost"
+ default_port = 7687
+
+ default_targets = ": :17601 :17687"
+
+ def __init__(self, initial_addresses):
+ self._initial_addresses = initial_addresses
+
+ @property
+ def initial_addresses(self):
+ return self._initial_addresses
+
+ @classmethod
+ def parse_targets(cls, *targets):
+ """ Parse a sequence of target strings to produce an address
+ list.
+ """
+ targets = " ".join(targets)
+ if not targets:
+ targets = cls.default_targets
+ addresses = Address.parse_list(targets, default_host=cls.default_host, default_port=cls.default_port)
+ return addresses
+
+
+class AsyncDriver:
+ """ Base class for all types of :class:`neo4j.AsyncDriver`, instances of
+ which are used as the primary access point to Neo4j.
+ """
+
+ #: Connection pool
+ _pool = None
+
+ def __init__(self, pool):
+ assert pool is not None
+ self._pool = pool
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ await self.close()
+
+ def __del__(self):
+ if not asyncio.iscoroutinefunction(self.close):
+ self.close()
+
+ @property
+ def encrypted(self):
+ return bool(self._pool.pool_config.encrypted)
+
+ def session(self, **config):
+ """Create a session, see :ref:`async-session-construction-ref`
+
+ :param config: session configuration key-word arguments,
+ see :ref:`async-session-configuration-ref` for available key-word
+ arguments.
+
+ :returns: new :class:`neo4j.AsyncSession` object
+ """
+ raise NotImplementedError
+
+ async def close(self):
+ """ Shut down, closing any open connections in the pool.
+ """
+ await self._pool.close()
+
+ @experimental("The configuration may change in the future.")
+ async def verify_connectivity(self, **config):
+ """ This verifies if the driver can connect to a remote server or a cluster
+ by establishing a network connection with the remote and possibly exchanging
+ a few data before closing the connection. It throws exception if fails to connect.
+
+ Use the exception to further understand the cause of the connectivity problem.
+
+ Note: Even if this method throws an exception, the driver still need to be closed via close() to free up all resources.
+ """
+ raise NotImplementedError
+
+ @experimental("Feature support query, based on Bolt Protocol Version and Neo4j Server Version will change in the future.")
+ async def supports_multi_db(self):
+ """ Check if the server or cluster supports multi-databases.
+
+ :return: Returns true if the server or cluster the driver connects to supports multi-databases, otherwise false.
+ :rtype: bool
+ """
+ async with self.session() as session:
+ await session._connect(READ_ACCESS)
+ return session._connection.supports_multiple_databases
+
+
+class AsyncBoltDriver(_Direct, AsyncDriver):
+ """:class:`.AsyncBoltDriver` is instantiated for ``bolt`` URIs and
+ addresses a single database machine. This may be a standalone server or
+ could be a specific member of a cluster.
+
+ Connections established by a :class:`.AsyncBoltDriver` are always made to
+ the exact host and port detailed in the URI.
+
+ This class is not supposed to be instantiated externally. Use
+ :meth:`AsyncGraphDatabase.driver` instead.
+ """
+
+ @classmethod
+ def open(cls, target, *, auth=None, **config):
+ """
+ :param target:
+ :param auth:
+ :param config: The values that can be specified are found in :class: `neo4j.PoolConfig` and :class: `neo4j.WorkspaceConfig`
+
+ :return:
+ :rtype: :class: `neo4j.BoltDriver`
+ """
+ from .io import AsyncBoltPool
+ address = cls.parse_target(target)
+ pool_config, default_workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig)
+ pool = AsyncBoltPool.open(address, auth=auth, pool_config=pool_config, workspace_config=default_workspace_config)
+ return cls(pool, default_workspace_config)
+
+ def __init__(self, pool, default_workspace_config):
+ _Direct.__init__(self, pool.address)
+ AsyncDriver.__init__(self, pool)
+ self._default_workspace_config = default_workspace_config
+
+ def session(self, **config):
+ """
+ :param config: The values that can be specified are found in :class: `neo4j.SessionConfig`
+
+ :return:
+ :rtype: :class: `neo4j.AsyncSession`
+ """
+ from .work import AsyncSession
+ session_config = SessionConfig(self._default_workspace_config, config)
+ SessionConfig.consume(config) # Consume the config
+ return AsyncSession(self._pool, session_config)
+
+ @experimental("The configuration may change in the future.")
+ async def verify_connectivity(self, **config):
+ server_agent = None
+ config["fetch_size"] = -1
+ async with self.session(**config) as session:
+ result = await session.run("RETURN 1 AS x")
+ value = await result.single().value()
+ summary = await result.consume()
+ server_agent = summary.server.agent
+ return server_agent
+
+
+class AsyncNeo4jDriver(_Routing, AsyncDriver):
+ """:class:`.AsyncNeo4jDriver` is instantiated for ``neo4j`` URIs. The
+ routing behaviour works in tandem with Neo4j's `Causal Clustering
+ `_
+ feature by directing read and write behaviour to appropriate
+ cluster members.
+
+ This class is not supposed to be instantiated externally. Use
+ :meth:`AsyncGraphDatabase.driver` instead.
+ """
+
+ @classmethod
+ def open(cls, *targets, auth=None, routing_context=None, **config):
+ from .io import AsyncNeo4jPool
+ addresses = cls.parse_targets(*targets)
+ pool_config, default_workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig)
+ pool = AsyncNeo4jPool.open(*addresses, auth=auth, routing_context=routing_context, pool_config=pool_config, workspace_config=default_workspace_config)
+ return cls(pool, default_workspace_config)
+
+ def __init__(self, pool, default_workspace_config):
+ _Routing.__init__(self, pool.get_default_database_initial_router_addresses())
+ AsyncDriver.__init__(self, pool)
+ self._default_workspace_config = default_workspace_config
+
+ def session(self, **config):
+ from .work import AsyncSession
+ session_config = SessionConfig(self._default_workspace_config, config)
+ SessionConfig.consume(config) # Consume the config
+ return AsyncSession(self._pool, session_config)
+
+ @experimental("The configuration may change in the future.")
+ async def verify_connectivity(self, **config):
+ """
+ :raise ServiceUnavailable: raised if the server does not support routing or if routing support is broken.
+ """
+ # TODO: Improve and update Stub Test Server to be able to test.
+ return await self._verify_routing_connectivity()
+
+ async def _verify_routing_connectivity(self):
+ from ..exceptions import (
+ Neo4jError,
+ ServiceUnavailable,
+ SessionExpired,
+ )
+
+ table = self._pool.get_routing_table_for_default_database()
+ routing_info = {}
+ for ix in list(table.routers):
+ try:
+ routing_info[ix] = await self._pool.fetch_routing_info(
+ address=table.routers[0],
+ database=self._default_workspace_config.database,
+ imp_user=self._default_workspace_config.impersonated_user,
+ bookmarks=None,
+ timeout=self._default_workspace_config
+ .connection_acquisition_timeout
+ )
+ except (ServiceUnavailable, SessionExpired, Neo4jError):
+ routing_info[ix] = None
+ for key, val in routing_info.items():
+ if val is not None:
+ return routing_info
+ raise ServiceUnavailable("Could not connect to any routing servers.")
diff --git a/neo4j/_async/io/__init__.py b/neo4j/_async/io/__init__.py
new file mode 100644
index 00000000..32aca11f
--- /dev/null
+++ b/neo4j/_async/io/__init__.py
@@ -0,0 +1,43 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This module contains the low-level functionality required for speaking
+Bolt. It is not intended to be used directly by driver users. Instead,
+the `session` module provides the main user-facing abstractions.
+"""
+
+
+__all__ = [
+ "AsyncBolt",
+ "AsyncBoltPool",
+ "AsyncNeo4jPool",
+ "check_supported_server_product",
+ "ConnectionErrorHandler",
+]
+
+
+from ._bolt import AsyncBolt
+from ._common import (
+ check_supported_server_product,
+ ConnectionErrorHandler,
+)
+from ._pool import (
+ AsyncBoltPool,
+ AsyncNeo4jPool,
+)
diff --git a/neo4j/_async/io/_bolt.py b/neo4j/_async/io/_bolt.py
new file mode 100644
index 00000000..503ebeaf
--- /dev/null
+++ b/neo4j/_async/io/_bolt.py
@@ -0,0 +1,571 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import abc
+import asyncio
+from collections import deque
+from logging import getLogger
+from time import perf_counter
+
+from ..._async_compat.network import AsyncBoltSocket
+from ..._exceptions import BoltHandshakeError
+from ...addressing import Address
+from ...api import (
+ ServerInfo,
+ Version,
+)
+from ...conf import PoolConfig
+from ...exceptions import (
+ AuthError,
+ IncompleteCommit,
+ ServiceUnavailable,
+ SessionExpired,
+)
+from ...meta import get_user_agent
+from ...packstream import (
+ Packer,
+ Unpacker,
+)
+from ._common import (
+ AsyncInbox,
+ CommitResponse,
+ Outbox,
+)
+
+
+# Set up logger
+log = getLogger("neo4j")
+
+
+class AsyncBolt:
+ """ Server connection for Bolt protocol.
+
+ A :class:`.Bolt` should be constructed following a
+ successful .open()
+
+ Bolt handshake and takes the socket over which
+ the handshake was carried out.
+ """
+
+ MAGIC_PREAMBLE = b"\x60\x60\xB0\x17"
+
+ PROTOCOL_VERSION = None
+
+ # flag if connection needs RESET to go back to READY state
+ is_reset = False
+
+ # The socket
+ in_use = False
+
+ # The socket
+ _closed = False
+
+ # The socket
+ _defunct = False
+
+ #: The pool of which this connection is a member
+ pool = None
+
+ # Store the id of the most recent ran query to be able to reduce sent bits by
+ # using the default (-1) to refer to the most recent query when pulling
+ # results for it.
+ most_recent_qid = None
+
+ def __init__(self, unresolved_address, sock, max_connection_lifetime, *,
+ auth=None, user_agent=None, routing_context=None):
+ self.unresolved_address = unresolved_address
+ self.socket = sock
+ self.server_info = ServerInfo(Address(sock.getpeername()),
+ self.PROTOCOL_VERSION)
+ # so far `connection.recv_timeout_seconds` is the only available
+ # configuration hint that exists. Therefore, all hints can be stored at
+ # connection level. This might change in the future.
+ self.configuration_hints = {}
+ self.outbox = Outbox()
+ self.inbox = AsyncInbox(self.socket, on_error=self._set_defunct_read)
+ self.packer = Packer(self.outbox)
+ self.unpacker = Unpacker(self.inbox)
+ self.responses = deque()
+ self._max_connection_lifetime = max_connection_lifetime
+ self._creation_timestamp = perf_counter()
+ self.routing_context = routing_context
+
+ # Determine the user agent
+ if user_agent:
+ self.user_agent = user_agent
+ else:
+ self.user_agent = get_user_agent()
+
+ # Determine auth details
+ if not auth:
+ self.auth_dict = {}
+ elif isinstance(auth, tuple) and 2 <= len(auth) <= 3:
+ from neo4j import Auth
+ self.auth_dict = vars(Auth("basic", *auth))
+ else:
+ try:
+ self.auth_dict = vars(auth)
+ except (KeyError, TypeError):
+ raise AuthError("Cannot determine auth details from %r" % auth)
+
+ # Check for missing password
+ try:
+ credentials = self.auth_dict["credentials"]
+ except KeyError:
+ pass
+ else:
+ if credentials is None:
+ raise AuthError("Password cannot be None")
+
+ def __del__(self):
+ if not asyncio.iscoroutinefunction(self.close):
+ self.close()
+
+ @property
+ @abc.abstractmethod
+ def supports_multiple_results(self):
+ """ Boolean flag to indicate if the connection version supports multiple
+ queries to be buffered on the server side (True) or if all results need
+ to be eagerly pulled before sending the next RUN (False).
+ """
+ pass
+
+ @property
+ @abc.abstractmethod
+ def supports_multiple_databases(self):
+ """ Boolean flag to indicate if the connection version supports multiple
+ databases.
+ """
+ pass
+
+ @classmethod
+ def protocol_handlers(cls, protocol_version=None):
+ """ Return a dictionary of available Bolt protocol handlers,
+ keyed by version tuple. If an explicit protocol version is
+ provided, the dictionary will contain either zero or one items,
+ depending on whether that version is supported. If no protocol
+ version is provided, all available versions will be returned.
+
+ :param protocol_version: tuple identifying a specific protocol
+ version (e.g. (3, 5)) or None
+ :return: dictionary of version tuple to handler class for all
+ relevant and supported protocol versions
+ :raise TypeError: if protocol version is not passed in a tuple
+ """
+
+ # Carry out Bolt subclass imports locally to avoid circular dependency issues.
+ from ._bolt3 import AsyncBolt3
+ from ._bolt4 import (
+ AsyncBolt4x0,
+ AsyncBolt4x1,
+ AsyncBolt4x2,
+ AsyncBolt4x3,
+ AsyncBolt4x4,
+ )
+
+ handlers = {
+ AsyncBolt3.PROTOCOL_VERSION: AsyncBolt3,
+ AsyncBolt4x0.PROTOCOL_VERSION: AsyncBolt4x0,
+ AsyncBolt4x1.PROTOCOL_VERSION: AsyncBolt4x1,
+ AsyncBolt4x2.PROTOCOL_VERSION: AsyncBolt4x2,
+ AsyncBolt4x3.PROTOCOL_VERSION: AsyncBolt4x3,
+ AsyncBolt4x4.PROTOCOL_VERSION: AsyncBolt4x4,
+ }
+
+ if protocol_version is None:
+ return handlers
+
+ if not isinstance(protocol_version, tuple):
+ raise TypeError("Protocol version must be specified as a tuple")
+
+ if protocol_version in handlers:
+ return {protocol_version: handlers[protocol_version]}
+
+ return {}
+
+ @classmethod
+ def version_list(cls, versions, limit=4):
+ """ Return a list of supported protocol versions in order of
+ preference. The number of protocol versions (or ranges)
+ returned is limited to four.
+ """
+ # In fact, 4.3 is the fist version to support ranges. However, the range
+ # support got backported to 4.2. But even if the server is too old to
+ # have the backport, negotiating BOLT 4.1 is no problem as it's
+ # equivalent to 4.2
+ first_with_range_support = Version(4, 2)
+ result = []
+ for version in versions:
+ if (result
+ and version >= first_with_range_support
+ and result[-1][0] == version[0]
+ and result[-1][1][1] == version[1] + 1):
+ # can use range to encompass this version
+ result[-1][1][1] = version[1]
+ continue
+ result.append(Version(version[0], [version[1], version[1]]))
+ if len(result) == 4:
+ break
+ return result
+
+ @classmethod
+ def get_handshake(cls):
+ """ Return the supported Bolt versions as bytes.
+ The length is 16 bytes as specified in the Bolt version negotiation.
+ :return: bytes
+ """
+ supported_versions = sorted(cls.protocol_handlers().keys(), reverse=True)
+ offered_versions = cls.version_list(supported_versions)
+ return b"".join(version.to_bytes() for version in offered_versions).ljust(16, b"\x00")
+
+ @classmethod
+ async def ping(cls, address, *, timeout=None, **config):
+ """ Attempt to establish a Bolt connection, returning the
+ agreed Bolt protocol version if successful.
+ """
+ config = PoolConfig.consume(config)
+ try:
+ s, protocol_version, handshake, data = \
+ await AsyncBoltSocket.connect(
+ address,
+ timeout=timeout,
+ custom_resolver=config.resolver,
+ ssl_context=config.get_ssl_context(),
+ keep_alive=config.keep_alive,
+ )
+ except (ServiceUnavailable, SessionExpired, BoltHandshakeError):
+ return None
+ else:
+ AsyncBoltSocket.close_socket(s)
+ return protocol_version
+
+ @classmethod
+ async def open(
+ cls, address, *, auth=None, timeout=None, routing_context=None, **pool_config
+ ):
+ """ Open a new Bolt connection to a given server address.
+
+ :param address:
+ :param auth:
+ :param timeout: the connection timeout in seconds
+ :param routing_context: dict containing routing context
+ :param pool_config:
+ :return:
+ :raise BoltHandshakeError: raised if the Bolt Protocol can not negotiate a protocol version.
+ :raise ServiceUnavailable: raised if there was a connection issue.
+ """
+ pool_config = PoolConfig.consume(pool_config)
+ s, pool_config.protocol_version, handshake, data = \
+ await AsyncBoltSocket.connect(
+ address,
+ timeout=timeout,
+ custom_resolver=pool_config.resolver,
+ ssl_context=pool_config.get_ssl_context(),
+ keep_alive=pool_config.keep_alive,
+ )
+
+ # Carry out Bolt subclass imports locally to avoid circular dependency
+ # issues.
+ if pool_config.protocol_version == (3, 0):
+ from ._bolt3 import AsyncBolt3
+ bolt_cls = AsyncBolt3
+ elif pool_config.protocol_version == (4, 0):
+ from ._bolt4 import AsyncBolt4x0
+ bolt_cls = AsyncBolt4x0
+ elif pool_config.protocol_version == (4, 1):
+ from ._bolt4 import AsyncBolt4x1
+ bolt_cls = AsyncBolt4x1
+ elif pool_config.protocol_version == (4, 2):
+ from ._bolt4 import AsyncBolt4x2
+ bolt_cls = AsyncBolt4x2
+ elif pool_config.protocol_version == (4, 3):
+ from ._bolt4 import AsyncBolt4x3
+ bolt_cls = AsyncBolt4x3
+ elif pool_config.protocol_version == (4, 4):
+ from ._bolt4 import AsyncBolt4x4
+ bolt_cls = AsyncBolt4x4
+ else:
+ log.debug("[#%04X] S: ", s.getsockname()[1])
+ AsyncBoltSocket.close_socket(s)
+
+ supported_versions = cls.protocol_handlers().keys()
+ raise BoltHandshakeError("The Neo4J server does not support communication with this driver. This driver have support for Bolt Protocols {}".format(supported_versions), address=address, request_data=handshake, response_data=data)
+
+ connection = bolt_cls(
+ address, s, pool_config.max_connection_lifetime, auth=auth,
+ user_agent=pool_config.user_agent, routing_context=routing_context
+ )
+
+ try:
+ await connection.hello()
+ except Exception:
+ await connection.close()
+ raise
+
+ return connection
+
+ @property
+ @abc.abstractmethod
+ def encrypted(self):
+ pass
+
+ @property
+ @abc.abstractmethod
+ def der_encoded_server_certificate(self):
+ pass
+
+ @property
+ @abc.abstractmethod
+ def local_port(self):
+ pass
+
+ @abc.abstractmethod
+ async def hello(self):
+ """ Appends a HELLO message to the outgoing queue, sends it and consumes
+ all remaining messages.
+ """
+ pass
+
+ @abc.abstractmethod
+ async def route(self, database=None, imp_user=None, bookmarks=None):
+ """ Fetch a routing table from the server for the given
+ `database`. For Bolt 4.3 and above, this appends a ROUTE
+ message; for earlier versions, a procedure call is made via
+ the regular Cypher execution mechanism. In all cases, this is
+ sent to the network, and a response is fetched.
+
+ :param database: database for which to fetch a routing table
+ Requires Bolt 4.0+.
+ :param imp_user: the user to impersonate
+ Requires Bolt 4.4+.
+ :param bookmarks: iterable of bookmark values after which this
+ transaction should begin
+ :return: dictionary of raw routing data
+ """
+ pass
+
+ @abc.abstractmethod
+ def run(self, query, parameters=None, mode=None, bookmarks=None,
+ metadata=None, timeout=None, db=None, imp_user=None, **handlers):
+ """ Appends a RUN message to the output queue.
+
+ :param query: Cypher query string
+ :param parameters: dictionary of Cypher parameters
+ :param mode: access mode for routing - "READ" or "WRITE" (default)
+ :param bookmarks: iterable of bookmark values after which this transaction should begin
+ :param metadata: custom metadata dictionary to attach to the transaction
+ :param timeout: timeout for transaction execution (seconds)
+ :param db: name of the database against which to begin the transaction
+ Requires Bolt 4.0+.
+ :param imp_user: the user to impersonate
+ Requires Bolt 4.4+.
+ :param handlers: handler functions passed into the returned Response object
+ :return: Response object
+ """
+ pass
+
+ @abc.abstractmethod
+ def discard(self, n=-1, qid=-1, **handlers):
+ """ Appends a DISCARD message to the output queue.
+
+ :param n: number of records to discard, default = -1 (ALL)
+ :param qid: query ID to discard for, default = -1 (last query)
+ :param handlers: handler functions passed into the returned Response object
+ :return: Response object
+ """
+ pass
+
+ @abc.abstractmethod
+ def pull(self, n=-1, qid=-1, **handlers):
+ """ Appends a PULL message to the output queue.
+
+ :param n: number of records to pull, default = -1 (ALL)
+ :param qid: query ID to pull for, default = -1 (last query)
+ :param handlers: handler functions passed into the returned Response object
+ :return: Response object
+ """
+ pass
+
+ @abc.abstractmethod
+ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
+ db=None, imp_user=None, **handlers):
+ """ Appends a BEGIN message to the output queue.
+
+ :param mode: access mode for routing - "READ" or "WRITE" (default)
+ :param bookmarks: iterable of bookmark values after which this transaction should begin
+ :param metadata: custom metadata dictionary to attach to the transaction
+ :param timeout: timeout for transaction execution (seconds)
+ :param db: name of the database against which to begin the transaction
+ Requires Bolt 4.0+.
+ :param imp_user: the user to impersonate
+ Requires Bolt 4.4+
+ :param handlers: handler functions passed into the returned Response object
+ :return: Response object
+ """
+ pass
+
+ @abc.abstractmethod
+ def commit(self, **handlers):
+ """ Appends a COMMIT message to the output queue."""
+ pass
+
+ @abc.abstractmethod
+ def rollback(self, **handlers):
+ """ Appends a ROLLBACK message to the output queue."""
+ pass
+
+ @abc.abstractmethod
+ async def reset(self):
+ """ Appends a RESET message to the outgoing queue, sends it and consumes
+ all remaining messages.
+ """
+ pass
+
+ def _append(self, signature, fields=(), response=None):
+ """ Appends a message to the outgoing queue.
+
+ :param signature: the signature of the message
+ :param fields: the fields of the message as a tuple
+ :param response: a response object to handle callbacks
+ """
+ self.packer.pack_struct(signature, fields)
+ self.outbox.wrap_message()
+ self.responses.append(response)
+
+ async def _send_all(self):
+ data = self.outbox.view()
+ if data:
+ try:
+ await self.socket.sendall(data)
+ except OSError as error:
+ await self._set_defunct_write(error)
+ self.outbox.clear()
+
+ async def send_all(self):
+ """ Send all queued messages to the server.
+ """
+ if self.closed():
+ raise ServiceUnavailable("Failed to write to closed connection {!r} ({!r})".format(
+ self.unresolved_address, self.server_info.address))
+
+ if self.defunct():
+ raise ServiceUnavailable("Failed to write to defunct connection {!r} ({!r})".format(
+ self.unresolved_address, self.server_info.address))
+
+ await self._send_all()
+
+ @abc.abstractmethod
+ async def fetch_message(self):
+ """ Receive at most one message from the server, if available.
+
+ :return: 2-tuple of number of detail messages and number of summary
+ messages fetched
+ """
+ pass
+
+ async def fetch_all(self):
+ """ Fetch all outstanding messages.
+
+ :return: 2-tuple of number of detail messages and number of summary
+ messages fetched
+ """
+ detail_count = summary_count = 0
+ while self.responses:
+ response = self.responses[0]
+ while not response.complete:
+ detail_delta, summary_delta = await self.fetch_message()
+ detail_count += detail_delta
+ summary_count += summary_delta
+ return detail_count, summary_count
+
+ async def _set_defunct_read(self, error=None, silent=False):
+ message = "Failed to read from defunct connection {!r} ({!r})".format(
+ self.unresolved_address, self.server_info.address
+ )
+ await self._set_defunct(message, error=error, silent=silent)
+
+ async def _set_defunct_write(self, error=None, silent=False):
+ message = "Failed to write data to connection {!r} ({!r})".format(
+ self.unresolved_address, self.server_info.address
+ )
+ await self._set_defunct(message, error=error, silent=silent)
+
+ async def _set_defunct(self, message, error=None, silent=False):
+ from ._pool import AsyncBoltPool
+ direct_driver = isinstance(self.pool, AsyncBoltPool)
+
+ if error:
+ log.debug("[#%04X] %s", self.socket.getsockname()[1], error)
+ log.error(message)
+ # We were attempting to receive data but the connection
+ # has unexpectedly terminated. So, we need to close the
+ # connection from the client side, and remove the address
+ # from the connection pool.
+ self._defunct = True
+ await self.close()
+ if self.pool:
+ await self.pool.deactivate(address=self.unresolved_address)
+ # Iterate through the outstanding responses, and if any correspond
+ # to COMMIT requests then raise an error to signal that we are
+ # unable to confirm that the COMMIT completed successfully.
+ if silent:
+ return
+ for response in self.responses:
+ if isinstance(response, CommitResponse):
+ if error:
+ raise IncompleteCommit(message) from error
+ else:
+ raise IncompleteCommit(message)
+
+ if direct_driver:
+ if error:
+ raise ServiceUnavailable(message) from error
+ else:
+ raise ServiceUnavailable(message)
+ else:
+ if error:
+ raise SessionExpired(message) from error
+ else:
+ raise SessionExpired(message)
+
+ def stale(self):
+ return (self._stale
+ or (0 <= self._max_connection_lifetime
+ <= perf_counter() - self._creation_timestamp))
+
+ _stale = False
+
+ def set_stale(self):
+ self._stale = True
+
+ @abc.abstractmethod
+ async def close(self):
+ """ Close the connection.
+ """
+ pass
+
+ @abc.abstractmethod
+ def closed(self):
+ pass
+
+ @abc.abstractmethod
+ def defunct(self):
+ pass
+
+
+AsyncBoltSocket.Bolt = AsyncBolt
diff --git a/neo4j/_async/io/_bolt3.py b/neo4j/_async/io/_bolt3.py
new file mode 100644
index 00000000..8bf7e939
--- /dev/null
+++ b/neo4j/_async/io/_bolt3.py
@@ -0,0 +1,396 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from enum import Enum
+from logging import getLogger
+from ssl import SSLSocket
+
+from ..._async_compat.util import AsyncUtil
+from ..._exceptions import (
+ BoltError,
+ BoltProtocolError,
+)
+from ...api import (
+ READ_ACCESS,
+ Version,
+)
+from ...exceptions import (
+ ConfigurationError,
+ DatabaseUnavailable,
+ DriverError,
+ ForbiddenOnReadOnlyDatabase,
+ Neo4jError,
+ NotALeader,
+ ServiceUnavailable,
+)
+from ._bolt import AsyncBolt
+from ._common import (
+ check_supported_server_product,
+ CommitResponse,
+ InitResponse,
+ Response,
+)
+
+
+log = getLogger("neo4j")
+
+
+class ServerStates(Enum):
+ CONNECTED = "CONNECTED"
+ READY = "READY"
+ STREAMING = "STREAMING"
+ TX_READY_OR_TX_STREAMING = "TX_READY||TX_STREAMING"
+ FAILED = "FAILED"
+
+
+class ServerStateManager:
+ _STATE_TRANSITIONS = {
+ ServerStates.CONNECTED: {
+ "hello": ServerStates.READY,
+ },
+ ServerStates.READY: {
+ "run": ServerStates.STREAMING,
+ "begin": ServerStates.TX_READY_OR_TX_STREAMING,
+ },
+ ServerStates.STREAMING: {
+ "pull": ServerStates.READY,
+ "discard": ServerStates.READY,
+ "reset": ServerStates.READY,
+ },
+ ServerStates.TX_READY_OR_TX_STREAMING: {
+ "commit": ServerStates.READY,
+ "rollback": ServerStates.READY,
+ "reset": ServerStates.READY,
+ },
+ ServerStates.FAILED: {
+ "reset": ServerStates.READY,
+ }
+ }
+
+ def __init__(self, init_state, on_change=None):
+ self.state = init_state
+ self._on_change = on_change
+
+ def transition(self, message, metadata):
+ if metadata.get("has_more"):
+ return
+ state_before = self.state
+ self.state = self._STATE_TRANSITIONS\
+ .get(self.state, {})\
+ .get(message, self.state)
+ if state_before != self.state and callable(self._on_change):
+ self._on_change(state_before, self.state)
+
+
+class AsyncBolt3(AsyncBolt):
+ """ Protocol handler for Bolt 3.
+
+ This is supported by Neo4j versions 3.5, 4.0, 4.1, 4.2, 4.3, and 4.4.
+ """
+
+ PROTOCOL_VERSION = Version(3, 0)
+
+ supports_multiple_results = False
+
+ supports_multiple_databases = False
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._server_state_manager = ServerStateManager(
+ ServerStates.CONNECTED, on_change=self._on_server_state_change
+ )
+
+ def _on_server_state_change(self, old_state, new_state):
+ log.debug("[#%04X] State: %s > %s", self.local_port,
+ old_state.name, new_state.name)
+
+ @property
+ def is_reset(self):
+ # We can't be sure of the server's state if there are still pending
+ # responses. Unless the last message we sent was RESET. In that case
+ # the server state will always be READY when we're done.
+ if (self.responses and self.responses[-1]
+ and self.responses[-1].message == "reset"):
+ return True
+ return self._server_state_manager.state == ServerStates.READY
+
+ @property
+ def encrypted(self):
+ return isinstance(self.socket, SSLSocket)
+
+ @property
+ def der_encoded_server_certificate(self):
+ return self.socket.getpeercert(binary_form=True)
+
+ @property
+ def local_port(self):
+ try:
+ return self.socket.getsockname()[1]
+ except OSError:
+ return 0
+
+ def get_base_headers(self):
+ return {
+ "user_agent": self.user_agent,
+ }
+
+ async def hello(self):
+ headers = self.get_base_headers()
+ headers.update(self.auth_dict)
+ logged_headers = dict(headers)
+ if "credentials" in logged_headers:
+ logged_headers["credentials"] = "*******"
+ log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers)
+ self._append(b"\x01", (headers,),
+ response=InitResponse(self, "hello",
+ on_success=self.server_info.update))
+ await self.send_all()
+ await self.fetch_all()
+ check_supported_server_product(self.server_info.agent)
+
+ async def route(self, database=None, imp_user=None, bookmarks=None):
+ if database is not None:
+ raise ConfigurationError(
+ "Database name parameter for selecting database is not "
+ "supported in Bolt Protocol {!r}. Database name {!r}. "
+ "Server Agent {!r}".format(
+ self.PROTOCOL_VERSION, database, self.server_info.agent
+ )
+ )
+ if imp_user is not None:
+ raise ConfigurationError(
+ "Impersonation is not supported in Bolt Protocol {!r}. "
+ "Trying to impersonate {!r}.".format(
+ self.PROTOCOL_VERSION, imp_user
+ )
+ )
+
+ metadata = {}
+ records = []
+
+ # Ignoring database and bookmarks because there is no multi-db support.
+ # The bookmarks are only relevant for making sure a previously created
+ # db exists before querying a routing table for it.
+ self.run(
+ "CALL dbms.cluster.routing.getRoutingTable($context)", # This is an internal procedure call. Only available if the Neo4j 3.5 is setup with clustering.
+ {"context": self.routing_context},
+ mode="r", # Bolt Protocol Version(3, 0) supports mode="r"
+ on_success=metadata.update
+ )
+ self.pull(on_success=metadata.update, on_records=records.extend)
+ await self.send_all()
+ await self.fetch_all()
+ routing_info = [dict(zip(metadata.get("fields", ()), values)) for values in records]
+ return routing_info
+
+ def run(self, query, parameters=None, mode=None, bookmarks=None,
+ metadata=None, timeout=None, db=None, imp_user=None, **handlers):
+ if db is not None:
+ raise ConfigurationError(
+ "Database name parameter for selecting database is not "
+ "supported in Bolt Protocol {!r}. Database name {!r}.".format(
+ self.PROTOCOL_VERSION, db
+ )
+ )
+ if imp_user is not None:
+ raise ConfigurationError(
+ "Impersonation is not supported in Bolt Protocol {!r}. "
+ "Trying to impersonate {!r}.".format(
+ self.PROTOCOL_VERSION, imp_user
+ )
+ )
+ if not parameters:
+ parameters = {}
+ extra = {}
+ if mode in (READ_ACCESS, "r"):
+ extra["mode"] = "r" # It will default to mode "w" if nothing is specified
+ if bookmarks:
+ try:
+ extra["bookmarks"] = list(bookmarks)
+ except TypeError:
+ raise TypeError("Bookmarks must be provided within an iterable")
+ if metadata:
+ try:
+ extra["tx_metadata"] = dict(metadata)
+ except TypeError:
+ raise TypeError("Metadata must be coercible to a dict")
+ if timeout:
+ try:
+ extra["tx_timeout"] = int(1000 * timeout)
+ except TypeError:
+ raise TypeError("Timeout must be specified as a number of seconds")
+ fields = (query, parameters, extra)
+ log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields)))
+ if query.upper() == u"COMMIT":
+ self._append(b"\x10", fields, CommitResponse(self, "run",
+ **handlers))
+ else:
+ self._append(b"\x10", fields, Response(self, "run", **handlers))
+
+ def discard(self, n=-1, qid=-1, **handlers):
+ # Just ignore n and qid, it is not supported in the Bolt 3 Protocol.
+ log.debug("[#%04X] C: DISCARD_ALL", self.local_port)
+ self._append(b"\x2F", (), Response(self, "discard", **handlers))
+
+ def pull(self, n=-1, qid=-1, **handlers):
+ # Just ignore n and qid, it is not supported in the Bolt 3 Protocol.
+ log.debug("[#%04X] C: PULL_ALL", self.local_port)
+ self._append(b"\x3F", (), Response(self, "pull", **handlers))
+
+ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
+ db=None, imp_user=None, **handlers):
+ if db is not None:
+ raise ConfigurationError(
+ "Database name parameter for selecting database is not "
+ "supported in Bolt Protocol {!r}. Database name {!r}.".format(
+ self.PROTOCOL_VERSION, db
+ )
+ )
+ if imp_user is not None:
+ raise ConfigurationError(
+ "Impersonation is not supported in Bolt Protocol {!r}. "
+ "Trying to impersonate {!r}.".format(
+ self.PROTOCOL_VERSION, imp_user
+ )
+ )
+ extra = {}
+ if mode in (READ_ACCESS, "r"):
+ extra["mode"] = "r" # It will default to mode "w" if nothing is specified
+ if bookmarks:
+ try:
+ extra["bookmarks"] = list(bookmarks)
+ except TypeError:
+ raise TypeError("Bookmarks must be provided within an iterable")
+ if metadata:
+ try:
+ extra["tx_metadata"] = dict(metadata)
+ except TypeError:
+ raise TypeError("Metadata must be coercible to a dict")
+ if timeout:
+ try:
+ extra["tx_timeout"] = int(1000 * timeout)
+ except TypeError:
+ raise TypeError("Timeout must be specified as a number of seconds")
+ log.debug("[#%04X] C: BEGIN %r", self.local_port, extra)
+ self._append(b"\x11", (extra,), Response(self, "begin", **handlers))
+
+ def commit(self, **handlers):
+ log.debug("[#%04X] C: COMMIT", self.local_port)
+ self._append(b"\x12", (), CommitResponse(self, "commit", **handlers))
+
+ def rollback(self, **handlers):
+ log.debug("[#%04X] C: ROLLBACK", self.local_port)
+ self._append(b"\x13", (), Response(self, "rollback", **handlers))
+
+ async def reset(self):
+ """ Add a RESET message to the outgoing queue, send
+ it and consume all remaining messages.
+ """
+
+ def fail(metadata):
+ raise BoltProtocolError("RESET failed %r" % metadata, address=self.unresolved_address)
+
+ log.debug("[#%04X] C: RESET", self.local_port)
+ self._append(b"\x0F", response=Response(self, "reset", on_failure=fail))
+ await self.send_all()
+ await self.fetch_all()
+
+ async def fetch_message(self):
+ """ Receive at most one message from the server, if available.
+
+ :return: 2-tuple of number of detail messages and number of summary
+ messages fetched
+ """
+ if self._closed:
+ raise ServiceUnavailable("Failed to read from closed connection {!r} ({!r})".format(
+ self.unresolved_address, self.server_info.address))
+
+ if self._defunct:
+ raise ServiceUnavailable("Failed to read from defunct connection {!r} ({!r})".format(
+ self.unresolved_address, self.server_info.address))
+
+ if not self.responses:
+ return 0, 0
+
+ # Receive exactly one message
+ details, summary_signature, summary_metadata = \
+ await AsyncUtil.next(self.inbox)
+
+ if details:
+ log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # Do not log any data
+ await self.responses[0].on_records(details)
+
+ if summary_signature is None:
+ return len(details), 0
+
+ response = self.responses.popleft()
+ response.complete = True
+ if summary_signature == b"\x70":
+ log.debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata)
+ self._server_state_manager.transition(response.message,
+ summary_metadata)
+ await response.on_success(summary_metadata or {})
+ elif summary_signature == b"\x7E":
+ log.debug("[#%04X] S: IGNORED", self.local_port)
+ await response.on_ignored(summary_metadata or {})
+ elif summary_signature == b"\x7F":
+ log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata)
+ self._server_state_manager.state = ServerStates.FAILED
+ try:
+ await response.on_failure(summary_metadata or {})
+ except (ServiceUnavailable, DatabaseUnavailable):
+ if self.pool:
+ await self.pool.deactivate(address=self.unresolved_address)
+ raise
+ except (NotALeader, ForbiddenOnReadOnlyDatabase):
+ if self.pool:
+ self.pool.on_write_failure(address=self.unresolved_address)
+ raise
+ except Neo4jError as e:
+ if self.pool and e.invalidates_all_connections():
+ await self.pool.mark_all_stale()
+ raise
+ else:
+ raise BoltProtocolError("Unexpected response message with signature %02X" % summary_signature, address=self.unresolved_address)
+
+ return len(details), 1
+
+ async def close(self):
+ """ Close the connection.
+ """
+ if not self._closed:
+ if not self._defunct:
+ log.debug("[#%04X] C: GOODBYE", self.local_port)
+ self._append(b"\x02", ())
+ try:
+ await self._send_all()
+ except (OSError, BoltError, DriverError):
+ pass
+ log.debug("[#%04X] C: ", self.local_port)
+ try:
+ self.socket.close()
+ except OSError:
+ pass
+ finally:
+ self._closed = True
+
+ def closed(self):
+ return self._closed
+
+ def defunct(self):
+ return self._defunct
diff --git a/neo4j/_async/io/_bolt4.py b/neo4j/_async/io/_bolt4.py
new file mode 100644
index 00000000..326e7212
--- /dev/null
+++ b/neo4j/_async/io/_bolt4.py
@@ -0,0 +1,537 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from logging import getLogger
+from ssl import SSLSocket
+
+from ..._async_compat.util import AsyncUtil
+from ..._exceptions import (
+ BoltError,
+ BoltProtocolError,
+)
+from ...api import (
+ READ_ACCESS,
+ SYSTEM_DATABASE,
+ Version,
+)
+from ...exceptions import (
+ ConfigurationError,
+ DatabaseUnavailable,
+ DriverError,
+ ForbiddenOnReadOnlyDatabase,
+ Neo4jError,
+ NotALeader,
+ ServiceUnavailable,
+)
+from ._bolt3 import (
+ ServerStateManager,
+ ServerStates,
+)
+from ._bolt import AsyncBolt
+from ._common import (
+ check_supported_server_product,
+ CommitResponse,
+ InitResponse,
+ Response,
+)
+
+
+log = getLogger("neo4j")
+
+
+class AsyncBolt4x0(AsyncBolt):
+ """ Protocol handler for Bolt 4.0.
+
+ This is supported by Neo4j versions 4.0, 4.1 and 4.2.
+ """
+
+ PROTOCOL_VERSION = Version(4, 0)
+
+ supports_multiple_results = True
+
+ supports_multiple_databases = True
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._server_state_manager = ServerStateManager(
+ ServerStates.CONNECTED, on_change=self._on_server_state_change
+ )
+
+ def _on_server_state_change(self, old_state, new_state):
+ log.debug("[#%04X] State: %s > %s", self.local_port,
+ old_state.name, new_state.name)
+
+ @property
+ def is_reset(self):
+ # We can't be sure of the server's state if there are still pending
+ # responses. Unless the last message we sent was RESET. In that case
+ # the server state will always be READY when we're done.
+ if (self.responses and self.responses[-1]
+ and self.responses[-1].message == "reset"):
+ return True
+ return self._server_state_manager.state == ServerStates.READY
+
+ @property
+ def encrypted(self):
+ return isinstance(self.socket, SSLSocket)
+
+ @property
+ def der_encoded_server_certificate(self):
+ return self.socket.getpeercert(binary_form=True)
+
+ @property
+ def local_port(self):
+ try:
+ return self.socket.getsockname()[1]
+ except OSError:
+ return 0
+
+ def get_base_headers(self):
+ return {
+ "user_agent": self.user_agent,
+ }
+
+ async def hello(self):
+ headers = self.get_base_headers()
+ headers.update(self.auth_dict)
+ logged_headers = dict(headers)
+ if "credentials" in logged_headers:
+ logged_headers["credentials"] = "*******"
+ log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers)
+ self._append(b"\x01", (headers,),
+ response=InitResponse(self, "hello",
+ on_success=self.server_info.update))
+ await self.send_all()
+ await self.fetch_all()
+ check_supported_server_product(self.server_info.agent)
+
+ async def route(self, database=None, imp_user=None, bookmarks=None):
+ if imp_user is not None:
+ raise ConfigurationError(
+ "Impersonation is not supported in Bolt Protocol {!r}. "
+ "Trying to impersonate {!r}.".format(
+ self.PROTOCOL_VERSION, imp_user
+ )
+ )
+ metadata = {}
+ records = []
+
+ if database is None: # default database
+ self.run(
+ "CALL dbms.routing.getRoutingTable($context)",
+ {"context": self.routing_context},
+ mode="r",
+ bookmarks=bookmarks,
+ db=SYSTEM_DATABASE,
+ on_success=metadata.update
+ )
+ else:
+ self.run(
+ "CALL dbms.routing.getRoutingTable($context, $database)",
+ {"context": self.routing_context, "database": database},
+ mode="r",
+ bookmarks=bookmarks,
+ db=SYSTEM_DATABASE,
+ on_success=metadata.update
+ )
+ self.pull(on_success=metadata.update, on_records=records.extend)
+ await self.send_all()
+ await self.fetch_all()
+ routing_info = [dict(zip(metadata.get("fields", ()), values)) for values in records]
+ return routing_info
+
+ def run(self, query, parameters=None, mode=None, bookmarks=None,
+ metadata=None, timeout=None, db=None, imp_user=None, **handlers):
+ if imp_user is not None:
+ raise ConfigurationError(
+ "Impersonation is not supported in Bolt Protocol {!r}. "
+ "Trying to impersonate {!r}.".format(
+ self.PROTOCOL_VERSION, imp_user
+ )
+ )
+ if not parameters:
+ parameters = {}
+ extra = {}
+ if mode in (READ_ACCESS, "r"):
+ extra["mode"] = "r" # It will default to mode "w" if nothing is specified
+ if db:
+ extra["db"] = db
+ if bookmarks:
+ try:
+ extra["bookmarks"] = list(bookmarks)
+ except TypeError:
+ raise TypeError("Bookmarks must be provided within an iterable")
+ if metadata:
+ try:
+ extra["tx_metadata"] = dict(metadata)
+ except TypeError:
+ raise TypeError("Metadata must be coercible to a dict")
+ if timeout:
+ try:
+ extra["tx_timeout"] = int(1000 * timeout)
+ except TypeError:
+ raise TypeError("Timeout must be specified as a number of seconds")
+ fields = (query, parameters, extra)
+ log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields)))
+ if query.upper() == u"COMMIT":
+ self._append(b"\x10", fields, CommitResponse(self, "run",
+ **handlers))
+ else:
+ self._append(b"\x10", fields, Response(self, "run", **handlers))
+
+ def discard(self, n=-1, qid=-1, **handlers):
+ extra = {"n": n}
+ if qid != -1:
+ extra["qid"] = qid
+ log.debug("[#%04X] C: DISCARD %r", self.local_port, extra)
+ self._append(b"\x2F", (extra,), Response(self, "discard", **handlers))
+
+ def pull(self, n=-1, qid=-1, **handlers):
+ extra = {"n": n}
+ if qid != -1:
+ extra["qid"] = qid
+ log.debug("[#%04X] C: PULL %r", self.local_port, extra)
+ self._append(b"\x3F", (extra,), Response(self, "pull", **handlers))
+
+ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
+ db=None, imp_user=None, **handlers):
+ if imp_user is not None:
+ raise ConfigurationError(
+ "Impersonation is not supported in Bolt Protocol {!r}. "
+ "Trying to impersonate {!r}.".format(
+ self.PROTOCOL_VERSION, imp_user
+ )
+ )
+ extra = {}
+ if mode in (READ_ACCESS, "r"):
+ extra["mode"] = "r" # It will default to mode "w" if nothing is specified
+ if db:
+ extra["db"] = db
+ if bookmarks:
+ try:
+ extra["bookmarks"] = list(bookmarks)
+ except TypeError:
+ raise TypeError("Bookmarks must be provided within an iterable")
+ if metadata:
+ try:
+ extra["tx_metadata"] = dict(metadata)
+ except TypeError:
+ raise TypeError("Metadata must be coercible to a dict")
+ if timeout:
+ try:
+ extra["tx_timeout"] = int(1000 * timeout)
+ except TypeError:
+ raise TypeError("Timeout must be specified as a number of seconds")
+ log.debug("[#%04X] C: BEGIN %r", self.local_port, extra)
+ self._append(b"\x11", (extra,), Response(self, "begin", **handlers))
+
+ def commit(self, **handlers):
+ log.debug("[#%04X] C: COMMIT", self.local_port)
+ self._append(b"\x12", (), CommitResponse(self, "commit", **handlers))
+
+ def rollback(self, **handlers):
+ log.debug("[#%04X] C: ROLLBACK", self.local_port)
+ self._append(b"\x13", (), Response(self, "rollback", **handlers))
+
+ async def reset(self):
+ """ Add a RESET message to the outgoing queue, send
+ it and consume all remaining messages.
+ """
+
+ def fail(metadata):
+ raise BoltProtocolError("RESET failed %r" % metadata, self.unresolved_address)
+
+ log.debug("[#%04X] C: RESET", self.local_port)
+ self._append(b"\x0F", response=Response(self, "reset", on_failure=fail))
+ await self.send_all()
+ await self.fetch_all()
+
+ async def fetch_message(self):
+ """ Receive at most one message from the server, if available.
+
+ :return: 2-tuple of number of detail messages and number of summary
+ messages fetched
+ """
+ if self._closed:
+ raise ServiceUnavailable("Failed to read from closed connection {!r} ({!r})".format(
+ self.unresolved_address, self.server_info.address))
+
+ if self._defunct:
+ raise ServiceUnavailable("Failed to read from defunct connection {!r} ({!r})".format(
+ self.unresolved_address, self.server_info.address))
+
+ if not self.responses:
+ return 0, 0
+
+ # Receive exactly one message
+ details, summary_signature, summary_metadata = \
+ await AsyncUtil.next(self.inbox)
+
+ if details:
+ log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # Do not log any data
+ await self.responses[0].on_records(details)
+
+ if summary_signature is None:
+ return len(details), 0
+
+ response = self.responses.popleft()
+ response.complete = True
+ if summary_signature == b"\x70":
+ log.debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata)
+ self._server_state_manager.transition(response.message,
+ summary_metadata)
+ await response.on_success(summary_metadata or {})
+ elif summary_signature == b"\x7E":
+ log.debug("[#%04X] S: IGNORED", self.local_port)
+ await response.on_ignored(summary_metadata or {})
+ elif summary_signature == b"\x7F":
+ log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata)
+ self._server_state_manager.state = ServerStates.FAILED
+ try:
+ await response.on_failure(summary_metadata or {})
+ except (ServiceUnavailable, DatabaseUnavailable):
+ if self.pool:
+ await self.pool.deactivate(address=self.unresolved_address)
+ raise
+ except (NotALeader, ForbiddenOnReadOnlyDatabase):
+ if self.pool:
+ self.pool.on_write_failure(address=self.unresolved_address)
+ raise
+ except Neo4jError as e:
+ if self.pool and e.invalidates_all_connections():
+ await self.pool.mark_all_stale()
+ raise
+ else:
+ raise BoltProtocolError("Unexpected response message with signature "
+ "%02X" % ord(summary_signature), self.unresolved_address)
+
+ return len(details), 1
+
+ async def close(self):
+ """ Close the connection.
+ """
+ if not self._closed:
+ if not self._defunct:
+ log.debug("[#%04X] C: GOODBYE", self.local_port)
+ self._append(b"\x02", ())
+ try:
+ await self._send_all()
+ except (OSError, BoltError, DriverError):
+ pass
+ log.debug("[#%04X] C: ", self.local_port)
+ try:
+ self.socket.close()
+ except OSError:
+ pass
+ finally:
+ self._closed = True
+
+ def closed(self):
+ return self._closed
+
+ def defunct(self):
+ return self._defunct
+
+
+class AsyncBolt4x1(AsyncBolt4x0):
+ """ Protocol handler for Bolt 4.1.
+
+ This is supported by Neo4j versions 4.1 - 4.4.
+ """
+
+ PROTOCOL_VERSION = Version(4, 1)
+
+ def get_base_headers(self):
+ """ Bolt 4.1 passes the routing context, originally taken from
+ the URI, into the connection initialisation message. This
+ enables server-side routing to propagate the same behaviour
+ through its driver.
+ """
+ headers = {
+ "user_agent": self.user_agent,
+ }
+ if self.routing_context is not None:
+ headers["routing"] = self.routing_context
+ return headers
+
+
+class AsyncBolt4x2(AsyncBolt4x1):
+ """ Protocol handler for Bolt 4.2.
+
+ This is supported by Neo4j version 4.2 - 4.4.
+ """
+
+ PROTOCOL_VERSION = Version(4, 2)
+
+
+class AsyncBolt4x3(AsyncBolt4x2):
+ """ Protocol handler for Bolt 4.3.
+
+ This is supported by Neo4j version 4.3 - 4.4.
+ """
+
+ PROTOCOL_VERSION = Version(4, 3)
+
+ async def route(self, database=None, imp_user=None, bookmarks=None):
+ if imp_user is not None:
+ raise ConfigurationError(
+ "Impersonation is not supported in Bolt Protocol {!r}. "
+ "Trying to impersonate {!r}.".format(
+ self.PROTOCOL_VERSION, imp_user
+ )
+ )
+
+ routing_context = self.routing_context or {}
+ log.debug("[#%04X] C: ROUTE %r %r %r", self.local_port,
+ routing_context, bookmarks, database)
+ metadata = {}
+ if bookmarks is None:
+ bookmarks = []
+ else:
+ bookmarks = list(bookmarks)
+ self._append(b"\x66", (routing_context, bookmarks, database),
+ response=Response(self, "route",
+ on_success=metadata.update))
+ await self.send_all()
+ await self.fetch_all()
+ return [metadata.get("rt")]
+
+ async def hello(self):
+ def on_success(metadata):
+ self.configuration_hints.update(metadata.pop("hints", {}))
+ self.server_info.update(metadata)
+ if "connection.recv_timeout_seconds" in self.configuration_hints:
+ recv_timeout = self.configuration_hints[
+ "connection.recv_timeout_seconds"
+ ]
+ if isinstance(recv_timeout, int) and recv_timeout > 0:
+ self.socket.settimeout(recv_timeout)
+ else:
+ log.info("[#%04X] Server supplied an invalid value for "
+ "connection.recv_timeout_seconds (%r). Make sure "
+ "the server and network is set up correctly.",
+ self.local_port, recv_timeout)
+
+ headers = self.get_base_headers()
+ headers.update(self.auth_dict)
+ logged_headers = dict(headers)
+ if "credentials" in logged_headers:
+ logged_headers["credentials"] = "*******"
+ log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers)
+ self._append(b"\x01", (headers,),
+ response=InitResponse(self, "hello",
+ on_success=on_success))
+ await self.send_all()
+ await self.fetch_all()
+ check_supported_server_product(self.server_info.agent)
+
+
+class AsyncBolt4x4(AsyncBolt4x3):
+ """ Protocol handler for Bolt 4.4.
+
+ This is supported by Neo4j version 4.4.
+ """
+
+ PROTOCOL_VERSION = Version(4, 4)
+
+ async def route(self, database=None, imp_user=None, bookmarks=None):
+ routing_context = self.routing_context or {}
+ db_context = {}
+ if database is not None:
+ db_context.update(db=database)
+ if imp_user is not None:
+ db_context.update(imp_user=imp_user)
+ log.debug("[#%04X] C: ROUTE %r %r %r", self.local_port,
+ routing_context, bookmarks, db_context)
+ metadata = {}
+ if bookmarks is None:
+ bookmarks = []
+ else:
+ bookmarks = list(bookmarks)
+ self._append(b"\x66", (routing_context, bookmarks, db_context),
+ response=Response(self, "route",
+ on_success=metadata.update))
+ await self.send_all()
+ await self.fetch_all()
+ return [metadata.get("rt")]
+
+ def run(self, query, parameters=None, mode=None, bookmarks=None,
+ metadata=None, timeout=None, db=None, imp_user=None, **handlers):
+ if not parameters:
+ parameters = {}
+ extra = {}
+ if mode in (READ_ACCESS, "r"):
+ # It will default to mode "w" if nothing is specified
+ extra["mode"] = "r"
+ if db:
+ extra["db"] = db
+ if imp_user:
+ extra["imp_user"] = imp_user
+ if bookmarks:
+ try:
+ extra["bookmarks"] = list(bookmarks)
+ except TypeError:
+ raise TypeError("Bookmarks must be provided within an iterable")
+ if metadata:
+ try:
+ extra["tx_metadata"] = dict(metadata)
+ except TypeError:
+ raise TypeError("Metadata must be coercible to a dict")
+ if timeout:
+ try:
+ extra["tx_timeout"] = int(1000 * timeout)
+ except TypeError:
+ raise TypeError("Timeout must be specified as a number of "
+ "seconds")
+ fields = (query, parameters, extra)
+ log.debug("[#%04X] C: RUN %s", self.local_port,
+ " ".join(map(repr, fields)))
+ if query.upper() == u"COMMIT":
+ self._append(b"\x10", fields, CommitResponse(self, "run",
+ **handlers))
+ else:
+ self._append(b"\x10", fields, Response(self, "run", **handlers))
+
+ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
+ db=None, imp_user=None, **handlers):
+ extra = {}
+ if mode in (READ_ACCESS, "r"):
+ # It will default to mode "w" if nothing is specified
+ extra["mode"] = "r"
+ if db:
+ extra["db"] = db
+ if imp_user:
+ extra["imp_user"] = imp_user
+ if bookmarks:
+ try:
+ extra["bookmarks"] = list(bookmarks)
+ except TypeError:
+ raise TypeError("Bookmarks must be provided within an iterable")
+ if metadata:
+ try:
+ extra["tx_metadata"] = dict(metadata)
+ except TypeError:
+ raise TypeError("Metadata must be coercible to a dict")
+ if timeout:
+ try:
+ extra["tx_timeout"] = int(1000 * timeout)
+ except TypeError:
+ raise TypeError("Timeout must be specified as a number of "
+ "seconds")
+ log.debug("[#%04X] C: BEGIN %r", self.local_port, extra)
+ self._append(b"\x11", (extra,), Response(self, "begin", **handlers))
diff --git a/neo4j/_async/io/_common.py b/neo4j/_async/io/_common.py
new file mode 100644
index 00000000..5e0d18f5
--- /dev/null
+++ b/neo4j/_async/io/_common.py
@@ -0,0 +1,280 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import asyncio
+import logging
+import socket
+from struct import pack as struct_pack
+
+from ..._async_compat.util import AsyncUtil
+from ...exceptions import (
+ Neo4jError,
+ ServiceUnavailable,
+ SessionExpired,
+ UnsupportedServerProduct,
+)
+from ...packstream import (
+ UnpackableBuffer,
+ Unpacker,
+)
+
+
+log = logging.getLogger("neo4j")
+
+
+class AsyncMessageInbox:
+
+ def __init__(self, s, on_error):
+ self.on_error = on_error
+ self._messages = self._yield_messages(s)
+
+ async def _yield_messages(self, sock):
+ try:
+ buffer = UnpackableBuffer()
+ unpacker = Unpacker(buffer)
+ chunk_size = 0
+ while True:
+
+ while chunk_size == 0:
+ # Determine the chunk size and skip noop
+ await receive_into_buffer(sock, buffer, 2)
+ chunk_size = buffer.pop_u16()
+ if chunk_size == 0:
+ log.debug("[#%04X] S: ", sock.getsockname()[1])
+
+ await receive_into_buffer(sock, buffer, chunk_size + 2)
+ chunk_size = buffer.pop_u16()
+
+ if chunk_size == 0:
+ # chunk_size was the end marker for the message
+ size, tag = unpacker.unpack_structure_header()
+ fields = [unpacker.unpack() for _ in range(size)]
+ yield tag, fields
+ # Reset for new message
+ unpacker.reset()
+
+ except (OSError, socket.timeout) as error:
+ await AsyncUtil.callback(self.on_error, error)
+
+ async def pop(self):
+ return await AsyncUtil.next(self._messages)
+
+
+class AsyncInbox(AsyncMessageInbox):
+
+ async def __anext__(self):
+ tag, fields = await self.pop()
+ if tag == b"\x71":
+ return fields, None, None
+ elif fields:
+ return [], tag, fields[0]
+ else:
+ return [], tag, None
+
+
+class Outbox:
+
+ def __init__(self, max_chunk_size=16384):
+ self._max_chunk_size = max_chunk_size
+ self._chunked_data = bytearray()
+ self._raw_data = bytearray()
+ self.write = self._raw_data.extend
+
+ def max_chunk_size(self):
+ return self._max_chunk_size
+
+ def clear(self):
+ self._chunked_data = bytearray()
+ self._raw_data.clear()
+
+ def _chunk_data(self):
+ data_len = len(self._raw_data)
+ num_full_chunks, chunk_rest = divmod(
+ data_len, self._max_chunk_size
+ )
+ num_chunks = num_full_chunks + bool(chunk_rest)
+
+ data_view = memoryview(self._raw_data)
+ header_start = len(self._chunked_data)
+ data_start = header_start + 2
+ raw_data_start = 0
+ for i in range(num_chunks):
+ chunk_size = min(data_len - raw_data_start,
+ self._max_chunk_size)
+ self._chunked_data[header_start:data_start] = struct_pack(
+ ">H", chunk_size
+ )
+ self._chunked_data[data_start:(data_start + chunk_size)] = \
+ data_view[raw_data_start:(raw_data_start + chunk_size)]
+ header_start += chunk_size + 2
+ data_start = header_start + 2
+ raw_data_start += chunk_size
+ del data_view
+ self._raw_data.clear()
+
+ def wrap_message(self):
+ self._chunk_data()
+ self._chunked_data += b"\x00\x00"
+
+ def view(self):
+ self._chunk_data()
+ return memoryview(self._chunked_data)
+
+
+class ConnectionErrorHandler:
+ """
+ Wrapper class for handling connection errors.
+
+ The class will wrap each method to invoke a callback if the method raises
+ Neo4jError, SessionExpired, or ServiceUnavailable.
+ The error will be re-raised after the callback.
+ """
+
+ def __init__(self, connection, on_error):
+ """
+ :param connection the connection object to warp
+ :type connection Bolt
+ :param on_error the function to be called when a method of
+ connection raises of of the caught errors.
+ :type on_error callable
+ """
+ self.__connection = connection
+ self.__on_error = on_error
+
+ def __getattr__(self, name):
+ connection_attr = getattr(self.__connection, name)
+ if not callable(connection_attr):
+ return connection_attr
+
+ def outer(func):
+ def inner(*args, **kwargs):
+ try:
+ func(*args, **kwargs)
+ except (Neo4jError, ServiceUnavailable, SessionExpired) as exc:
+ assert not asyncio.iscoroutinefunction(self.__on_error)
+ self.__on_error(exc)
+ raise
+ return inner
+
+ def outer_async(coroutine_func):
+ async def inner(*args, **kwargs):
+ try:
+ await coroutine_func(*args, **kwargs)
+ except (Neo4jError, ServiceUnavailable, SessionExpired) as exc:
+ await AsyncUtil.callback(self.__on_error, exc)
+ raise
+ return inner
+
+ if asyncio.iscoroutinefunction(connection_attr):
+ return outer_async(connection_attr)
+ return outer(connection_attr)
+
+ def __setattr__(self, name, value):
+ if name.startswith("_" + self.__class__.__name__ + "__"):
+ super().__setattr__(name, value)
+ else:
+ setattr(self.__connection, name, value)
+
+
+class Response:
+ """ Subscriber object for a full response (zero or
+ more detail messages followed by one summary message).
+ """
+
+ def __init__(self, connection, message, **handlers):
+ self.connection = connection
+ self.handlers = handlers
+ self.message = message
+ self.complete = False
+
+ async def on_records(self, records):
+ """ Called when one or more RECORD messages have been received.
+ """
+ handler = self.handlers.get("on_records")
+ await AsyncUtil.callback(handler, records)
+
+ async def on_success(self, metadata):
+ """ Called when a SUCCESS message has been received.
+ """
+ handler = self.handlers.get("on_success")
+ await AsyncUtil.callback(handler, metadata)
+
+ if not metadata.get("has_more"):
+ handler = self.handlers.get("on_summary")
+ await AsyncUtil.callback(handler)
+
+ async def on_failure(self, metadata):
+ """ Called when a FAILURE message has been received.
+ """
+ try:
+ self.connection.reset()
+ except (SessionExpired, ServiceUnavailable):
+ pass
+ handler = self.handlers.get("on_failure")
+ await AsyncUtil.callback(handler, metadata)
+ handler = self.handlers.get("on_summary")
+ await AsyncUtil.callback(handler)
+ raise Neo4jError.hydrate(**metadata)
+
+ async def on_ignored(self, metadata=None):
+ """ Called when an IGNORED message has been received.
+ """
+ handler = self.handlers.get("on_ignored")
+ await AsyncUtil.callback(handler, metadata)
+ handler = self.handlers.get("on_summary")
+ await AsyncUtil.callback(handler)
+
+
+class InitResponse(Response):
+
+ async def on_failure(self, metadata):
+ code = metadata.get("code")
+ if code == "Neo.ClientError.Security.Unauthorized":
+ raise Neo4jError.hydrate(**metadata)
+ else:
+ raise ServiceUnavailable(
+ metadata.get("message", "Connection initialisation failed")
+ )
+
+
+class CommitResponse(Response):
+
+ pass
+
+
+def check_supported_server_product(agent):
+ """ Checks that a server product is supported by the driver by
+ looking at the server agent string.
+
+ :param agent: server agent string to check for validity
+ :raises UnsupportedServerProduct: if the product is not supported
+ """
+ if not agent.startswith("Neo4j/"):
+ raise UnsupportedServerProduct(agent)
+
+
+async def receive_into_buffer(sock, buffer, n_bytes):
+ end = buffer.used + n_bytes
+ if end > len(buffer.data):
+ buffer.data += bytearray(end - len(buffer.data))
+ view = memoryview(buffer.data)
+ while buffer.used < end:
+ n = await sock.recv_into(view[buffer.used:end], end - buffer.used)
+ if n == 0:
+ raise OSError("No data")
+ buffer.used += n
diff --git a/neo4j/_async/io/_pool.py b/neo4j/_async/io/_pool.py
new file mode 100644
index 00000000..30998b26
--- /dev/null
+++ b/neo4j/_async/io/_pool.py
@@ -0,0 +1,701 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import abc
+from collections import (
+ defaultdict,
+ deque,
+)
+import logging
+from logging import getLogger
+from random import choice
+from time import perf_counter
+
+from ..._async_compat.concurrency import (
+ AsyncCondition,
+ AsyncRLock,
+)
+from ..._async_compat.network import AsyncNetworkUtil
+from ..._exceptions import BoltError
+from ...api import (
+ READ_ACCESS,
+ WRITE_ACCESS,
+)
+from ...conf import (
+ PoolConfig,
+ WorkspaceConfig,
+)
+from ...exceptions import (
+ ClientError,
+ ConfigurationError,
+ DriverError,
+ Neo4jError,
+ ReadServiceUnavailable,
+ ServiceUnavailable,
+ SessionExpired,
+ WriteServiceUnavailable,
+)
+from ...routing import RoutingTable
+from ._bolt import AsyncBolt
+
+
+# Set up logger
+log = getLogger("neo4j")
+
+
+class AsyncIOPool(abc.ABC):
+ """ A collection of connections to one or more server addresses.
+ """
+
+ def __init__(self, opener, pool_config, workspace_config):
+ assert callable(opener)
+ assert isinstance(pool_config, PoolConfig)
+ assert isinstance(workspace_config, WorkspaceConfig)
+
+ self.opener = opener
+ self.pool_config = pool_config
+ self.workspace_config = workspace_config
+ self.connections = defaultdict(deque)
+ self.lock = AsyncRLock()
+ self.cond = AsyncCondition(self.lock)
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ await self.close()
+
+ async def _acquire(self, address, timeout):
+ """ Acquire a connection to a given address from the pool.
+ The address supplied should always be an IP address, not
+ a host name.
+
+ This method is thread safe.
+ """
+ t0 = perf_counter()
+ if timeout is None:
+ timeout = self.workspace_config.connection_acquisition_timeout
+
+ async with self.lock:
+ def time_remaining():
+ t = timeout - (perf_counter() - t0)
+ return t if t > 0 else 0
+
+ while True:
+ # try to find a free connection in pool
+ for connection in list(self.connections.get(address, [])):
+ if (connection.closed() or connection.defunct()
+ or (connection.stale() and not connection.in_use)):
+ # `close` is a noop on already closed connections.
+ # This is to make sure that the connection is
+ # gracefully closed, e.g. if it's just marked as
+ # `stale` but still alive.
+ if log.isEnabledFor(logging.DEBUG):
+ log.debug(
+ "[#%04X] C: removing old connection "
+ "(closed=%s, defunct=%s, stale=%s, in_use=%s)",
+ connection.local_port,
+ connection.closed(), connection.defunct(),
+ connection.stale(), connection.in_use
+ )
+ await connection.close()
+ try:
+ self.connections.get(address, []).remove(connection)
+ except ValueError:
+ # If closure fails (e.g. because the server went
+ # down), all connections to the same address will
+ # be removed. Therefore, we silently ignore if the
+ # connection isn't in the pool anymore.
+ pass
+ continue
+ if not connection.in_use:
+ connection.in_use = True
+ return connection
+ # all connections in pool are in-use
+ connections = self.connections[address]
+ max_pool_size = self.pool_config.max_connection_pool_size
+ infinite_pool_size = (max_pool_size < 0
+ or max_pool_size == float("inf"))
+ can_create_new_connection = (
+ infinite_pool_size
+ or len(connections) < max_pool_size
+ )
+ if can_create_new_connection:
+ timeout = min(self.pool_config.connection_timeout,
+ time_remaining())
+ try:
+ connection = await self.opener(address, timeout)
+ except ServiceUnavailable:
+ await self.remove(address)
+ raise
+ else:
+ connection.pool = self
+ connection.in_use = True
+ connections.append(connection)
+ return connection
+
+ # failed to obtain a connection from pool because the
+ # pool is full and no free connection in the pool
+ if time_remaining():
+ await self.cond.wait(time_remaining())
+ # if timed out, then we throw error. This time
+ # computation is needed, as with python 2.7, we
+ # cannot tell if the condition is notified or
+ # timed out when we come to this line
+ if not time_remaining():
+ raise ClientError("Failed to obtain a connection from pool "
+ "within {!r}s".format(timeout))
+ else:
+ raise ClientError("Failed to obtain a connection from pool "
+ "within {!r}s".format(timeout))
+
+ @abc.abstractmethod
+ async def acquire(
+ self, access_mode=None, timeout=None, database=None, bookmarks=None
+ ):
+ """ Acquire a connection to a server that can satisfy a set of parameters.
+
+ :param access_mode:
+ :param timeout:
+ :param database:
+ :param bookmarks:
+ """
+
+ async def release(self, *connections):
+ """ Release a connection back into the pool.
+ This method is thread safe.
+ """
+ async with self.lock:
+ for connection in connections:
+ if not (connection.defunct()
+ or connection.closed()
+ or connection.is_reset):
+ try:
+ await connection.reset()
+ except (Neo4jError, DriverError, BoltError) as e:
+ log.debug(
+ "Failed to reset connection on release: %s", e
+ )
+ connection.in_use = False
+ self.cond.notify_all()
+
+ def in_use_connection_count(self, address):
+ """ Count the number of connections currently in use to a given
+ address.
+ """
+ try:
+ connections = self.connections[address]
+ except KeyError:
+ return 0
+ else:
+ return sum(1 if connection.in_use else 0 for connection in connections)
+
+ async def mark_all_stale(self):
+ async with self.lock:
+ for address in self.connections:
+ for connection in self.connections[address]:
+ connection.set_stale()
+
+ async def deactivate(self, address):
+ """ Deactivate an address from the connection pool, if present, closing
+ all idle connection to that address
+ """
+ async with self.lock:
+ try:
+ connections = self.connections[address]
+ except KeyError: # already removed from the connection pool
+ return
+ for conn in list(connections):
+ if not conn.in_use:
+ connections.remove(conn)
+ try:
+ await conn.close()
+ except OSError:
+ pass
+ if not connections:
+ await self.remove(address)
+
+ def on_write_failure(self, address):
+ raise WriteServiceUnavailable(
+ "No write service available for pool {}".format(self)
+ )
+
+ async def remove(self, address):
+ """ Remove an address from the connection pool, if present, closing
+ all connections to that address.
+ """
+ async with self.lock:
+ for connection in self.connections.pop(address, ()):
+ try:
+ await connection.close()
+ except OSError:
+ pass
+
+ async def close(self):
+ """ Close all connections and empty the pool.
+ This method is thread safe.
+ """
+ try:
+ async with self.lock:
+ for address in list(self.connections):
+ await self.remove(address)
+ except TypeError:
+ pass
+
+
+class AsyncBoltPool(AsyncIOPool):
+
+ @classmethod
+ def open(cls, address, *, auth, pool_config, workspace_config):
+ """Create a new BoltPool
+
+ :param address:
+ :param auth:
+ :param pool_config:
+ :param workspace_config:
+ :return: BoltPool
+ """
+
+ async def opener(addr, timeout):
+ return await AsyncBolt.open(
+ addr, auth=auth, timeout=timeout, routing_context=None,
+ **pool_config
+ )
+
+ pool = cls(opener, pool_config, workspace_config, address)
+ return pool
+
+ def __init__(self, opener, pool_config, workspace_config, address):
+ super().__init__(opener, pool_config, workspace_config)
+ self.address = address
+
+ def __repr__(self):
+ return "<{} address={!r}>".format(self.__class__.__name__,
+ self.address)
+
+ async def acquire(
+ self, access_mode=None, timeout=None, database=None, bookmarks=None
+ ):
+ # The access_mode and database is not needed for a direct connection,
+ # it's just there for consistency.
+ return await self._acquire(self.address, timeout)
+
+
+class AsyncNeo4jPool(AsyncIOPool):
+ """ Connection pool with routing table.
+ """
+
+ @classmethod
+ def open(cls, *addresses, auth, pool_config, workspace_config,
+ routing_context=None):
+ """Create a new Neo4jPool
+
+ :param addresses: one or more address as positional argument
+ :param auth:
+ :param pool_config:
+ :param workspace_config:
+ :param routing_context:
+ :return: Neo4jPool
+ """
+
+ address = addresses[0]
+ if routing_context is None:
+ routing_context = {}
+ elif "address" in routing_context:
+ raise ConfigurationError("The key 'address' is reserved for routing context.")
+ routing_context["address"] = str(address)
+
+ async def opener(addr, timeout):
+ return await AsyncBolt.open(
+ addr, auth=auth, timeout=timeout,
+ routing_context=routing_context, **pool_config
+ )
+
+ pool = cls(opener, pool_config, workspace_config, address)
+ return pool
+
+ def __init__(self, opener, pool_config, workspace_config, address):
+ """
+
+ :param opener:
+ :param pool_config:
+ :param workspace_config:
+ :param address:
+ """
+ super().__init__(opener, pool_config, workspace_config)
+ # Each database have a routing table, the default database is a special case.
+ log.debug("[#0000] C: routing address %r", address)
+ self.address = address
+ self.routing_tables = {workspace_config.database: RoutingTable(database=workspace_config.database, routers=[address])}
+ self.refresh_lock = AsyncRLock()
+
+ def __repr__(self):
+ """ The representation shows the initial routing addresses.
+
+ :return: The representation
+ :rtype: str
+ """
+ return "<{} addresses={!r}>".format(self.__class__.__name__, self.get_default_database_initial_router_addresses())
+
+ @property
+ def first_initial_routing_address(self):
+ return self.get_default_database_initial_router_addresses()[0]
+
+ def get_default_database_initial_router_addresses(self):
+ """ Get the initial router addresses for the default database.
+
+ :return:
+ :rtype: OrderedSet
+ """
+ return self.get_routing_table_for_default_database().initial_routers
+
+ def get_default_database_router_addresses(self):
+ """ Get the router addresses for the default database.
+
+ :return:
+ :rtype: OrderedSet
+ """
+ return self.get_routing_table_for_default_database().routers
+
+ def get_routing_table_for_default_database(self):
+ return self.routing_tables[self.workspace_config.database]
+
+ async def get_or_create_routing_table(self, database):
+ async with self.refresh_lock:
+ if database not in self.routing_tables:
+ self.routing_tables[database] = RoutingTable(
+ database=database,
+ routers=self.get_default_database_initial_router_addresses()
+ )
+ return self.routing_tables[database]
+
+ async def fetch_routing_info(
+ self, address, database, imp_user, bookmarks, timeout
+ ):
+ """ Fetch raw routing info from a given router address.
+
+ :param address: router address
+ :param database: the database name to get routing table for
+ :param imp_user: the user to impersonate while fetching the routing
+ table
+ :type imp_user: str or None
+ :param bookmarks: iterable of bookmark values after which the routing
+ info should be fetched
+ :param timeout: connection acquisition timeout in seconds
+
+ :return: list of routing records, or None if no connection
+ could be established or if no readers or writers are present
+ :raise ServiceUnavailable: if the server does not support
+ routing, or if routing support is broken or outdated
+ """
+ cx = await self._acquire(address, timeout)
+ try:
+ routing_table = await cx.route(
+ database or self.workspace_config.database,
+ imp_user or self.workspace_config.impersonated_user,
+ bookmarks
+ )
+ finally:
+ await self.release(cx)
+ return routing_table
+
+ async def fetch_routing_table(
+ self, *, address, timeout, database, imp_user, bookmarks
+ ):
+ """ Fetch a routing table from a given router address.
+
+ :param address: router address
+ :param timeout: seconds
+ :param database: the database name
+ :type: str
+ :param imp_user: the user to impersonate while fetching the routing
+ table
+ :type imp_user: str or None
+ :param bookmarks: bookmarks used when fetching routing table
+
+ :return: a new RoutingTable instance or None if the given router is
+ currently unable to provide routing information
+ """
+ new_routing_info = None
+ try:
+ new_routing_info = await self.fetch_routing_info(
+ address, database, imp_user, bookmarks, timeout
+ )
+ except Neo4jError as e:
+ # checks if the code is an error that is caused by the client. In
+ # this case there is no sense in trying to fetch a RT from another
+ # router. Hence, the driver should fail fast during discovery.
+ if e.is_fatal_during_discovery():
+ raise
+ except (ServiceUnavailable, SessionExpired):
+ pass
+ if not new_routing_info:
+ log.debug("Failed to fetch routing info %s", address)
+ return None
+ else:
+ servers = new_routing_info[0]["servers"]
+ ttl = new_routing_info[0]["ttl"]
+ database = new_routing_info[0].get("db", database)
+ new_routing_table = RoutingTable.parse_routing_info(
+ database=database, servers=servers, ttl=ttl
+ )
+
+ # Parse routing info and count the number of each type of server
+ num_routers = len(new_routing_table.routers)
+ num_readers = len(new_routing_table.readers)
+
+ # num_writers = len(new_routing_table.writers)
+ # If no writers are available. This likely indicates a temporary state,
+ # such as leader switching, so we should not signal an error.
+
+ # No routers
+ if num_routers == 0:
+ log.debug("No routing servers returned from server %s", address)
+ return None
+
+ # No readers
+ if num_readers == 0:
+ log.debug("No read servers returned from server %s", address)
+ return None
+
+ # At least one of each is fine, so return this table
+ return new_routing_table
+
+ async def _update_routing_table_from(
+ self, *routers, database=None, imp_user=None, bookmarks=None,
+ database_callback=None
+ ):
+ """ Try to update routing tables with the given routers.
+
+ :return: True if the routing table is successfully updated,
+ otherwise False
+ """
+ log.debug("Attempting to update routing table from {}".format(
+ ", ".join(map(repr, routers)))
+ )
+ for router in routers:
+ async for address in AsyncNetworkUtil.resolve_address(
+ router, resolver=self.pool_config.resolver
+ ):
+ new_routing_table = await self.fetch_routing_table(
+ address=address,
+ timeout=self.pool_config.connection_timeout,
+ database=database, imp_user=imp_user, bookmarks=bookmarks
+ )
+ if new_routing_table is not None:
+ new_databse = new_routing_table.database
+ old_routing_table = await self.get_or_create_routing_table(
+ new_databse
+ )
+ old_routing_table.update(new_routing_table)
+ log.debug(
+ "[#0000] C: address=%r (%r)",
+ address, self.routing_tables[new_databse]
+ )
+ if callable(database_callback):
+ database_callback(new_databse)
+ return True
+ await self.deactivate(router)
+ return False
+
+ async def update_routing_table(
+ self, *, database, imp_user, bookmarks, database_callback=None
+ ):
+ """ Update the routing table from the first router able to provide
+ valid routing information.
+
+ :param database: The database name
+ :param imp_user: the user to impersonate while fetching the routing
+ table
+ :type imp_user: str or None
+ :param bookmarks: bookmarks used when fetching routing table
+ :param database_callback: A callback function that will be called with
+ the database name as only argument when a new routing table has been
+ acquired. This database name might different from `database` if that
+ was None and the underlying protocol supports reporting back the
+ actual database.
+
+ :raise neo4j.exceptions.ServiceUnavailable:
+ """
+ async with self.refresh_lock:
+ routing_table = await self.get_or_create_routing_table(database)
+ # copied because it can be modified
+ existing_routers = set(routing_table.routers)
+
+ prefer_initial_routing_address = \
+ self.routing_tables[database].initialized_without_writers
+
+ if prefer_initial_routing_address:
+ # TODO: Test this state
+ if await self._update_routing_table_from(
+ self.first_initial_routing_address, database=database,
+ imp_user=imp_user, bookmarks=bookmarks,
+ database_callback=database_callback
+ ):
+ # Why is only the first initial routing address used?
+ return
+ if await self._update_routing_table_from(
+ *(existing_routers - {self.first_initial_routing_address}),
+ database=database, imp_user=imp_user, bookmarks=bookmarks,
+ database_callback=database_callback
+ ):
+ return
+
+ if not prefer_initial_routing_address:
+ if await self._update_routing_table_from(
+ self.first_initial_routing_address, database=database,
+ imp_user=imp_user, bookmarks=bookmarks,
+ database_callback=database_callback
+ ):
+ # Why is only the first initial routing address used?
+ return
+
+ # None of the routers have been successful, so just fail
+ log.error("Unable to retrieve routing information")
+ raise ServiceUnavailable("Unable to retrieve routing information")
+
+ async def update_connection_pool(self, *, database):
+ routing_table = await self.get_or_create_routing_table(database)
+ servers = routing_table.servers()
+ for address in list(self.connections):
+ if address.unresolved not in servers:
+ await super(AsyncNeo4jPool, self).deactivate(address)
+
+ async def ensure_routing_table_is_fresh(
+ self, *, access_mode, database, imp_user, bookmarks,
+ database_callback=None
+ ):
+ """ Update the routing table if stale.
+
+ This method performs two freshness checks, before and after acquiring
+ the refresh lock. If the routing table is already fresh on entry, the
+ method exits immediately; otherwise, the refresh lock is acquired and
+ the second freshness check that follows determines whether an update
+ is still required.
+
+ This method is thread-safe.
+
+ :return: `True` if an update was required, `False` otherwise.
+ """
+ from neo4j.api import READ_ACCESS
+ async with self.refresh_lock:
+ routing_table = await self.get_or_create_routing_table(database)
+ if routing_table.is_fresh(readonly=(access_mode == READ_ACCESS)):
+ # Readers are fresh.
+ return False
+
+ await self.update_routing_table(
+ database=database, imp_user=imp_user, bookmarks=bookmarks,
+ database_callback=database_callback
+ )
+ await self.update_connection_pool(database=database)
+
+ for database in list(self.routing_tables.keys()):
+ # Remove unused databases in the routing table
+ # Remove the routing table after a timeout = TTL + 30s
+ log.debug("[#0000] C: database=%s", database)
+ if (self.routing_tables[database].should_be_purged_from_memory()
+ and database != self.workspace_config.database):
+ del self.routing_tables[database]
+
+ return True
+
+ async def _select_address(self, *, access_mode, database):
+ from ...api import READ_ACCESS
+ """ Selects the address with the fewest in-use connections.
+ """
+ async with self.refresh_lock:
+ if access_mode == READ_ACCESS:
+ addresses = self.routing_tables[database].readers
+ else:
+ addresses = self.routing_tables[database].writers
+ addresses_by_usage = {}
+ for address in addresses:
+ addresses_by_usage.setdefault(
+ self.in_use_connection_count(address), []
+ ).append(address)
+ if not addresses_by_usage:
+ if access_mode == READ_ACCESS:
+ raise ReadServiceUnavailable(
+ "No read service currently available"
+ )
+ else:
+ raise WriteServiceUnavailable(
+ "No write service currently available"
+ )
+ return choice(addresses_by_usage[min(addresses_by_usage)])
+
+ async def acquire(
+ self, access_mode=None, timeout=None, database=None, bookmarks=None
+ ):
+ if access_mode not in (WRITE_ACCESS, READ_ACCESS):
+ raise ClientError("Non valid 'access_mode'; {}".format(access_mode))
+ if not timeout:
+ raise ClientError("'timeout' must be a float larger than 0; {}"
+ .format(timeout))
+
+ from neo4j.api import check_access_mode
+ access_mode = check_access_mode(access_mode)
+ async with self.refresh_lock:
+ log.debug("[#0000] C: %r",
+ self.routing_tables)
+ await self.ensure_routing_table_is_fresh(
+ access_mode=access_mode, database=database, imp_user=None,
+ bookmarks=bookmarks
+ )
+
+ while True:
+ try:
+ # Get an address for a connection that have the fewest in-use
+ # connections.
+ address = await self._select_address(
+ access_mode=access_mode, database=database
+ )
+ except (ReadServiceUnavailable, WriteServiceUnavailable) as err:
+ raise SessionExpired("Failed to obtain connection towards '%s' server." % access_mode) from err
+ try:
+ log.debug("[#0000] C: database=%r address=%r", database, address)
+ # should always be a resolved address
+ connection = await self._acquire(address, timeout=timeout)
+ except ServiceUnavailable:
+ await self.deactivate(address=address)
+ else:
+ return connection
+
+ async def deactivate(self, address):
+ """ Deactivate an address from the connection pool,
+ if present, remove from the routing table and also closing
+ all idle connections to that address.
+ """
+ log.debug("[#0000] C: Deactivating address %r", address)
+ # We use `discard` instead of `remove` here since the former
+ # will not fail if the address has already been removed.
+ for database in self.routing_tables.keys():
+ self.routing_tables[database].routers.discard(address)
+ self.routing_tables[database].readers.discard(address)
+ self.routing_tables[database].writers.discard(address)
+ log.debug("[#0000] C: table=%r", self.routing_tables)
+ await super(AsyncNeo4jPool, self).deactivate(address)
+
+ def on_write_failure(self, address):
+ """ Remove a writer address from the routing table, if present.
+ """
+ log.debug("[#0000] C: Removing writer %r", address)
+ for database in self.routing_tables.keys():
+ self.routing_tables[database].writers.discard(address)
+ log.debug("[#0000] C: table=%r", self.routing_tables)
diff --git a/neo4j/_async/work/__init__.py b/neo4j/_async/work/__init__.py
new file mode 100644
index 00000000..e48e1c21
--- /dev/null
+++ b/neo4j/_async/work/__init__.py
@@ -0,0 +1,32 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .session import (
+ AsyncResult,
+ AsyncSession,
+ AsyncTransaction,
+ AsyncWorkspace,
+)
+
+
+__all__ = [
+ "AsyncResult",
+ "AsyncSession",
+ "AsyncTransaction",
+ "AsyncWorkspace",
+]
diff --git a/neo4j/_async/work/result.py b/neo4j/_async/work/result.py
new file mode 100644
index 00000000..6184d963
--- /dev/null
+++ b/neo4j/_async/work/result.py
@@ -0,0 +1,379 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from collections import deque
+from warnings import warn
+
+from ..._async_compat.util import AsyncUtil
+from ...data import DataDehydrator
+from ...work import ResultSummary
+from ..io import ConnectionErrorHandler
+
+
+class AsyncResult:
+ """A handler for the result of Cypher query execution. Instances
+ of this class are typically constructed and returned by
+ :meth:`.AyncSession.run` and :meth:`.AsyncTransaction.run`.
+ """
+
+ def __init__(self, connection, hydrant, fetch_size, on_closed,
+ on_error):
+ self._connection = ConnectionErrorHandler(connection, on_error)
+ self._hydrant = hydrant
+ self._on_closed = on_closed
+ self._metadata = None
+ self._record_buffer = deque()
+ self._summary = None
+ self._bookmark = None
+ self._raw_qid = -1
+ self._fetch_size = fetch_size
+
+ # states
+ self._discarding = False # discard the remainder of records
+ self._attached = False # attached to a connection
+ # there are still more response messages we wait for
+ self._streaming = False
+ # there ar more records available to pull from the server
+ self._has_more = False
+ # the result has been fully iterated or consumed
+ self._closed = False
+
+ @property
+ def _qid(self):
+ if self._raw_qid == self._connection.most_recent_qid:
+ return -1
+ else:
+ return self._raw_qid
+
+ async def _tx_ready_run(self, query, parameters, **kwargs):
+ # BEGIN+RUN does not carry any extra on the RUN message.
+ # BEGIN {extra}
+ # RUN "query" {parameters} {extra}
+ await self._run(
+ query, parameters, None, None, None, None, **kwargs
+ )
+
+ async def _run(
+ self, query, parameters, db, imp_user, access_mode, bookmarks,
+ **kwargs
+ ):
+ query_text = str(query) # Query or string object
+ query_metadata = getattr(query, "metadata", None)
+ query_timeout = getattr(query, "timeout", None)
+
+ parameters = DataDehydrator.fix_parameters(dict(parameters or {}, **kwargs))
+
+ self._metadata = {
+ "query": query_text,
+ "parameters": parameters,
+ "server": self._connection.server_info,
+ }
+
+ def on_attached(metadata):
+ self._metadata.update(metadata)
+ # For auto-commit there is no qid and Bolt 3 does not support qid
+ self._raw_qid = metadata.get("qid", -1)
+ if self._raw_qid != -1:
+ self._connection.most_recent_qid = self._raw_qid
+ self._keys = metadata.get("fields")
+ self._attached = True
+
+ async def on_failed_attach(metadata):
+ self._metadata.update(metadata)
+ self._attached = False
+ await AsyncUtil.callback(self._on_closed)
+
+ self._connection.run(
+ query_text,
+ parameters=parameters,
+ mode=access_mode,
+ bookmarks=bookmarks,
+ metadata=query_metadata,
+ timeout=query_timeout,
+ db=db,
+ imp_user=imp_user,
+ on_success=on_attached,
+ on_failure=on_failed_attach,
+ )
+ self._pull()
+ await self._connection.send_all()
+ await self._attach()
+
+ def _pull(self):
+ def on_records(records):
+ if not self._discarding:
+ self._record_buffer.extend(self._hydrant.hydrate_records(self._keys, records))
+
+ async def on_summary():
+ self._attached = False
+ await AsyncUtil.callback(self._on_closed)
+
+ async def on_failure(metadata):
+ self._attached = False
+ await AsyncUtil.callback(self._on_closed)
+
+ def on_success(summary_metadata):
+ self._streaming = False
+ has_more = summary_metadata.get("has_more")
+ self._has_more = bool(has_more)
+ if has_more:
+ return
+ self._metadata.update(summary_metadata)
+ self._bookmark = summary_metadata.get("bookmark")
+
+ self._connection.pull(
+ n=self._fetch_size,
+ qid=self._qid,
+ on_records=on_records,
+ on_success=on_success,
+ on_failure=on_failure,
+ on_summary=on_summary,
+ )
+ self._streaming = True
+
+ def _discard(self):
+ async def on_summary():
+ self._attached = False
+ await AsyncUtil.callback(self._on_closed)
+
+ async def on_failure(metadata):
+ self._metadata.update(metadata)
+ self._attached = False
+ await AsyncUtil.callback(self._on_closed)
+
+ def on_success(summary_metadata):
+ self._streaming = False
+ has_more = summary_metadata.get("has_more")
+ self._has_more = bool(has_more)
+ if has_more:
+ return
+ self._discarding = False
+ self._metadata.update(summary_metadata)
+ self._bookmark = summary_metadata.get("bookmark")
+
+ # This was the last page received, discard the rest
+ self._connection.discard(
+ n=-1,
+ qid=self._qid,
+ on_success=on_success,
+ on_failure=on_failure,
+ on_summary=on_summary,
+ )
+ self._streaming = True
+
+ async def __aiter__(self):
+ """Iterator returning Records.
+ :returns: Record, it is an immutable ordered collection of key-value pairs.
+ :rtype: :class:`neo4j.Record`
+ """
+ while self._record_buffer or self._attached:
+ if self._record_buffer:
+ yield self._record_buffer.popleft()
+ elif self._streaming:
+ await self._connection.fetch_message()
+ elif self._discarding:
+ self._discard()
+ await self._connection.send_all()
+ elif self._has_more:
+ self._pull()
+ await self._connection.send_all()
+
+ self._closed = True
+
+ async def _attach(self):
+ """Sets the Result object in an attached state by fetching messages from
+ the connection to the buffer.
+ """
+ if self._closed is False:
+ while self._attached is False:
+ await self._connection.fetch_message()
+
+ async def _buffer(self, n=None):
+ """Try to fill `self._record_buffer` with n records.
+
+ Might end up with more records in the buffer if the fetch size makes it
+ overshoot.
+ Might ent up with fewer records in the buffer if there are not enough
+ records available.
+ """
+ record_buffer = deque()
+ async for record in self:
+ record_buffer.append(record)
+ if n is not None and len(record_buffer) >= n:
+ break
+ self._closed = False
+ if n is None:
+ self._record_buffer = record_buffer
+ else:
+ self._record_buffer.extend(record_buffer)
+
+ async def _buffer_all(self):
+ """Sets the Result object in an detached state by fetching all records
+ from the connection to the buffer.
+ """
+ await self._buffer()
+
+ def _obtain_summary(self):
+ """Obtain the summary of this result, buffering any remaining records.
+
+ :returns: The :class:`neo4j.ResultSummary` for this result
+ """
+ if self._summary is None:
+ if self._metadata:
+ self._summary = ResultSummary(
+ self._connection.unresolved_address, **self._metadata
+ )
+ elif self._connection:
+ self._summary = ResultSummary(
+ self._connection.unresolved_address,
+ server=self._connection.server_info
+ )
+
+ return self._summary
+
+ def keys(self):
+ """The keys for the records in this result.
+
+ :returns: tuple of key names
+ :rtype: tuple
+ """
+ return self._keys
+
+ async def consume(self):
+ """Consume the remainder of this result and return a :class:`neo4j.ResultSummary`.
+
+ Example::
+
+ def create_node_tx(tx, name):
+ result = await tx.run(
+ "CREATE (n:ExampleNode { name: $name }) RETURN n", name=name
+ )
+ record = await result.single()
+ value = record.value()
+ info = await result.consume()
+ return value, info
+
+ async with driver.session() as session:
+ node_id, info = await session.write_transaction(create_node_tx, "example")
+
+ Example::
+
+ async def get_two_tx(tx):
+ result = await tx.run("UNWIND [1,2,3,4] AS x RETURN x")
+ values = []
+ async for record in result:
+ if len(values) >= 2:
+ break
+ values.append(record.values())
+ # discard the remaining records if there are any
+ info = await result.consume()
+ # use the info for logging etc.
+ return values, info
+
+ with driver.session() as session:
+ values, info = session.read_transaction(get_two_tx)
+
+ :returns: The :class:`neo4j.ResultSummary` for this result
+ """
+ if self._closed is False:
+ self._discarding = True
+ async for _ in self:
+ pass
+
+ return self._obtain_summary()
+
+ async def single(self):
+ """Obtain the next and only remaining record from this result if available else return None.
+ Calling this method always exhausts the result.
+
+ A warning is generated if more than one record is available but
+ the first of these is still returned.
+
+ :returns: the next :class:`neo4j.Record` or :const:`None` if none remain
+ :warns: if more than one record is available
+ """
+ # TODO in 5.0 replace with this code that raises an error if there's not
+ # exactly one record in the left result stream.
+ # self._buffer(2).
+ # if len(self._record_buffer) != 1:
+ # raise SomeError("Expected exactly 1 record, found %i"
+ # % len(self._record_buffer))
+ # return self._record_buffer.popleft()
+ # TODO: exhausts the result with self.consume if there are more records.
+ records = await AsyncUtil.list(self)
+ size = len(records)
+ if size == 0:
+ return None
+ if size != 1:
+ warn("Expected a result with a single record, but this result contains %d" % size)
+ return records[0]
+
+ async def peek(self):
+ """Obtain the next record from this result without consuming it.
+ This leaves the record in the buffer for further processing.
+
+ :returns: the next :class:`.Record` or :const:`None` if none remain
+ """
+ await self._buffer(1)
+ if self._record_buffer:
+ return self._record_buffer[0]
+
+ async def graph(self):
+ """Return a :class:`neo4j.graph.Graph` instance containing all the graph objects
+ in the result. After calling this method, the result becomes
+ detached, buffering all remaining records.
+
+ :returns: a result graph
+ :rtype: :class:`neo4j.graph.Graph`
+ """
+ await self._buffer_all()
+ return self._hydrant.graph
+
+ async def value(self, key=0, default=None):
+ """Helper function that return the remainder of the result as a list of values.
+
+ See :class:`neo4j.Record.value`
+
+ :param key: field to return for each remaining record. Obtain a single value from the record by index or key.
+ :param default: default value, used if the index of key is unavailable
+ :returns: list of individual values
+ :rtype: list
+ """
+ return [record.value(key, default) async for record in self]
+
+ async def values(self, *keys):
+ """Helper function that return the remainder of the result as a list of values lists.
+
+ See :class:`neo4j.Record.values`
+
+ :param keys: fields to return for each remaining record. Optionally filtering to include only certain values by index or key.
+ :returns: list of values lists
+ :rtype: list
+ """
+ return [record.values(*keys) async for record in self]
+
+ async def data(self, *keys):
+ """Helper function that return the remainder of the result as a list of dictionaries.
+
+ See :class:`neo4j.Record.data`
+
+ :param keys: fields to return for each remaining record. Optionally filtering to include only certain values by index or key.
+ :returns: list of dictionaries
+ :rtype: list
+ """
+ return [record.data(*keys) async for record in self]
diff --git a/neo4j/_async/work/session.py b/neo4j/_async/work/session.py
new file mode 100644
index 00000000..bedf2915
--- /dev/null
+++ b/neo4j/_async/work/session.py
@@ -0,0 +1,447 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import asyncio
+from logging import getLogger
+from random import random
+from time import perf_counter
+
+from ..._async_compat import async_sleep
+from ...api import (
+ READ_ACCESS,
+ WRITE_ACCESS,
+)
+from ...conf import SessionConfig
+from ...data import DataHydrator
+from ...exceptions import (
+ ClientError,
+ IncompleteCommit,
+ Neo4jError,
+ ServiceUnavailable,
+ SessionExpired,
+ TransactionError,
+ TransientError,
+)
+from ...work import Query
+from .result import AsyncResult
+from .transaction import AsyncTransaction
+from .workspace import AsyncWorkspace
+
+
+log = getLogger("neo4j")
+
+
+class AsyncSession(AsyncWorkspace):
+ """A :class:`.AsyncSession` is a logical context for transactional units
+ of work. Connections are drawn from the :class:`.AsyncDriver` connection
+ pool as required.
+
+ Session creation is a lightweight operation and sessions are not safe to
+ be used in concurrent contexts (multiple threads/coroutines).
+ Therefore, a session should generally be short-lived, and must not
+ span multiple threads/coroutines.
+
+ In general, sessions will be created and destroyed within a `with`
+ context. For example::
+
+ async with driver.session() as session:
+ result = await session.run("MATCH (n:Person) RETURN n.name AS name")
+ # do something with the result...
+
+ :param pool: connection pool instance
+ :param config: session config instance
+ """
+
+ # The current connection.
+ _connection = None
+
+ # The current :class:`.Transaction` instance, if any.
+ _transaction = None
+
+ # The current auto-transaction result, if any.
+ _auto_result = None
+
+ # The state this session is in.
+ _state_failed = False
+
+ # Session have been properly closed.
+ _closed = False
+
+ def __init__(self, pool, session_config):
+ super().__init__(pool, session_config)
+ assert isinstance(session_config, SessionConfig)
+ self._bookmarks = tuple(session_config.bookmarks)
+
+ def __del__(self):
+ if asyncio.iscoroutinefunction(self.close):
+ return
+ try:
+ self.close()
+ except (OSError, ServiceUnavailable, SessionExpired):
+ pass
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exception_type, exception_value, traceback):
+ if exception_type:
+ self._state_failed = True
+ await self.close()
+
+ async def _connect(self, access_mode):
+ if access_mode is None:
+ access_mode = self._config.default_access_mode
+ await super()._connect(access_mode)
+
+ def _collect_bookmark(self, bookmark):
+ if bookmark:
+ self._bookmarks = [bookmark]
+
+ async def _result_closed(self):
+ if self._auto_result:
+ self._collect_bookmark(self._auto_result._bookmark)
+ self._auto_result = None
+ await self._disconnect()
+
+ async def _result_error(self, _):
+ if self._auto_result:
+ self._auto_result = None
+ await self._disconnect()
+
+ async def close(self):
+ """Close the session.
+
+ This will release any borrowed resources, such as connections, and will
+ roll back any outstanding transactions.
+ """
+ if self._connection:
+ if self._auto_result:
+ if self._state_failed is False:
+ try:
+ await self._auto_result.consume()
+ self._collect_bookmark(self._auto_result._bookmark)
+ except Exception as error:
+ # TODO: Investigate potential non graceful close states
+ self._auto_result = None
+ self._state_failed = True
+
+ if self._transaction:
+ if self._transaction.closed() is False:
+ await self._transaction.rollback() # roll back the transaction if it is not closed
+ self._transaction = None
+
+ try:
+ if self._connection:
+ await self._connection.send_all()
+ await self._connection.fetch_all()
+ # TODO: Investigate potential non graceful close states
+ except Neo4jError:
+ pass
+ except TransactionError:
+ pass
+ except ServiceUnavailable:
+ pass
+ except SessionExpired:
+ pass
+ finally:
+ await self._disconnect()
+
+ self._state_failed = False
+ self._closed = True
+
+ async def run(self, query, parameters=None, **kwargs):
+ """Run a Cypher query within an auto-commit transaction.
+
+ The query is sent and the result header received
+ immediately but the :class:`neo4j.Result` content is
+ fetched lazily as consumed by the client application.
+
+ If a query is executed before a previous
+ :class:`neo4j.AsyncResult` in the same :class:`.AsyncSession` has
+ been fully consumed, the first result will be fully fetched
+ and buffered. Note therefore that the generally recommended
+ pattern of usage is to fully consume one result before
+ executing a subsequent query. If two results need to be
+ consumed in parallel, multiple :class:`.AsyncSession` objects
+ can be used as an alternative to result buffering.
+
+ For more usage details, see :meth:`.AsyncTransaction.run`.
+
+ :param query: cypher query
+ :type query: str, neo4j.Query
+ :param parameters: dictionary of parameters
+ :type parameters: dict
+ :param kwargs: additional keyword parameters
+ :returns: a new :class:`neo4j.AsyncResult` object
+ :rtype: AsyncResult
+ """
+ if not query:
+ raise ValueError("Cannot run an empty query")
+ if not isinstance(query, (str, Query)):
+ raise TypeError("query must be a string or a Query instance")
+
+ if self._transaction:
+ raise ClientError("Explicit Transaction must be handled explicitly")
+
+ if self._auto_result:
+ # This will buffer upp all records for the previous auto-transaction
+ await self._auto_result._buffer_all()
+
+ if not self._connection:
+ await self._connect(self._config.default_access_mode)
+ cx = self._connection
+ protocol_version = cx.PROTOCOL_VERSION
+ server_info = cx.server_info
+
+ hydrant = DataHydrator()
+
+ self._auto_result = AsyncResult(
+ cx, hydrant, self._config.fetch_size, self._result_closed,
+ self._result_error
+ )
+ await self._auto_result._run(
+ query, parameters, self._config.database,
+ self._config.impersonated_user, self._config.default_access_mode,
+ self._bookmarks, **kwargs
+ )
+
+ return self._auto_result
+
+ async def last_bookmark(self):
+ """Return the bookmark received following the last completed transaction.
+ Note: For auto-transaction (Session.run) this will trigger an consume for the current result.
+
+ :returns: :class:`neo4j.Bookmark` object
+ """
+ # The set of bookmarks to be passed into the next transaction.
+
+ if self._auto_result:
+ await self._auto_result.consume()
+
+ if self._transaction and self._transaction._closed:
+ self._collect_bookmark(self._transaction._bookmark)
+ self._transaction = None
+
+ if len(self._bookmarks):
+ return self._bookmarks[len(self._bookmarks)-1]
+ return None
+
+ async def _transaction_closed_handler(self):
+ if self._transaction:
+ self._collect_bookmark(self._transaction._bookmark)
+ self._transaction = None
+ await self._disconnect()
+
+ async def _transaction_error_handler(self, _):
+ if self._transaction:
+ self._transaction = None
+ await self._disconnect()
+
+ async def _open_transaction(self, *, access_mode, metadata=None,
+ timeout=None):
+ await self._connect(access_mode=access_mode)
+ self._transaction = AsyncTransaction(
+ self._connection, self._config.fetch_size,
+ self._transaction_closed_handler,
+ self._transaction_error_handler
+ )
+ await self._transaction._begin(
+ self._config.database, self._config.impersonated_user,
+ self._bookmarks, access_mode, metadata, timeout
+ )
+
+ async def begin_transaction(self, metadata=None, timeout=None):
+ """ Begin a new unmanaged transaction. Creates a new :class:`.AsyncTransaction` within this session.
+ At most one transaction may exist in a session at any point in time.
+ To maintain multiple concurrent transactions, use multiple concurrent sessions.
+
+ Note: For auto-transaction (AsyncSession.run) this will trigger an consume for the current result.
+
+ :param metadata:
+ a dictionary with metadata.
+ Specified metadata will be attached to the executing transaction and visible in the output of ``dbms.listQueries`` and ``dbms.listTransactions`` procedures.
+ It will also get logged to the ``query.log``.
+ This functionality makes it easier to tag transactions and is equivalent to ``dbms.setTXMetaData`` procedure, see https://neo4j.com/docs/operations-manual/current/reference/procedures/ for procedure reference.
+ :type metadata: dict
+
+ :param timeout:
+ the transaction timeout in seconds.
+ Transactions that execute longer than the configured timeout will be terminated by the database.
+ This functionality allows to limit query/transaction execution time.
+ Specified timeout overrides the default timeout configured in the database using ``dbms.transaction.timeout`` setting.
+ Value should not represent a duration of zero or negative duration.
+ :type timeout: int
+
+ :returns: A new transaction instance.
+ :rtype: AsyncTransaction
+
+ :raises TransactionError: :class:`neo4j.exceptions.TransactionError` if a transaction is already open.
+ """
+ # TODO: Implement TransactionConfig consumption
+
+ if self._auto_result:
+ self._auto_result.consume()
+
+ if self._transaction:
+ raise TransactionError("Explicit transaction already open")
+
+ await self._open_transaction(
+ access_mode=self._config.default_access_mode, metadata=metadata,
+ timeout=timeout
+ )
+
+ return self._transaction
+
+ async def _run_transaction(
+ self, access_mode, transaction_function, *args, **kwargs
+ ):
+ if not callable(transaction_function):
+ raise TypeError("Unit of work is not callable")
+
+ metadata = getattr(transaction_function, "metadata", None)
+ timeout = getattr(transaction_function, "timeout", None)
+
+ retry_delay = retry_delay_generator(self._config.initial_retry_delay, self._config.retry_delay_multiplier, self._config.retry_delay_jitter_factor)
+
+ errors = []
+
+ t0 = -1 # Timer
+
+ while True:
+ try:
+ await self._open_transaction(
+ access_mode=access_mode, metadata=metadata,
+ timeout=timeout
+ )
+ tx = self._transaction
+ try:
+ result = await transaction_function(tx, *args, **kwargs)
+ except Exception:
+ await tx.close()
+ raise
+ else:
+ await tx.commit()
+ except IncompleteCommit:
+ raise
+ except (ServiceUnavailable, SessionExpired) as error:
+ errors.append(error)
+ await self._disconnect()
+ except TransientError as transient_error:
+ if not transient_error.is_retriable():
+ raise
+ errors.append(transient_error)
+ else:
+ return result
+ if t0 == -1:
+ t0 = perf_counter() # The timer should be started after the first attempt
+ t1 = perf_counter()
+ if t1 - t0 > self._config.max_transaction_retry_time:
+ break
+ delay = next(retry_delay)
+ log.warning("Transaction failed and will be retried in {}s ({})".format(delay, "; ".join(errors[-1].args)))
+ await async_sleep(delay)
+
+ if errors:
+ raise errors[-1]
+ else:
+ raise ServiceUnavailable("Transaction failed")
+
+ async def read_transaction(self, transaction_function, *args, **kwargs):
+ """Execute a unit of work in a managed read transaction.
+ This transaction will automatically be committed unless an exception is thrown during query execution or by the user code.
+ Note, that this function perform retries and that the supplied `transaction_function` might get invoked more than once.
+
+ Managed transactions should not generally be explicitly committed
+ (via ``await tx.commit()``).
+
+ Example::
+
+ async def do_cypher_tx(tx, cypher):
+ result = await tx.run(cypher)
+ values = [record.values() async for record in result]
+ return values
+
+ async with driver.session() as session:
+ values = await session.read_transaction(do_cypher_tx, "RETURN 1 AS x")
+
+ Example::
+
+ async def get_two_tx(tx):
+ result = await tx.run("UNWIND [1,2,3,4] AS x RETURN x")
+ values = []
+ async for record in result:
+ if len(values) >= 2:
+ break
+ values.append(record.values())
+ # discard the remaining records if there are any
+ info = await result.consume()
+ # use the info for logging etc.
+ return values
+
+ async with driver.session() as session:
+ values = await session.read_transaction(get_two_tx)
+
+ :param transaction_function: a function that takes a transaction as an
+ argument and does work with the transaction.
+ `transaction_function(tx, *args, **kwargs)` where `tx` is a
+ :class:`.AsyncTransaction`.
+ :param args: arguments for the `transaction_function`
+ :param kwargs: key word arguments for the `transaction_function`
+ :return: a result as returned by the given unit of work
+ """
+ return await self._run_transaction(
+ READ_ACCESS, transaction_function, *args, **kwargs
+ )
+
+ async def write_transaction(self, transaction_function, *args, **kwargs):
+ """Execute a unit of work in a managed write transaction.
+ This transaction will automatically be committed unless an exception is thrown during query execution or by the user code.
+ Note, that this function perform retries and that the supplied `transaction_function` might get invoked more than once.
+
+ Managed transactions should not generally be explicitly committed (via tx.commit()).
+
+ Example::
+
+ async def create_node_tx(tx, name):
+ query = "CREATE (n:NodeExample { name: $name }) RETURN id(n) AS node_id"
+ result = await tx.run(query, name=name)
+ record = await result.single()
+ return record["node_id"]
+
+ async with driver.session() as session:
+ node_id = await session.write_transaction(create_node_tx, "example")
+
+ :param transaction_function: a function that takes a transaction as an
+ argument and does work with the transaction.
+ `transaction_function(tx, *args, **kwargs)` where `tx` is a
+ :class:`.AsyncTransaction`.
+ :param args: key word arguments for the `transaction_function`
+ :param kwargs: key word arguments for the `transaction_function`
+ :return: a result as returned by the given unit of work
+ """
+ return await self._run_transaction(
+ WRITE_ACCESS, transaction_function, *args, **kwargs
+ )
+
+
+def retry_delay_generator(initial_delay, multiplier, jitter_factor):
+ delay = initial_delay
+ while True:
+ jitter = jitter_factor * delay
+ yield delay - jitter + (2 * jitter * random())
+ delay *= multiplier
diff --git a/neo4j/_async/work/transaction.py b/neo4j/_async/work/transaction.py
new file mode 100644
index 00000000..a2ae32f5
--- /dev/null
+++ b/neo4j/_async/work/transaction.py
@@ -0,0 +1,199 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from ..._async_compat.util import AsyncUtil
+from ...data import DataHydrator
+from ...exceptions import TransactionError
+from ...work import Query
+from ..io import ConnectionErrorHandler
+from .result import AsyncResult
+
+
+class AsyncTransaction:
+ """ Container for multiple Cypher queries to be executed within a single
+ context. asynctransactions can be used within a :py:const:`async with`
+ block where the transaction is committed or rolled back on based on
+ whether an exception is raised::
+
+ async with session.begin_transaction() as tx:
+ ...
+
+ """
+
+ def __init__(self, connection, fetch_size, on_closed, on_error):
+ self._connection = connection
+ self._error_handling_connection = ConnectionErrorHandler(
+ connection, self._error_handler
+ )
+ self._bookmark = None
+ self._results = []
+ self._closed = False
+ self._last_error = None
+ self._fetch_size = fetch_size
+ self._on_closed = on_closed
+ self._on_error = on_error
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exception_type, exception_value, traceback):
+ if self._closed:
+ return
+ success = not bool(exception_type)
+ if success:
+ await self.commit()
+ await self.close()
+
+ async def _begin(
+ self, database, imp_user, bookmarks, access_mode, metadata, timeout
+ ):
+ self._connection.begin(
+ bookmarks=bookmarks, metadata=metadata, timeout=timeout,
+ mode=access_mode, db=database, imp_user=imp_user
+ )
+ await self._error_handling_connection.send_all()
+ await self._error_handling_connection.fetch_all()
+
+ async def _result_on_closed_handler(self):
+ pass
+
+ async def _error_handler(self, exc):
+ self._last_error = exc
+ await AsyncUtil.callback(self._on_error, exc)
+
+ async def _consume_results(self):
+ for result in self._results:
+ await result.consume()
+ self._results = []
+
+ async def run(self, query, parameters=None, **kwparameters):
+ """ Run a Cypher query within the context of this transaction.
+
+ Cypher is typically expressed as a query template plus a
+ set of named parameters. In Python, parameters may be expressed
+ through a dictionary of parameters, through individual parameter
+ arguments, or as a mixture of both. For example, the `run`
+ queries below are all equivalent::
+
+ >>> query = "CREATE (a:Person { name: $name, age: $age })"
+ >>> result = await tx.run(query, {"name": "Alice", "age": 33})
+ >>> result = await tx.run(query, {"name": "Alice"}, age=33)
+ >>> result = await tx.run(query, name="Alice", age=33)
+
+ Parameter values can be of any type supported by the Neo4j type
+ system. In Python, this includes :class:`bool`, :class:`int`,
+ :class:`str`, :class:`list` and :class:`dict`. Note however that
+ :class:`list` properties must be homogenous.
+
+ :param query: cypher query
+ :type query: str
+ :param parameters: dictionary of parameters
+ :type parameters: dict
+ :param kwparameters: additional keyword parameters
+ :returns: a new :class:`neo4j.Result` object
+ :rtype: :class:`neo4j.Result`
+ :raise TransactionError: if the transaction is already closed
+ """
+ if isinstance(query, Query):
+ raise ValueError("Query object is only supported for session.run")
+
+ if self._closed:
+ raise TransactionError(self, "Transaction closed")
+ if self._last_error:
+ raise TransactionError(self,
+ "Transaction failed") from self._last_error
+
+ if (self._results
+ and self._connection.supports_multiple_results is False):
+ # Bolt 3 Support
+ # Buffer up all records for the previous Result because it does not
+ # have any qid to fetch in batches.
+ await self._results[-1]._buffer_all()
+
+ result = AsyncResult(
+ self._connection, DataHydrator(), self._fetch_size,
+ self._result_on_closed_handler,
+ self._error_handler
+ )
+ self._results.append(result)
+
+ await result._tx_ready_run(query, parameters, **kwparameters)
+
+ return result
+
+ async def commit(self):
+ """Mark this transaction as successful and close in order to trigger a COMMIT.
+
+ :raise TransactionError: if the transaction is already closed
+ """
+ if self._closed:
+ raise TransactionError(self, "Transaction closed")
+ if self._last_error:
+ raise TransactionError(self,
+ "Transaction failed") from self._last_error
+
+ metadata = {}
+ try:
+ # DISCARD pending records then do a commit.
+ await self._consume_results()
+ self._connection.commit(on_success=metadata.update)
+ await self._connection.send_all()
+ await self._connection.fetch_all()
+ self._bookmark = metadata.get("bookmark")
+ finally:
+ self._closed = True
+ await AsyncUtil.callback(self._on_closed)
+
+ return self._bookmark
+
+ async def rollback(self):
+ """Mark this transaction as unsuccessful and close in order to trigger a ROLLBACK.
+
+ :raise TransactionError: if the transaction is already closed
+ """
+ if self._closed:
+ raise TransactionError(self, "Transaction closed")
+
+ metadata = {}
+ try:
+ if not (self._connection.defunct()
+ or self._connection.closed()
+ or self._connection.is_reset):
+ # DISCARD pending records then do a rollback.
+ await self._consume_results()
+ self._connection.rollback(on_success=metadata.update)
+ await self._connection.send_all()
+ await self._connection.fetch_all()
+ finally:
+ self._closed = True
+ await AsyncUtil.callback(self._on_closed)
+
+ async def close(self):
+ """Close this transaction, triggering a ROLLBACK if not closed.
+ """
+ if self._closed:
+ return
+ await self.rollback()
+
+ def closed(self):
+ """Indicator to show whether the transaction has been closed.
+
+ :return: :const:`True` if closed, :const:`False` otherwise.
+ :rtype: bool
+ """
+ return self._closed
diff --git a/neo4j/_async/work/workspace.py b/neo4j/_async/work/workspace.py
new file mode 100644
index 00000000..5dee191a
--- /dev/null
+++ b/neo4j/_async/work/workspace.py
@@ -0,0 +1,102 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import asyncio
+
+from ...conf import WorkspaceConfig
+from ...exceptions import ServiceUnavailable
+from ..io import AsyncNeo4jPool
+
+
+class AsyncWorkspace:
+
+ def __init__(self, pool, config):
+ assert isinstance(config, WorkspaceConfig)
+ self._pool = pool
+ self._config = config
+ self._connection = None
+ self._connection_access_mode = None
+ # Sessions are supposed to cache the database on which to operate.
+ self._cached_database = False
+ self._bookmarks = None
+
+ def __del__(self):
+ if asyncio.iscoroutinefunction(self.close):
+ return
+ try:
+ self.close()
+ except OSError:
+ pass
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ await self.close()
+
+ def _set_cached_database(self, database):
+ self._cached_database = True
+ self._config.database = database
+
+ async def _connect(self, access_mode):
+ if self._connection:
+ # TODO: Investigate this
+ # log.warning("FIXME: should always disconnect before connect")
+ await self._connection.send_all()
+ await self._connection.fetch_all()
+ await self._disconnect()
+ if not self._cached_database:
+ if (self._config.database is not None
+ or not isinstance(self._pool, AsyncNeo4jPool)):
+ self._set_cached_database(self._config.database)
+ else:
+ # This is the first time we open a connection to a server in a
+ # cluster environment for this session without explicitly
+ # configured database. Hence, we request a routing table update
+ # to try to fetch the home database. If provided by the server,
+ # we shall use this database explicitly for all subsequent
+ # actions within this session.
+ await self._pool.update_routing_table(
+ database=self._config.database,
+ imp_user=self._config.impersonated_user,
+ bookmarks=self._bookmarks,
+ database_callback=self._set_cached_database
+ )
+ self._connection = await self._pool.acquire(
+ access_mode=access_mode,
+ timeout=self._config.connection_acquisition_timeout,
+ database=self._config.database,
+ bookmarks=self._bookmarks
+ )
+ self._connection_access_mode = access_mode
+
+ async def _disconnect(self, sync=False):
+ if self._connection:
+ if sync:
+ try:
+ await self._connection.send_all()
+ await self._connection.fetch_all()
+ except ServiceUnavailable:
+ pass
+ if self._connection:
+ await self._pool.release(self._connection)
+ self._connection = None
+ self._connection_access_mode = None
+
+ async def close(self):
+ await self._disconnect(sync=True)
diff --git a/neo4j/_async_compat/__init__.py b/neo4j/_async_compat/__init__.py
new file mode 100644
index 00000000..87430ab5
--- /dev/null
+++ b/neo4j/_async_compat/__init__.py
@@ -0,0 +1,26 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from asyncio import sleep as async_sleep
+from time import sleep
+
+
+__all__ = [
+ "async_sleep",
+ "sleep",
+]
diff --git a/neo4j/_async_compat/concurrency.py b/neo4j/_async_compat/concurrency.py
new file mode 100644
index 00000000..868ad1bd
--- /dev/null
+++ b/neo4j/_async_compat/concurrency.py
@@ -0,0 +1,231 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import asyncio
+import collections
+import re
+import threading
+
+
+__all__ = [
+ "AsyncCondition",
+ "AsyncRLock",
+ "Condition",
+ "RLock",
+]
+
+
+class AsyncRLock(asyncio.Lock):
+ """Reentrant asyncio.lock
+
+ Inspired by Python's RLock implementation
+ """
+
+ _WAITERS_RE = re.compile(r"(?:\W|^)waiters[:=](\d+)(?:\W|$)")
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._owner = None
+ self._count = 0
+
+ def __repr__(self):
+ res = object.__repr__(self)
+ lock_repr = super().__repr__()
+ extra = "locked" if self._count > 0 else "unlocked"
+ extra += f" count={self._count}"
+ waiters_match = self._WAITERS_RE.search(lock_repr)
+ if waiters_match:
+ extra += f" waiters={waiters_match.group(1)}"
+ if self._owner:
+ extra += f" owner={self._owner}"
+ return f'<{res[1:-1]} [{extra}]>'
+
+ def is_owner(self, task=None):
+ if task is None:
+ task = asyncio.current_task()
+ return self._owner == task
+
+ async def _acquire(self, me):
+ if self.is_owner(task=me):
+ self._count += 1
+ return
+ await super().acquire()
+ self._owner = me
+ self._count = 1
+
+ async def acquire(self, timeout=None):
+ """Acquire the lock."""
+ me = asyncio.current_task()
+ if timeout is None:
+ return await self._acquire(me)
+ return await asyncio.wait_for(self._acquire(me), timeout)
+
+ __aenter__ = acquire
+
+ def _release(self, me):
+ if not self.is_owner(task=me):
+ if self._owner is None:
+ raise RuntimeError("Cannot release un-acquired lock.")
+ raise RuntimeError("Cannot release foreign lock.")
+ self._count -= 1
+ if not self._count:
+ self._owner = None
+ super().release()
+
+ def release(self):
+ """Release the lock"""
+ me = asyncio.current_task()
+ return self._release(me)
+
+ async def __aexit__(self, t, v, tb):
+ self.release()
+
+
+# copied and modified from asyncio.locks (3.7)
+# to add support for `.wait(timeout)`
+class AsyncCondition:
+ """Asynchronous equivalent to threading.Condition.
+
+ This class implements condition variable objects. A condition variable
+ allows one or more coroutines to wait until they are notified by another
+ coroutine.
+
+ A new Lock object is created and used as the underlying lock.
+ """
+
+ def __init__(self, lock=None, *, loop=None):
+ if loop is not None:
+ self._loop = loop
+ else:
+ self._loop = asyncio.get_event_loop()
+
+ if lock is None:
+ lock = asyncio.Lock(loop=self._loop)
+ elif (hasattr(lock, "_loop")
+ and lock._loop is not None
+ and lock._loop is not self._loop):
+ raise ValueError("loop argument must agree with lock")
+
+ self._lock = lock
+ # Export the lock's locked(), acquire() and release() methods.
+ self.locked = lock.locked
+ self.acquire = lock.acquire
+ self.release = lock.release
+
+ self._waiters = collections.deque()
+
+ async def __aenter__(self):
+ await self.acquire()
+ # We have no use for the "as ..." clause in the with
+ # statement for locks.
+ return None
+
+ async def __aexit__(self, exc_type, exc, tb):
+ self.release()
+
+ def __repr__(self):
+ res = super().__repr__()
+ extra = 'locked' if self.locked() else 'unlocked'
+ if self._waiters:
+ extra = f'{extra}, waiters:{len(self._waiters)}'
+ return f'<{res[1:-1]} [{extra}]>'
+
+ async def _wait(self, me=None):
+ """Wait until notified.
+
+ If the calling coroutine has not acquired the lock when this
+ method is called, a RuntimeError is raised.
+
+ This method releases the underlying lock, and then blocks
+ until it is awakened by a notify() or notify_all() call for
+ the same condition variable in another coroutine. Once
+ awakened, it re-acquires the lock and returns True.
+ """
+ if not self.locked():
+ raise RuntimeError('cannot wait on un-acquired lock')
+
+ if isinstance(self._lock, AsyncRLock):
+ self._lock._release(me)
+ else:
+ self._lock.release()
+ try:
+ fut = self._loop.create_future()
+ self._waiters.append(fut)
+ try:
+ await fut
+ return True
+ finally:
+ self._waiters.remove(fut)
+
+ finally:
+ # Must reacquire lock even if wait is cancelled
+ cancelled = False
+ while True:
+ try:
+ if isinstance(self._lock, AsyncRLock):
+ await self._lock._acquire(me)
+ else:
+ await self._lock.acquire()
+ break
+ except asyncio.CancelledError:
+ cancelled = True
+
+ if cancelled:
+ raise asyncio.CancelledError
+
+ async def wait(self, timeout=None):
+ if not timeout:
+ return await self._wait()
+ me = asyncio.current_task()
+ return await asyncio.wait_for(self._wait(me), timeout)
+
+ def notify(self, n=1):
+ """By default, wake up one coroutine waiting on this condition, if any.
+ If the calling coroutine has not acquired the lock when this method
+ is called, a RuntimeError is raised.
+
+ This method wakes up at most n of the coroutines waiting for the
+ condition variable; it is a no-op if no coroutines are waiting.
+
+ Note: an awakened coroutine does not actually return from its
+ wait() call until it can reacquire the lock. Since notify() does
+ not release the lock, its caller should.
+ """
+ if not self.locked():
+ raise RuntimeError('cannot notify on un-acquired lock')
+
+ idx = 0
+ for fut in self._waiters:
+ if idx >= n:
+ break
+
+ if not fut.done():
+ idx += 1
+ fut.set_result(False)
+
+ def notify_all(self):
+ """Wake up all threads waiting on this condition. This method acts
+ like notify(), but wakes up all waiting threads instead of one. If the
+ calling thread has not acquired the lock when this method is called,
+ a RuntimeError is raised.
+ """
+ self.notify(len(self._waiters))
+
+
+Condition = threading.Condition
+RLock = threading.RLock
diff --git a/neo4j/_async_compat/network/__init__.py b/neo4j/_async_compat/network/__init__.py
new file mode 100644
index 00000000..c2611569
--- /dev/null
+++ b/neo4j/_async_compat/network/__init__.py
@@ -0,0 +1,34 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .bolt_socket import (
+ AsyncBoltSocket,
+ BoltSocket,
+)
+from .util import (
+ AsyncNetworkUtil,
+ NetworkUtil,
+)
+
+
+__all__ = [
+ "AsyncBoltSocket",
+ "AsyncNetworkUtil",
+ "BoltSocket",
+ "NetworkUtil",
+]
diff --git a/neo4j/_async_compat/network/bolt_socket.py b/neo4j/_async_compat/network/bolt_socket.py
new file mode 100644
index 00000000..1af0793a
--- /dev/null
+++ b/neo4j/_async_compat/network/bolt_socket.py
@@ -0,0 +1,539 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import asyncio
+import logging
+import selectors
+import socket
+from socket import (
+ AF_INET,
+ AF_INET6,
+ SHUT_RDWR,
+ SO_KEEPALIVE,
+ socket,
+ SOL_SOCKET,
+ timeout as SocketTimeout,
+)
+from ssl import (
+ CertificateError,
+ HAS_SNI,
+ SSLError,
+)
+import struct
+
+from ... import addressing
+from ..._exceptions import (
+ BoltError,
+ BoltProtocolError,
+ BoltSecurityError,
+)
+from ...exceptions import (
+ DriverError,
+ ServiceUnavailable,
+)
+from .util import (
+ AsyncNetworkUtil,
+ NetworkUtil,
+)
+
+
+log = logging.getLogger("neo4j")
+
+
+class AsyncBoltSocket:
+ Bolt = None
+
+ def __init__(self, reader, protocol, writer): # , loop):
+ self._reader = reader # type: asyncio.StreamReader
+ self._protocol = protocol # type: asyncio.StreamReaderProtocol
+ self._writer = writer # type: asyncio.StreamWriter
+ # self._loop = loop # type: asyncio.BaseEventLoop
+ # 0 - non-blocking
+ # None infinitely blocking
+ # int - seconds to wait for data
+ self._timeout = None
+
+ async def _wait_for_io(self, io_fut):
+ if self._timeout is not None and self._timeout <= 0:
+ # give the io-operation time for one loop cycle to do its thing
+ await asyncio.sleep(0)
+ try:
+ return await asyncio.wait_for(io_fut, self._timeout)
+ except asyncio.TimeoutError:
+ raise SocketTimeout
+
+ @property
+ def _socket(self) -> socket:
+ return self._writer.transport.get_extra_info("socket")
+
+ def getsockname(self):
+ return self._writer.transport.get_extra_info("sockname")
+
+ def getpeername(self):
+ return self._writer.transport.get_extra_info("peername")
+
+ def getpeercert(self, *args, **kwargs):
+ return self._writer.transport.get_extra_info("ssl_object")\
+ .getpeercert(*args, **kwargs)
+
+ def gettimeout(self):
+ return self._timeout
+
+ def settimeout(self, timeout):
+ if timeout is None:
+ self._timeout = timeout
+ else:
+ assert timeout >= 0
+ self._timeout = timeout
+
+ async def recv(self, n):
+ io_fut = self._reader.read(n)
+ return await self._wait_for_io(io_fut)
+
+ async def recv_into(self, buffer, nbytes):
+ # FIXME: not particularly memory or time efficient
+ io_fut = self._reader.read(nbytes)
+ res = await self._wait_for_io(io_fut)
+ buffer[:len(res)] = res
+ return len(res)
+
+ async def sendall(self, data):
+ self._writer.write(data)
+ io_fut = self._writer.drain()
+ return await self._wait_for_io(io_fut)
+
+ def close(self):
+ self._writer.close()
+
+ @classmethod
+ async def _connect_secure(cls, resolved_address, timeout, keep_alive, ssl):
+ """
+
+ :param resolved_address:
+ :param timeout: seconds
+ :param keep_alive: True or False
+ :param ssl: SSLContext or None
+
+ :return: AsyncBoltSocket object
+ """
+
+ loop = asyncio.get_event_loop()
+ s = None
+
+ # TODO: tomorrow me: fix this mess
+ try:
+ if len(resolved_address) == 2:
+ s = socket(AF_INET)
+ elif len(resolved_address) == 4:
+ s = socket(AF_INET6)
+ else:
+ raise ValueError(
+ "Unsupported address {!r}".format(resolved_address))
+ s.setblocking(False) # asyncio + blocking = no-no!
+ log.debug("[#0000] C: %s", resolved_address)
+ await asyncio.wait_for(
+ loop.sock_connect(s, resolved_address),
+ timeout
+ )
+
+ keep_alive = 1 if keep_alive else 0
+ s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive)
+
+ ssl_kwargs = {}
+
+ if ssl is not None:
+ hostname = resolved_address.host_name or None
+ ssl_kwargs.update(
+ ssl=ssl, server_hostname=hostname if HAS_SNI else None
+ )
+
+ reader = asyncio.StreamReader(
+ limit=2 ** 16, # 64 KiB,
+ loop=loop
+ )
+ protocol = asyncio.StreamReaderProtocol(reader, loop=loop)
+ transport, _ = await loop.create_connection(
+ lambda: protocol, sock=s, **ssl_kwargs
+ )
+ writer = asyncio.StreamWriter(transport, protocol, reader, loop)
+
+ if ssl is not None:
+ # Check that the server provides a certificate
+ der_encoded_server_certificate = \
+ transport.get_extra_info("ssl_object").getpeercert(
+ binary_form=True)
+ if der_encoded_server_certificate is None:
+ local_port = s.getsockname()[1]
+ raise BoltProtocolError(
+ "When using an encrypted socket, the server should "
+ "always provide a certificate",
+ address=(resolved_address.host_name, local_port)
+ )
+
+ return cls(reader, protocol, writer)
+
+ except asyncio.TimeoutError:
+ log.debug("[#0000] C: %s", resolved_address)
+ log.debug("[#0000] C: %s", resolved_address)
+ if s:
+ cls.close_socket(s)
+ raise ServiceUnavailable(
+ "Timed out trying to establish connection to {!r}".format(
+ resolved_address))
+ except (SSLError, CertificateError) as error:
+ local_port = s.getsockname()[1]
+ raise BoltSecurityError(
+ message="Failed to establish encrypted connection.",
+ address=(resolved_address.host_name, local_port)
+ ) from error
+ except OSError as error:
+ log.debug("[#0000] C: %s %s", type(error).__name__,
+ " ".join(map(repr, error.args)))
+ log.debug("[#0000] C: %s", resolved_address)
+ s.close()
+ raise ServiceUnavailable(
+ "Failed to establish connection to {!r} (reason {})".format(
+ resolved_address, error))
+
+ async def _handshake(self, resolved_address):
+ """
+
+ :param s: Socket
+ :param resolved_address:
+
+ :return: (socket, version, client_handshake, server_response_data)
+ """
+ local_port = self.getsockname()[1]
+
+ # TODO: Optimize logging code
+ handshake = self.Bolt.get_handshake()
+ handshake = struct.unpack(">16B", handshake)
+ handshake = [handshake[i:i + 4] for i in range(0, len(handshake), 4)]
+
+ supported_versions = [
+ ("0x%02X%02X%02X%02X" % (vx[0], vx[1], vx[2], vx[3])) for vx in
+ handshake]
+
+ log.debug("[#%04X] C: 0x%08X", local_port,
+ int.from_bytes(self.Bolt.MAGIC_PREAMBLE, byteorder="big"))
+ log.debug("[#%04X] C: %s %s %s %s", local_port,
+ *supported_versions)
+
+ data = self.Bolt.MAGIC_PREAMBLE + self.Bolt.get_handshake()
+ await self.sendall(data)
+
+ # Handle the handshake response
+ original_timeout = self.gettimeout()
+ if original_timeout is not None:
+ self.settimeout(original_timeout + 1)
+ try:
+ data = await self.recv(4)
+ except OSError:
+ raise ServiceUnavailable(
+ "Failed to read any data from server {!r} "
+ "after connected".format(resolved_address))
+ finally:
+ self.settimeout(original_timeout)
+ data_size = len(data)
+ if data_size == 0:
+ # If no data is returned after a successful select
+ # response, the server has closed the connection
+ log.debug("[#%04X] S: ", local_port)
+ self.close()
+ raise ServiceUnavailable(
+ "Connection to {address} closed without handshake response".format(
+ address=resolved_address))
+ if data_size != 4:
+ # Some garbled data has been received
+ log.debug("[#%04X] S: @*#!", local_port)
+ self.close()
+ raise BoltProtocolError(
+ "Expected four byte Bolt handshake response from %r, received %r instead; check for incorrect port number" % (
+ resolved_address, data), address=resolved_address)
+ elif data == b"HTTP":
+ log.debug("[#%04X] S: ", local_port)
+ self.close()
+ raise ServiceUnavailable(
+ "Cannot to connect to Bolt service on {!r} "
+ "(looks like HTTP)".format(resolved_address))
+ agreed_version = data[-1], data[-2]
+ log.debug("[#%04X] S: 0x%06X%02X", local_port,
+ agreed_version[1], agreed_version[0])
+ return self, agreed_version, handshake, data
+
+ @classmethod
+ def close_socket(cls, socket_):
+ if isinstance(socket_, socket):
+ try:
+ socket_.shutdown(SHUT_RDWR)
+ socket_.close()
+ except OSError:
+ pass
+ else:
+ socket_.close()
+
+ @classmethod
+ async def connect(cls, address, *, timeout, custom_resolver, ssl_context,
+ keep_alive):
+ """ Connect and perform a handshake and return a valid Connection object,
+ assuming a protocol version can be agreed.
+ """
+ errors = []
+ failed_addresses = []
+ # Establish a connection to the host and port specified
+ # Catches refused connections see:
+ # https://docs.python.org/2/library/errno.html
+
+ resolved_addresses = AsyncNetworkUtil.resolve_address(
+ addressing.Address(address), resolver=custom_resolver
+ )
+ async for resolved_address in resolved_addresses:
+ s = None
+ try:
+ s = await cls._connect_secure(
+ resolved_address, timeout, keep_alive, ssl_context
+ )
+ return await s._handshake(resolved_address)
+ except (BoltError, DriverError, OSError) as error:
+ try:
+ local_port = s.getsockname()[1]
+ except (OSError, AttributeError, TypeError):
+ local_port = 0
+ err_str = error.__class__.__name__
+ if str(error):
+ err_str += ": " + str(error)
+ log.debug("[#%04X] C: %s", local_port,
+ err_str)
+ if s:
+ cls.close_socket(s)
+ errors.append(error)
+ failed_addresses.append(resolved_address)
+ except Exception:
+ if s:
+ cls.close_socket(s)
+ raise
+ if not errors:
+ raise ServiceUnavailable(
+ "Couldn't connect to %s (resolved to %s)" % (
+ str(address), tuple(map(str, failed_addresses)))
+ )
+ else:
+ raise ServiceUnavailable(
+ "Couldn't connect to %s (resolved to %s):\n%s" % (
+ str(address), tuple(map(str, failed_addresses)),
+ "\n".join(map(str, errors))
+ )
+ ) from errors[0]
+
+
+class BoltSocket:
+ Bolt = None
+
+ @classmethod
+ def _connect(cls, resolved_address, timeout, keep_alive):
+ """
+
+ :param resolved_address:
+ :param timeout: seconds
+ :param keep_alive: True or False
+ :return: socket object
+ """
+
+ s = None # The socket
+
+ try:
+ if len(resolved_address) == 2:
+ s = socket(AF_INET)
+ elif len(resolved_address) == 4:
+ s = socket(AF_INET6)
+ else:
+ raise ValueError(
+ "Unsupported address {!r}".format(resolved_address))
+ t = s.gettimeout()
+ if timeout:
+ s.settimeout(timeout)
+ log.debug("[#0000] C: %s", resolved_address)
+ s.connect(resolved_address)
+ s.settimeout(t)
+ keep_alive = 1 if keep_alive else 0
+ s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive)
+ return s
+ except SocketTimeout:
+ log.debug("[#0000] C: %s", resolved_address)
+ log.debug("[#0000] C: %s", resolved_address)
+ cls.close_socket(s)
+ raise ServiceUnavailable(
+ "Timed out trying to establish connection to {!r}".format(
+ resolved_address))
+ except OSError as error:
+ log.debug("[#0000] C: %s %s", type(error).__name__,
+ " ".join(map(repr, error.args)))
+ log.debug("[#0000] C: %s", resolved_address)
+ s.close()
+ raise ServiceUnavailable(
+ "Failed to establish connection to {!r} (reason {})".format(
+ resolved_address, error))
+
+ @classmethod
+ def _secure(cls, s, host, ssl_context):
+ local_port = s.getsockname()[1]
+ # Secure the connection if an SSL context has been provided
+ if ssl_context:
+ log.debug("[#%04X] C: %s", local_port, host)
+ try:
+ sni_host = host if HAS_SNI and host else None
+ s = ssl_context.wrap_socket(s, server_hostname=sni_host)
+ except (OSError, SSLError, CertificateError) as cause:
+ raise BoltSecurityError(
+ message="Failed to establish encrypted connection.",
+ address=(host, local_port)
+ ) from cause
+ # Check that the server provides a certificate
+ der_encoded_server_certificate = s.getpeercert(binary_form=True)
+ if der_encoded_server_certificate is None:
+ raise BoltProtocolError(
+ "When using an encrypted socket, the server should always "
+ "provide a certificate", address=(host, local_port)
+ )
+ return s
+ return s
+
+ @classmethod
+ def _handshake(cls, s, resolved_address):
+ """
+
+ :param s: Socket
+ :param resolved_address:
+
+ :return: (socket, version, client_handshake, server_response_data)
+ """
+ local_port = s.getsockname()[1]
+
+ # TODO: Optimize logging code
+ handshake = cls.Bolt.get_handshake()
+ handshake = struct.unpack(">16B", handshake)
+ handshake = [handshake[i:i + 4] for i in range(0, len(handshake), 4)]
+
+ supported_versions = [
+ ("0x%02X%02X%02X%02X" % (vx[0], vx[1], vx[2], vx[3])) for vx in
+ handshake]
+
+ log.debug("[#%04X] C: 0x%08X", local_port,
+ int.from_bytes(cls.Bolt.MAGIC_PREAMBLE, byteorder="big"))
+ log.debug("[#%04X] C: %s %s %s %s", local_port,
+ *supported_versions)
+
+ data = cls.Bolt.MAGIC_PREAMBLE + cls.Bolt.get_handshake()
+ s.sendall(data)
+
+ # Handle the handshake response
+ ready_to_read = False
+ with selectors.DefaultSelector() as selector:
+ selector.register(s, selectors.EVENT_READ)
+ selector.select(1)
+ try:
+ data = s.recv(4)
+ except OSError:
+ raise ServiceUnavailable(
+ "Failed to read any data from server {!r} "
+ "after connected".format(resolved_address))
+ data_size = len(data)
+ if data_size == 0:
+ # If no data is returned after a successful select
+ # response, the server has closed the connection
+ log.debug("[#%04X] S: ", local_port)
+ BoltSocket.close_socket(s)
+ raise ServiceUnavailable(
+ "Connection to {address} closed without handshake response".format(
+ address=resolved_address))
+ if data_size != 4:
+ # Some garbled data has been received
+ log.debug("[#%04X] S: @*#!", local_port)
+ s.close()
+ raise BoltProtocolError(
+ "Expected four byte Bolt handshake response from %r, received %r instead; check for incorrect port number" % (
+ resolved_address, data), address=resolved_address)
+ elif data == b"HTTP":
+ log.debug("[#%04X] S: ", local_port)
+ BoltSocket.close_socket(s)
+ raise ServiceUnavailable(
+ "Cannot to connect to Bolt service on {!r} "
+ "(looks like HTTP)".format(resolved_address))
+ agreed_version = data[-1], data[-2]
+ log.debug("[#%04X] S: 0x%06X%02X", local_port,
+ agreed_version[1], agreed_version[0])
+ return s, agreed_version, handshake, data
+
+ @classmethod
+ def close_socket(cls, socket_):
+ try:
+ socket_.shutdown(SHUT_RDWR)
+ socket_.close()
+ except OSError:
+ pass
+
+ @classmethod
+ def connect(cls, address, *, timeout, custom_resolver, ssl_context,
+ keep_alive):
+ """ Connect and perform a handshake and return a valid Connection object,
+ assuming a protocol version can be agreed.
+ """
+ errors = []
+ # Establish a connection to the host and port specified
+ # Catches refused connections see:
+ # https://docs.python.org/2/library/errno.html
+
+ resolved_addresses = NetworkUtil.resolve_address(
+ addressing.Address(address), resolver=custom_resolver
+ )
+ for resolved_address in resolved_addresses:
+ s = None
+ try:
+ s = BoltSocket._connect(resolved_address, timeout, keep_alive)
+ s = BoltSocket._secure(s, resolved_address.host_name,
+ ssl_context)
+ return BoltSocket._handshake(s, resolved_address)
+ except (BoltError, DriverError, OSError) as error:
+ try:
+ local_port = s.getsockname()[1]
+ except (OSError, AttributeError):
+ local_port = 0
+ err_str = error.__class__.__name__
+ if str(error):
+ err_str += ": " + str(error)
+ log.debug("[#%04X] C: %s", local_port,
+ err_str)
+ if s:
+ BoltSocket.close_socket(s)
+ errors.append(error)
+ except Exception:
+ if s:
+ BoltSocket.close_socket(s)
+ raise
+ if not errors:
+ raise ServiceUnavailable(
+ "Couldn't connect to %s (resolved to %s)" % (
+ str(address), tuple(map(str, resolved_addresses)))
+ )
+ else:
+ raise ServiceUnavailable(
+ "Couldn't connect to %s (resolved to %s):\n%s" % (
+ str(address), tuple(map(str, resolved_addresses)),
+ "\n".join(map(str, errors))
+ )
+ ) from errors[0]
diff --git a/neo4j/_async_compat/network/util.py b/neo4j/_async_compat/network/util.py
new file mode 100644
index 00000000..a9243512
--- /dev/null
+++ b/neo4j/_async_compat/network/util.py
@@ -0,0 +1,150 @@
+import asyncio
+import logging
+import socket
+
+from ... import addressing
+
+
+log = logging.getLogger("neo4j")
+
+
+def _resolved_addresses_from_info(info, host_name):
+ resolved = []
+ for fam, _, _, _, addr in info:
+ if fam == socket.AF_INET6 and addr[3] != 0:
+ # skip any IPv6 addresses with a non-zero scope id
+ # as these appear to cause problems on some platforms
+ continue
+ if addr not in resolved:
+ resolved.append(addr)
+ yield addressing.ResolvedAddress(
+ addr, host_name=host_name
+ )
+
+
+class AsyncNetworkUtil:
+ @staticmethod
+ async def get_address_info(host, port, *,
+ family=0, type=0, proto=0, flags=0):
+ loop = asyncio.get_event_loop()
+ return await loop.getaddrinfo(
+ host, port, family=family, type=type, proto=proto, flags=flags
+ )
+
+ @staticmethod
+ async def _dns_resolver(address, family=0):
+ """ Regular DNS resolver. Takes an address object and optional
+ address family for filtering.
+
+ :param address:
+ :param family:
+ :return:
+ """
+ try:
+ info = await AsyncNetworkUtil.get_address_info(
+ address.host, address.port, family=family,
+ type=socket.SOCK_STREAM
+ )
+ except OSError:
+ raise ValueError("Cannot resolve address {}".format(address))
+ return _resolved_addresses_from_info(info, address.host_name)
+
+ @staticmethod
+ async def resolve_address(address, family=0, resolver=None):
+ """ Carry out domain name resolution on this Address object.
+
+ If a resolver function is supplied, and is callable, this is
+ called first, with this object as its argument. This may yield
+ multiple output addresses, which are chained into a subsequent
+ regular DNS resolution call. If no resolver function is passed,
+ the DNS resolution is carried out on the original Address
+ object.
+
+ This function returns a list of resolved Address objects.
+
+ :param address: the Address to resolve
+ :param family: optional address family to filter resolved
+ addresses by (e.g. `socket.AF_INET6`)
+ :param resolver: optional customer resolver function to be
+ called before regular DNS resolution
+ """
+ if isinstance(address, addressing.ResolvedAddress):
+ yield address
+ return
+
+ log.debug("[#0000] C: %s", address)
+ if resolver:
+ if asyncio.iscoroutinefunction(resolver):
+ resolved_addresses = await resolver(address)
+ else:
+ resolved_addresses = resolver(address)
+ for address in map(addressing.Address, resolved_addresses):
+ for resolved_address in await AsyncNetworkUtil._dns_resolver(
+ address, family=family
+ ):
+ yield resolved_address
+ else:
+ for resolved_address in await AsyncNetworkUtil._dns_resolver(
+ address, family=family
+ ):
+ yield resolved_address
+
+
+class NetworkUtil:
+ @staticmethod
+ def get_address_info(host, port, *, family=0, type=0, proto=0, flags=0):
+ return socket.getaddrinfo(host, port, family, type, proto, flags)
+
+ @staticmethod
+ def _dns_resolver(address, family=0):
+ """ Regular DNS resolver. Takes an address object and optional
+ address family for filtering.
+
+ :param address:
+ :param family:
+ :return:
+ """
+ try:
+ info = NetworkUtil.get_address_info(
+ address.host, address.port, family=family,
+ type=socket.SOCK_STREAM
+ )
+ except OSError:
+ raise ValueError("Cannot resolve address {}".format(address))
+ return _resolved_addresses_from_info(info, address.host_name)
+
+ @staticmethod
+ def resolve_address(address, family=0, resolver=None):
+ """ Carry out domain name resolution on this Address object.
+
+ If a resolver function is supplied, and is callable, this is
+ called first, with this object as its argument. This may yield
+ multiple output addresses, which are chained into a subsequent
+ regular DNS resolution call. If no resolver function is passed,
+ the DNS resolution is carried out on the original Address
+ object.
+
+ This function returns a list of resolved Address objects.
+
+ :param address: the Address to resolve
+ :param family: optional address family to filter resolved
+ addresses by (e.g. `socket.AF_INET6`)
+ :param resolver: optional customer resolver function to be
+ called before regular DNS resolution
+ """
+ if isinstance(address, addressing.ResolvedAddress):
+ yield address
+ return
+
+ addressing.log.debug("[#0000] C: %s", address)
+ if resolver:
+ for address in map(addressing.Address, resolver(address)):
+ for resolved_address in NetworkUtil._dns_resolver(
+ address, family=family
+ ):
+ yield resolved_address
+ else:
+ for resolved_address in NetworkUtil._dns_resolver(
+ address, family=family
+ ):
+ yield resolved_address
diff --git a/neo4j/_async_compat/util.py b/neo4j/_async_compat/util.py
new file mode 100644
index 00000000..ca432766
--- /dev/null
+++ b/neo4j/_async_compat/util.py
@@ -0,0 +1,63 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import inspect
+
+from ..meta import experimental
+
+
+class AsyncUtil:
+ @staticmethod
+ async def iter(it):
+ async for x in it:
+ yield x
+
+ @staticmethod
+ async def next(it):
+ return await it.__anext__()
+
+ @staticmethod
+ async def list(it):
+ return [x async for x in it]
+
+ @staticmethod
+ async def callback(cb, *args, **kwargs):
+ if callable(cb):
+ res = cb(*args, **kwargs)
+ if inspect.isawaitable(res):
+ return await res
+ return res
+
+ experimental_async = experimental
+
+
+class Util:
+ iter = iter
+ next = next
+ list = list
+
+ @staticmethod
+ def callback(cb, *args, **kwargs):
+ if callable(cb):
+ return cb(*args, **kwargs)
+
+ @staticmethod
+ def experimental_async(message):
+ def f_(f):
+ return f
+ return f_
diff --git a/neo4j/_exceptions.py b/neo4j/_exceptions.py
index 67db7f6c..a48fec26 100644
--- a/neo4j/_exceptions.py
+++ b/neo4j/_exceptions.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
diff --git a/neo4j/_sync/__init__.py b/neo4j/_sync/__init__.py
new file mode 100644
index 00000000..b81a309d
--- /dev/null
+++ b/neo4j/_sync/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/neo4j/_driver.py b/neo4j/_sync/driver.py
similarity index 58%
rename from neo4j/_driver.py
rename to neo4j/_sync/driver.py
index c8aa688b..711b99e7 100644
--- a/neo4j/_driver.py
+++ b/neo4j/_sync/driver.py
@@ -16,19 +16,132 @@
# limitations under the License.
-from .addressing import Address
-from .api import READ_ACCESS
-from .conf import (
+import asyncio
+
+from .._async_compat.util import Util
+from ..addressing import Address
+from ..api import (
+ READ_ACCESS,
+ TRUST_ALL_CERTIFICATES,
+ TRUST_SYSTEM_CA_SIGNED_CERTIFICATES,
+)
+from ..conf import (
Config,
PoolConfig,
SessionConfig,
WorkspaceConfig,
)
-from .meta import experimental
-from .work.simple import Session
+from ..meta import experimental
+
+
+class GraphDatabase:
+ """Accessor for :class:`neo4j.Driver` construction.
+ """
+
+ @classmethod
+ @Util.experimental_async(
+ "neo4j is in experimental phase. It might be removed or changed "
+ "at any time (including patch releases)."
+ )
+ def driver(cls, uri, *, auth=None, **config):
+ """Create a driver.
+
+ :param uri: the connection URI for the driver, see :ref:`uri-ref` for available URIs.
+ :param auth: the authentication details, see :ref:`auth-ref` for available authentication details.
+ :param config: driver configuration key-word arguments, see :ref:`driver-configuration-ref` for available key-word arguments.
+
+ :rtype: Neo4jDriver or BoltDriver
+ """
+
+ from ..api import (
+ DRIVER_BOLT,
+ DRIVER_NEO4j,
+ parse_neo4j_uri,
+ parse_routing_context,
+ SECURITY_TYPE_NOT_SECURE,
+ SECURITY_TYPE_SECURE,
+ SECURITY_TYPE_SELF_SIGNED_CERTIFICATE,
+ URI_SCHEME_BOLT,
+ URI_SCHEME_BOLT_SECURE,
+ URI_SCHEME_BOLT_SELF_SIGNED_CERTIFICATE,
+ URI_SCHEME_NEO4J,
+ URI_SCHEME_NEO4J_SECURE,
+ URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE,
+ )
+
+ driver_type, security_type, parsed = parse_neo4j_uri(uri)
+
+ if "trust" in config.keys():
+ if config.get("trust") not in [TRUST_ALL_CERTIFICATES, TRUST_SYSTEM_CA_SIGNED_CERTIFICATES]:
+ from neo4j.exceptions import ConfigurationError
+ raise ConfigurationError("The config setting `trust` values are {!r}".format(
+ [
+ TRUST_ALL_CERTIFICATES,
+ TRUST_SYSTEM_CA_SIGNED_CERTIFICATES,
+ ]
+ ))
+
+ if security_type in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE] and ("encrypted" in config.keys() or "trust" in config.keys()):
+ from neo4j.exceptions import ConfigurationError
+ raise ConfigurationError("The config settings 'encrypted' and 'trust' can only be used with the URI schemes {!r}. Use the other URI schemes {!r} for setting encryption settings.".format(
+ [
+ URI_SCHEME_BOLT,
+ URI_SCHEME_NEO4J,
+ ],
+ [
+ URI_SCHEME_BOLT_SELF_SIGNED_CERTIFICATE,
+ URI_SCHEME_BOLT_SECURE,
+ URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE,
+ URI_SCHEME_NEO4J_SECURE,
+ ]
+ ))
+
+ if security_type == SECURITY_TYPE_SECURE:
+ config["encrypted"] = True
+ elif security_type == SECURITY_TYPE_SELF_SIGNED_CERTIFICATE:
+ config["encrypted"] = True
+ config["trust"] = TRUST_ALL_CERTIFICATES
+
+ if driver_type == DRIVER_BOLT:
+ return cls.bolt_driver(parsed.netloc, auth=auth, **config)
+ elif driver_type == DRIVER_NEO4j:
+ routing_context = parse_routing_context(parsed.query)
+ return cls.neo4j_driver(parsed.netloc, auth=auth, routing_context=routing_context, **config)
+
+ @classmethod
+ def bolt_driver(cls, target, *, auth=None, **config):
+ """ Create a driver for direct Bolt server access that uses
+ socket I/O and thread-based concurrency.
+ """
+ from .._exceptions import (
+ BoltHandshakeError,
+ BoltSecurityError,
+ )
+
+ try:
+ return BoltDriver.open(target, auth=auth, **config)
+ except (BoltHandshakeError, BoltSecurityError) as error:
+ from neo4j.exceptions import ServiceUnavailable
+ raise ServiceUnavailable(str(error)) from error
+
+ @classmethod
+ def neo4j_driver(cls, *targets, auth=None, routing_context=None, **config):
+ """ Create a driver for routing-capable Neo4j service access
+ that uses socket I/O and thread-based concurrency.
+ """
+ from neo4j._exceptions import (
+ BoltHandshakeError,
+ BoltSecurityError,
+ )
+
+ try:
+ return Neo4jDriver.open(*targets, auth=auth, routing_context=routing_context, **config)
+ except (BoltHandshakeError, BoltSecurityError) as error:
+ from neo4j.exceptions import ServiceUnavailable
+ raise ServiceUnavailable(str(error)) from error
-class Direct:
+class _Direct:
default_host = "localhost"
default_port = 7687
@@ -53,7 +166,7 @@ def parse_target(cls, target):
return address
-class Routing:
+class _Routing:
default_host = "localhost"
default_port = 7687
@@ -80,8 +193,8 @@ def parse_targets(cls, *targets):
class Driver:
- """ Base class for all types of :class:`neo4j.Driver`, instances of which are
- used as the primary access point to Neo4j.
+ """ Base class for all types of :class:`neo4j.Driver`, instances of
+ which are used as the primary access point to Neo4j.
"""
#: Connection pool
@@ -91,15 +204,16 @@ def __init__(self, pool):
assert pool is not None
self._pool = pool
- def __del__(self):
- self.close()
-
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
+ def __del__(self):
+ if not asyncio.iscoroutinefunction(self.close):
+ self.close()
+
@property
def encrypted(self):
return bool(self._pool.pool_config.encrypted)
@@ -107,18 +221,14 @@ def encrypted(self):
def session(self, **config):
"""Create a session, see :ref:`session-construction-ref`
- :param config: session configuration key-word arguments, see :ref:`session-configuration-ref` for available key-word arguments.
+ :param config: session configuration key-word arguments,
+ see :ref:`session-configuration-ref` for available key-word
+ arguments.
:returns: new :class:`neo4j.Session` object
"""
raise NotImplementedError
- @experimental("The pipeline API is experimental and may be removed or changed in a future release")
- def pipeline(self, **config):
- """ Create a pipeline.
- """
- raise NotImplementedError
-
def close(self):
""" Shut down, closing any open connections in the pool.
"""
@@ -148,13 +258,16 @@ def supports_multi_db(self):
return session._connection.supports_multiple_databases
-class BoltDriver(Direct, Driver):
- """ A :class:`.BoltDriver` is created from a ``bolt`` URI and addresses
- a single database machine. This may be a standalone server or could be a
- specific member of a cluster.
+class BoltDriver(_Direct, Driver):
+ """:class:`.BoltDriver` is instantiated for ``bolt`` URIs and
+ addresses a single database machine. This may be a standalone server or
+ could be a specific member of a cluster.
- Connections established by a :class:`.BoltDriver` are always made to the
- exact host and port detailed in the URI.
+ Connections established by a :class:`.BoltDriver` are always made to
+ the exact host and port detailed in the URI.
+
+ This class is not supposed to be instantiated externally. Use
+ :meth:`GraphDatabase.driver` instead.
"""
@classmethod
@@ -167,14 +280,14 @@ def open(cls, target, *, auth=None, **config):
:return:
:rtype: :class: `neo4j.BoltDriver`
"""
- from neo4j.io import BoltPool
+ from .io import BoltPool
address = cls.parse_target(target)
pool_config, default_workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig)
pool = BoltPool.open(address, auth=auth, pool_config=pool_config, workspace_config=default_workspace_config)
return cls(pool, default_workspace_config)
def __init__(self, pool, default_workspace_config):
- Direct.__init__(self, pool.address)
+ _Direct.__init__(self, pool.address)
Driver.__init__(self, pool)
self._default_workspace_config = default_workspace_config
@@ -185,20 +298,11 @@ def session(self, **config):
:return:
:rtype: :class: `neo4j.Session`
"""
- from neo4j.work.simple import Session
+ from .work import Session
session_config = SessionConfig(self._default_workspace_config, config)
SessionConfig.consume(config) # Consume the config
return Session(self._pool, session_config)
- def pipeline(self, **config):
- from neo4j.work.pipelining import (
- Pipeline,
- PipelineConfig,
- )
- pipeline_config = PipelineConfig(self._default_workspace_config, config)
- PipelineConfig.consume(config) # Consume the config
- return Pipeline(self._pool, pipeline_config)
-
@experimental("The configuration may change in the future.")
def verify_connectivity(self, **config):
server_agent = None
@@ -211,41 +315,36 @@ def verify_connectivity(self, **config):
return server_agent
-class Neo4jDriver(Routing, Driver):
- """ A :class:`.Neo4jDriver` is created from a ``neo4j`` URI. The
+class Neo4jDriver(_Routing, Driver):
+ """:class:`.Neo4jDriver` is instantiated for ``neo4j`` URIs. The
routing behaviour works in tandem with Neo4j's `Causal Clustering
`_
feature by directing read and write behaviour to appropriate
cluster members.
+
+ This class is not supposed to be instantiated externally. Use
+ :meth:`GraphDatabase.driver` instead.
"""
@classmethod
def open(cls, *targets, auth=None, routing_context=None, **config):
- from neo4j.io import Neo4jPool
+ from .io import Neo4jPool
addresses = cls.parse_targets(*targets)
pool_config, default_workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig)
pool = Neo4jPool.open(*addresses, auth=auth, routing_context=routing_context, pool_config=pool_config, workspace_config=default_workspace_config)
return cls(pool, default_workspace_config)
def __init__(self, pool, default_workspace_config):
- Routing.__init__(self, pool.get_default_database_initial_router_addresses())
+ _Routing.__init__(self, pool.get_default_database_initial_router_addresses())
Driver.__init__(self, pool)
self._default_workspace_config = default_workspace_config
def session(self, **config):
+ from .work import Session
session_config = SessionConfig(self._default_workspace_config, config)
SessionConfig.consume(config) # Consume the config
return Session(self._pool, session_config)
- def pipeline(self, **config):
- from neo4j.work.pipelining import (
- Pipeline,
- PipelineConfig,
- )
- pipeline_config = PipelineConfig(self._default_workspace_config, config)
- PipelineConfig.consume(config) # Consume the config
- return Pipeline(self._pool, pipeline_config)
-
@experimental("The configuration may change in the future.")
def verify_connectivity(self, **config):
"""
@@ -255,7 +354,7 @@ def verify_connectivity(self, **config):
return self._verify_routing_connectivity()
def _verify_routing_connectivity(self):
- from neo4j.exceptions import (
+ from ..exceptions import (
Neo4jError,
ServiceUnavailable,
SessionExpired,
diff --git a/neo4j/_sync/io/__init__.py b/neo4j/_sync/io/__init__.py
new file mode 100644
index 00000000..b598d07d
--- /dev/null
+++ b/neo4j/_sync/io/__init__.py
@@ -0,0 +1,43 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This module contains the low-level functionality required for speaking
+Bolt. It is not intended to be used directly by driver users. Instead,
+the `session` module provides the main user-facing abstractions.
+"""
+
+
+__all__ = [
+ "Bolt",
+ "BoltPool",
+ "Neo4jPool",
+ "check_supported_server_product",
+ "ConnectionErrorHandler",
+]
+
+
+from ._bolt import Bolt
+from ._common import (
+ check_supported_server_product,
+ ConnectionErrorHandler,
+)
+from ._pool import (
+ BoltPool,
+ Neo4jPool,
+)
diff --git a/neo4j/_sync/io/_bolt.py b/neo4j/_sync/io/_bolt.py
new file mode 100644
index 00000000..82ee8b62
--- /dev/null
+++ b/neo4j/_sync/io/_bolt.py
@@ -0,0 +1,571 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import abc
+import asyncio
+from collections import deque
+from logging import getLogger
+from time import perf_counter
+
+from ..._async_compat.network import BoltSocket
+from ..._exceptions import BoltHandshakeError
+from ...addressing import Address
+from ...api import (
+ ServerInfo,
+ Version,
+)
+from ...conf import PoolConfig
+from ...exceptions import (
+ AuthError,
+ IncompleteCommit,
+ ServiceUnavailable,
+ SessionExpired,
+)
+from ...meta import get_user_agent
+from ...packstream import (
+ Packer,
+ Unpacker,
+)
+from ._common import (
+ CommitResponse,
+ Inbox,
+ Outbox,
+)
+
+
+# Set up logger
+log = getLogger("neo4j")
+
+
+class Bolt:
+ """ Server connection for Bolt protocol.
+
+ A :class:`.Bolt` should be constructed following a
+ successful .open()
+
+ Bolt handshake and takes the socket over which
+ the handshake was carried out.
+ """
+
+ MAGIC_PREAMBLE = b"\x60\x60\xB0\x17"
+
+ PROTOCOL_VERSION = None
+
+ # flag if connection needs RESET to go back to READY state
+ is_reset = False
+
+ # The socket
+ in_use = False
+
+ # The socket
+ _closed = False
+
+ # The socket
+ _defunct = False
+
+ #: The pool of which this connection is a member
+ pool = None
+
+ # Store the id of the most recent ran query to be able to reduce sent bits by
+ # using the default (-1) to refer to the most recent query when pulling
+ # results for it.
+ most_recent_qid = None
+
+ def __init__(self, unresolved_address, sock, max_connection_lifetime, *,
+ auth=None, user_agent=None, routing_context=None):
+ self.unresolved_address = unresolved_address
+ self.socket = sock
+ self.server_info = ServerInfo(Address(sock.getpeername()),
+ self.PROTOCOL_VERSION)
+ # so far `connection.recv_timeout_seconds` is the only available
+ # configuration hint that exists. Therefore, all hints can be stored at
+ # connection level. This might change in the future.
+ self.configuration_hints = {}
+ self.outbox = Outbox()
+ self.inbox = Inbox(self.socket, on_error=self._set_defunct_read)
+ self.packer = Packer(self.outbox)
+ self.unpacker = Unpacker(self.inbox)
+ self.responses = deque()
+ self._max_connection_lifetime = max_connection_lifetime
+ self._creation_timestamp = perf_counter()
+ self.routing_context = routing_context
+
+ # Determine the user agent
+ if user_agent:
+ self.user_agent = user_agent
+ else:
+ self.user_agent = get_user_agent()
+
+ # Determine auth details
+ if not auth:
+ self.auth_dict = {}
+ elif isinstance(auth, tuple) and 2 <= len(auth) <= 3:
+ from neo4j import Auth
+ self.auth_dict = vars(Auth("basic", *auth))
+ else:
+ try:
+ self.auth_dict = vars(auth)
+ except (KeyError, TypeError):
+ raise AuthError("Cannot determine auth details from %r" % auth)
+
+ # Check for missing password
+ try:
+ credentials = self.auth_dict["credentials"]
+ except KeyError:
+ pass
+ else:
+ if credentials is None:
+ raise AuthError("Password cannot be None")
+
+ def __del__(self):
+ if not asyncio.iscoroutinefunction(self.close):
+ self.close()
+
+ @property
+ @abc.abstractmethod
+ def supports_multiple_results(self):
+ """ Boolean flag to indicate if the connection version supports multiple
+ queries to be buffered on the server side (True) or if all results need
+ to be eagerly pulled before sending the next RUN (False).
+ """
+ pass
+
+ @property
+ @abc.abstractmethod
+ def supports_multiple_databases(self):
+ """ Boolean flag to indicate if the connection version supports multiple
+ databases.
+ """
+ pass
+
+ @classmethod
+ def protocol_handlers(cls, protocol_version=None):
+ """ Return a dictionary of available Bolt protocol handlers,
+ keyed by version tuple. If an explicit protocol version is
+ provided, the dictionary will contain either zero or one items,
+ depending on whether that version is supported. If no protocol
+ version is provided, all available versions will be returned.
+
+ :param protocol_version: tuple identifying a specific protocol
+ version (e.g. (3, 5)) or None
+ :return: dictionary of version tuple to handler class for all
+ relevant and supported protocol versions
+ :raise TypeError: if protocol version is not passed in a tuple
+ """
+
+ # Carry out Bolt subclass imports locally to avoid circular dependency issues.
+ from ._bolt3 import Bolt3
+ from ._bolt4 import (
+ Bolt4x0,
+ Bolt4x1,
+ Bolt4x2,
+ Bolt4x3,
+ Bolt4x4,
+ )
+
+ handlers = {
+ Bolt3.PROTOCOL_VERSION: Bolt3,
+ Bolt4x0.PROTOCOL_VERSION: Bolt4x0,
+ Bolt4x1.PROTOCOL_VERSION: Bolt4x1,
+ Bolt4x2.PROTOCOL_VERSION: Bolt4x2,
+ Bolt4x3.PROTOCOL_VERSION: Bolt4x3,
+ Bolt4x4.PROTOCOL_VERSION: Bolt4x4,
+ }
+
+ if protocol_version is None:
+ return handlers
+
+ if not isinstance(protocol_version, tuple):
+ raise TypeError("Protocol version must be specified as a tuple")
+
+ if protocol_version in handlers:
+ return {protocol_version: handlers[protocol_version]}
+
+ return {}
+
+ @classmethod
+ def version_list(cls, versions, limit=4):
+ """ Return a list of supported protocol versions in order of
+ preference. The number of protocol versions (or ranges)
+ returned is limited to four.
+ """
+ # In fact, 4.3 is the fist version to support ranges. However, the range
+ # support got backported to 4.2. But even if the server is too old to
+ # have the backport, negotiating BOLT 4.1 is no problem as it's
+ # equivalent to 4.2
+ first_with_range_support = Version(4, 2)
+ result = []
+ for version in versions:
+ if (result
+ and version >= first_with_range_support
+ and result[-1][0] == version[0]
+ and result[-1][1][1] == version[1] + 1):
+ # can use range to encompass this version
+ result[-1][1][1] = version[1]
+ continue
+ result.append(Version(version[0], [version[1], version[1]]))
+ if len(result) == 4:
+ break
+ return result
+
+ @classmethod
+ def get_handshake(cls):
+ """ Return the supported Bolt versions as bytes.
+ The length is 16 bytes as specified in the Bolt version negotiation.
+ :return: bytes
+ """
+ supported_versions = sorted(cls.protocol_handlers().keys(), reverse=True)
+ offered_versions = cls.version_list(supported_versions)
+ return b"".join(version.to_bytes() for version in offered_versions).ljust(16, b"\x00")
+
+ @classmethod
+ def ping(cls, address, *, timeout=None, **config):
+ """ Attempt to establish a Bolt connection, returning the
+ agreed Bolt protocol version if successful.
+ """
+ config = PoolConfig.consume(config)
+ try:
+ s, protocol_version, handshake, data = \
+ BoltSocket.connect(
+ address,
+ timeout=timeout,
+ custom_resolver=config.resolver,
+ ssl_context=config.get_ssl_context(),
+ keep_alive=config.keep_alive,
+ )
+ except (ServiceUnavailable, SessionExpired, BoltHandshakeError):
+ return None
+ else:
+ BoltSocket.close_socket(s)
+ return protocol_version
+
+ @classmethod
+ def open(
+ cls, address, *, auth=None, timeout=None, routing_context=None, **pool_config
+ ):
+ """ Open a new Bolt connection to a given server address.
+
+ :param address:
+ :param auth:
+ :param timeout: the connection timeout in seconds
+ :param routing_context: dict containing routing context
+ :param pool_config:
+ :return:
+ :raise BoltHandshakeError: raised if the Bolt Protocol can not negotiate a protocol version.
+ :raise ServiceUnavailable: raised if there was a connection issue.
+ """
+ pool_config = PoolConfig.consume(pool_config)
+ s, pool_config.protocol_version, handshake, data = \
+ BoltSocket.connect(
+ address,
+ timeout=timeout,
+ custom_resolver=pool_config.resolver,
+ ssl_context=pool_config.get_ssl_context(),
+ keep_alive=pool_config.keep_alive,
+ )
+
+ # Carry out Bolt subclass imports locally to avoid circular dependency
+ # issues.
+ if pool_config.protocol_version == (3, 0):
+ from ._bolt3 import Bolt3
+ bolt_cls = Bolt3
+ elif pool_config.protocol_version == (4, 0):
+ from ._bolt4 import Bolt4x0
+ bolt_cls = Bolt4x0
+ elif pool_config.protocol_version == (4, 1):
+ from ._bolt4 import Bolt4x1
+ bolt_cls = Bolt4x1
+ elif pool_config.protocol_version == (4, 2):
+ from ._bolt4 import Bolt4x2
+ bolt_cls = Bolt4x2
+ elif pool_config.protocol_version == (4, 3):
+ from ._bolt4 import Bolt4x3
+ bolt_cls = Bolt4x3
+ elif pool_config.protocol_version == (4, 4):
+ from ._bolt4 import Bolt4x4
+ bolt_cls = Bolt4x4
+ else:
+ log.debug("[#%04X] S: ", s.getsockname()[1])
+ BoltSocket.close_socket(s)
+
+ supported_versions = cls.protocol_handlers().keys()
+ raise BoltHandshakeError("The Neo4J server does not support communication with this driver. This driver have support for Bolt Protocols {}".format(supported_versions), address=address, request_data=handshake, response_data=data)
+
+ connection = bolt_cls(
+ address, s, pool_config.max_connection_lifetime, auth=auth,
+ user_agent=pool_config.user_agent, routing_context=routing_context
+ )
+
+ try:
+ connection.hello()
+ except Exception:
+ connection.close()
+ raise
+
+ return connection
+
+ @property
+ @abc.abstractmethod
+ def encrypted(self):
+ pass
+
+ @property
+ @abc.abstractmethod
+ def der_encoded_server_certificate(self):
+ pass
+
+ @property
+ @abc.abstractmethod
+ def local_port(self):
+ pass
+
+ @abc.abstractmethod
+ def hello(self):
+ """ Appends a HELLO message to the outgoing queue, sends it and consumes
+ all remaining messages.
+ """
+ pass
+
+ @abc.abstractmethod
+ def route(self, database=None, imp_user=None, bookmarks=None):
+ """ Fetch a routing table from the server for the given
+ `database`. For Bolt 4.3 and above, this appends a ROUTE
+ message; for earlier versions, a procedure call is made via
+ the regular Cypher execution mechanism. In all cases, this is
+ sent to the network, and a response is fetched.
+
+ :param database: database for which to fetch a routing table
+ Requires Bolt 4.0+.
+ :param imp_user: the user to impersonate
+ Requires Bolt 4.4+.
+ :param bookmarks: iterable of bookmark values after which this
+ transaction should begin
+ :return: dictionary of raw routing data
+ """
+ pass
+
+ @abc.abstractmethod
+ def run(self, query, parameters=None, mode=None, bookmarks=None,
+ metadata=None, timeout=None, db=None, imp_user=None, **handlers):
+ """ Appends a RUN message to the output queue.
+
+ :param query: Cypher query string
+ :param parameters: dictionary of Cypher parameters
+ :param mode: access mode for routing - "READ" or "WRITE" (default)
+ :param bookmarks: iterable of bookmark values after which this transaction should begin
+ :param metadata: custom metadata dictionary to attach to the transaction
+ :param timeout: timeout for transaction execution (seconds)
+ :param db: name of the database against which to begin the transaction
+ Requires Bolt 4.0+.
+ :param imp_user: the user to impersonate
+ Requires Bolt 4.4+.
+ :param handlers: handler functions passed into the returned Response object
+ :return: Response object
+ """
+ pass
+
+ @abc.abstractmethod
+ def discard(self, n=-1, qid=-1, **handlers):
+ """ Appends a DISCARD message to the output queue.
+
+ :param n: number of records to discard, default = -1 (ALL)
+ :param qid: query ID to discard for, default = -1 (last query)
+ :param handlers: handler functions passed into the returned Response object
+ :return: Response object
+ """
+ pass
+
+ @abc.abstractmethod
+ def pull(self, n=-1, qid=-1, **handlers):
+ """ Appends a PULL message to the output queue.
+
+ :param n: number of records to pull, default = -1 (ALL)
+ :param qid: query ID to pull for, default = -1 (last query)
+ :param handlers: handler functions passed into the returned Response object
+ :return: Response object
+ """
+ pass
+
+ @abc.abstractmethod
+ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
+ db=None, imp_user=None, **handlers):
+ """ Appends a BEGIN message to the output queue.
+
+ :param mode: access mode for routing - "READ" or "WRITE" (default)
+ :param bookmarks: iterable of bookmark values after which this transaction should begin
+ :param metadata: custom metadata dictionary to attach to the transaction
+ :param timeout: timeout for transaction execution (seconds)
+ :param db: name of the database against which to begin the transaction
+ Requires Bolt 4.0+.
+ :param imp_user: the user to impersonate
+ Requires Bolt 4.4+
+ :param handlers: handler functions passed into the returned Response object
+ :return: Response object
+ """
+ pass
+
+ @abc.abstractmethod
+ def commit(self, **handlers):
+ """ Appends a COMMIT message to the output queue."""
+ pass
+
+ @abc.abstractmethod
+ def rollback(self, **handlers):
+ """ Appends a ROLLBACK message to the output queue."""
+ pass
+
+ @abc.abstractmethod
+ def reset(self):
+ """ Appends a RESET message to the outgoing queue, sends it and consumes
+ all remaining messages.
+ """
+ pass
+
+ def _append(self, signature, fields=(), response=None):
+ """ Appends a message to the outgoing queue.
+
+ :param signature: the signature of the message
+ :param fields: the fields of the message as a tuple
+ :param response: a response object to handle callbacks
+ """
+ self.packer.pack_struct(signature, fields)
+ self.outbox.wrap_message()
+ self.responses.append(response)
+
+ def _send_all(self):
+ data = self.outbox.view()
+ if data:
+ try:
+ self.socket.sendall(data)
+ except OSError as error:
+ self._set_defunct_write(error)
+ self.outbox.clear()
+
+ def send_all(self):
+ """ Send all queued messages to the server.
+ """
+ if self.closed():
+ raise ServiceUnavailable("Failed to write to closed connection {!r} ({!r})".format(
+ self.unresolved_address, self.server_info.address))
+
+ if self.defunct():
+ raise ServiceUnavailable("Failed to write to defunct connection {!r} ({!r})".format(
+ self.unresolved_address, self.server_info.address))
+
+ self._send_all()
+
+ @abc.abstractmethod
+ def fetch_message(self):
+ """ Receive at most one message from the server, if available.
+
+ :return: 2-tuple of number of detail messages and number of summary
+ messages fetched
+ """
+ pass
+
+ def fetch_all(self):
+ """ Fetch all outstanding messages.
+
+ :return: 2-tuple of number of detail messages and number of summary
+ messages fetched
+ """
+ detail_count = summary_count = 0
+ while self.responses:
+ response = self.responses[0]
+ while not response.complete:
+ detail_delta, summary_delta = self.fetch_message()
+ detail_count += detail_delta
+ summary_count += summary_delta
+ return detail_count, summary_count
+
+ def _set_defunct_read(self, error=None, silent=False):
+ message = "Failed to read from defunct connection {!r} ({!r})".format(
+ self.unresolved_address, self.server_info.address
+ )
+ self._set_defunct(message, error=error, silent=silent)
+
+ def _set_defunct_write(self, error=None, silent=False):
+ message = "Failed to write data to connection {!r} ({!r})".format(
+ self.unresolved_address, self.server_info.address
+ )
+ self._set_defunct(message, error=error, silent=silent)
+
+ def _set_defunct(self, message, error=None, silent=False):
+ from ._pool import BoltPool
+ direct_driver = isinstance(self.pool, BoltPool)
+
+ if error:
+ log.debug("[#%04X] %s", self.socket.getsockname()[1], error)
+ log.error(message)
+ # We were attempting to receive data but the connection
+ # has unexpectedly terminated. So, we need to close the
+ # connection from the client side, and remove the address
+ # from the connection pool.
+ self._defunct = True
+ self.close()
+ if self.pool:
+ self.pool.deactivate(address=self.unresolved_address)
+ # Iterate through the outstanding responses, and if any correspond
+ # to COMMIT requests then raise an error to signal that we are
+ # unable to confirm that the COMMIT completed successfully.
+ if silent:
+ return
+ for response in self.responses:
+ if isinstance(response, CommitResponse):
+ if error:
+ raise IncompleteCommit(message) from error
+ else:
+ raise IncompleteCommit(message)
+
+ if direct_driver:
+ if error:
+ raise ServiceUnavailable(message) from error
+ else:
+ raise ServiceUnavailable(message)
+ else:
+ if error:
+ raise SessionExpired(message) from error
+ else:
+ raise SessionExpired(message)
+
+ def stale(self):
+ return (self._stale
+ or (0 <= self._max_connection_lifetime
+ <= perf_counter() - self._creation_timestamp))
+
+ _stale = False
+
+ def set_stale(self):
+ self._stale = True
+
+ @abc.abstractmethod
+ def close(self):
+ """ Close the connection.
+ """
+ pass
+
+ @abc.abstractmethod
+ def closed(self):
+ pass
+
+ @abc.abstractmethod
+ def defunct(self):
+ pass
+
+
+BoltSocket.Bolt = Bolt
diff --git a/neo4j/io/_bolt3.py b/neo4j/_sync/io/_bolt3.py
similarity index 97%
rename from neo4j/io/_bolt3.py
rename to neo4j/_sync/io/_bolt3.py
index 704d8740..e19bd4c2 100644
--- a/neo4j/io/_bolt3.py
+++ b/neo4j/_sync/io/_bolt3.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -18,19 +15,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+
from enum import Enum
from logging import getLogger
from ssl import SSLSocket
-from neo4j._exceptions import (
+from ..._async_compat.util import Util
+from ..._exceptions import (
BoltError,
BoltProtocolError,
)
-from neo4j.api import (
+from ...api import (
READ_ACCESS,
Version,
)
-from neo4j.exceptions import (
+from ...exceptions import (
ConfigurationError,
DatabaseUnavailable,
DriverError,
@@ -39,11 +38,9 @@
NotALeader,
ServiceUnavailable,
)
-from neo4j.io import (
- Bolt,
+from ._bolt import Bolt
+from ._common import (
check_supported_server_product,
-)
-from neo4j.io._common import (
CommitResponse,
InitResponse,
Response,
@@ -93,8 +90,8 @@ def transition(self, message, metadata):
if metadata.get("has_more"):
return
state_before = self.state
- self.state = self._STATE_TRANSITIONS\
- .get(self.state, {})\
+ self.state = self._STATE_TRANSITIONS \
+ .get(self.state, {}) \
.get(message, self.state)
if state_before != self.state and callable(self._on_change):
self._on_change(state_before, self.state)
@@ -331,7 +328,8 @@ def fetch_message(self):
return 0, 0
# Receive exactly one message
- details, summary_signature, summary_metadata = next(self.inbox)
+ details, summary_signature, summary_metadata = \
+ Util.next(self.inbox)
if details:
log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # Do not log any data
@@ -357,11 +355,11 @@ def fetch_message(self):
response.on_failure(summary_metadata or {})
except (ServiceUnavailable, DatabaseUnavailable):
if self.pool:
- self.pool.deactivate(address=self.unresolved_address),
+ self.pool.deactivate(address=self.unresolved_address)
raise
except (NotALeader, ForbiddenOnReadOnlyDatabase):
if self.pool:
- self.pool.on_write_failure(address=self.unresolved_address),
+ self.pool.on_write_failure(address=self.unresolved_address)
raise
except Neo4jError as e:
if self.pool and e.invalidates_all_connections():
diff --git a/neo4j/io/_bolt4.py b/neo4j/_sync/io/_bolt4.py
similarity index 98%
rename from neo4j/io/_bolt4.py
rename to neo4j/_sync/io/_bolt4.py
index 3bb90ab6..33242216 100644
--- a/neo4j/io/_bolt4.py
+++ b/neo4j/_sync/io/_bolt4.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -18,20 +15,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from enum import Enum
+
from logging import getLogger
from ssl import SSLSocket
-from neo4j._exceptions import (
+from ..._async_compat.util import Util
+from ..._exceptions import (
BoltError,
BoltProtocolError,
)
-from neo4j.api import (
+from ...api import (
READ_ACCESS,
SYSTEM_DATABASE,
Version,
)
-from neo4j.exceptions import (
+from ...exceptions import (
ConfigurationError,
DatabaseUnavailable,
DriverError,
@@ -40,19 +38,17 @@
NotALeader,
ServiceUnavailable,
)
-from neo4j.io import (
- Bolt,
- check_supported_server_product,
+from ._bolt3 import (
+ ServerStateManager,
+ ServerStates,
)
-from neo4j.io._common import (
+from ._bolt import Bolt
+from ._common import (
+ check_supported_server_product,
CommitResponse,
InitResponse,
Response,
)
-from neo4j.io._bolt3 import (
- ServerStateManager,
- ServerStates,
-)
log = getLogger("neo4j")
@@ -283,7 +279,8 @@ def fetch_message(self):
return 0, 0
# Receive exactly one message
- details, summary_signature, summary_metadata = next(self.inbox)
+ details, summary_signature, summary_metadata = \
+ Util.next(self.inbox)
if details:
log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # Do not log any data
@@ -309,11 +306,11 @@ def fetch_message(self):
response.on_failure(summary_metadata or {})
except (ServiceUnavailable, DatabaseUnavailable):
if self.pool:
- self.pool.deactivate(address=self.unresolved_address),
+ self.pool.deactivate(address=self.unresolved_address)
raise
except (NotALeader, ForbiddenOnReadOnlyDatabase):
if self.pool:
- self.pool.on_write_failure(address=self.unresolved_address),
+ self.pool.on_write_failure(address=self.unresolved_address)
raise
except Neo4jError as e:
if self.pool and e.invalidates_all_connections():
diff --git a/neo4j/io/_common.py b/neo4j/_sync/io/_common.py
similarity index 78%
rename from neo4j/io/_common.py
rename to neo4j/_sync/io/_common.py
index becb7db4..408de0a1 100644
--- a/neo4j/io/_common.py
+++ b/neo4j/_sync/io/_common.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,21 +16,24 @@
# limitations under the License.
+import asyncio
+import logging
import socket
from struct import pack as struct_pack
-from neo4j.exceptions import (
- AuthError,
+from ..._async_compat.util import Util
+from ...exceptions import (
Neo4jError,
ServiceUnavailable,
SessionExpired,
+ UnsupportedServerProduct,
)
-from neo4j.packstream import (
+from ...packstream import (
UnpackableBuffer,
Unpacker,
)
-import logging
+
log = logging.getLogger("neo4j")
@@ -52,12 +52,12 @@ def _yield_messages(self, sock):
while chunk_size == 0:
# Determine the chunk size and skip noop
- buffer.receive(sock, 2)
+ receive_into_buffer(sock, buffer, 2)
chunk_size = buffer.pop_u16()
if chunk_size == 0:
log.debug("[#%04X] S: ", sock.getsockname()[1])
- buffer.receive(sock, chunk_size + 2)
+ receive_into_buffer(sock, buffer, chunk_size + 2)
chunk_size = buffer.pop_u16()
if chunk_size == 0:
@@ -69,10 +69,10 @@ def _yield_messages(self, sock):
unpacker.reset()
except (OSError, socket.timeout) as error:
- self.on_error(error)
+ Util.callback(self.on_error, error)
def pop(self):
- return next(self._messages)
+ return Util.next(self._messages)
class Inbox(MessageInbox):
@@ -166,10 +166,22 @@ def inner(*args, **kwargs):
try:
func(*args, **kwargs)
except (Neo4jError, ServiceUnavailable, SessionExpired) as exc:
+ assert not asyncio.iscoroutinefunction(self.__on_error)
self.__on_error(exc)
raise
return inner
+ def outer_async(coroutine_func):
+ def inner(*args, **kwargs):
+ try:
+ coroutine_func(*args, **kwargs)
+ except (Neo4jError, ServiceUnavailable, SessionExpired) as exc:
+ Util.callback(self.__on_error, exc)
+ raise
+ return inner
+
+ if asyncio.iscoroutinefunction(connection_attr):
+ return outer_async(connection_attr)
return outer(connection_attr)
def __setattr__(self, name, value):
@@ -194,20 +206,17 @@ def on_records(self, records):
""" Called when one or more RECORD messages have been received.
"""
handler = self.handlers.get("on_records")
- if callable(handler):
- handler(records)
+ Util.callback(handler, records)
def on_success(self, metadata):
""" Called when a SUCCESS message has been received.
"""
handler = self.handlers.get("on_success")
- if callable(handler):
- handler(metadata)
+ Util.callback(handler, metadata)
if not metadata.get("has_more"):
handler = self.handlers.get("on_summary")
- if callable(handler):
- handler()
+ Util.callback(handler)
def on_failure(self, metadata):
""" Called when a FAILURE message has been received.
@@ -217,22 +226,18 @@ def on_failure(self, metadata):
except (SessionExpired, ServiceUnavailable):
pass
handler = self.handlers.get("on_failure")
- if callable(handler):
- handler(metadata)
+ Util.callback(handler, metadata)
handler = self.handlers.get("on_summary")
- if callable(handler):
- handler()
+ Util.callback(handler)
raise Neo4jError.hydrate(**metadata)
def on_ignored(self, metadata=None):
""" Called when an IGNORED message has been received.
"""
handler = self.handlers.get("on_ignored")
- if callable(handler):
- handler(metadata)
+ Util.callback(handler, metadata)
handler = self.handlers.get("on_summary")
- if callable(handler):
- handler()
+ Util.callback(handler)
class InitResponse(Response):
@@ -250,3 +255,26 @@ def on_failure(self, metadata):
class CommitResponse(Response):
pass
+
+
+def check_supported_server_product(agent):
+ """ Checks that a server product is supported by the driver by
+ looking at the server agent string.
+
+ :param agent: server agent string to check for validity
+ :raises UnsupportedServerProduct: if the product is not supported
+ """
+ if not agent.startswith("Neo4j/"):
+ raise UnsupportedServerProduct(agent)
+
+
+def receive_into_buffer(sock, buffer, n_bytes):
+ end = buffer.used + n_bytes
+ if end > len(buffer.data):
+ buffer.data += bytearray(end - len(buffer.data))
+ view = memoryview(buffer.data)
+ while buffer.used < end:
+ n = sock.recv_into(view[buffer.used:end], end - buffer.used)
+ if n == 0:
+ raise OSError("No data")
+ buffer.used += n
diff --git a/neo4j/_sync/io/_pool.py b/neo4j/_sync/io/_pool.py
new file mode 100644
index 00000000..5969fd6d
--- /dev/null
+++ b/neo4j/_sync/io/_pool.py
@@ -0,0 +1,701 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import abc
+from collections import (
+ defaultdict,
+ deque,
+)
+import logging
+from logging import getLogger
+from random import choice
+from time import perf_counter
+
+from ..._async_compat.concurrency import (
+ Condition,
+ RLock,
+)
+from ..._async_compat.network import NetworkUtil
+from ..._exceptions import BoltError
+from ...api import (
+ READ_ACCESS,
+ WRITE_ACCESS,
+)
+from ...conf import (
+ PoolConfig,
+ WorkspaceConfig,
+)
+from ...exceptions import (
+ ClientError,
+ ConfigurationError,
+ DriverError,
+ Neo4jError,
+ ReadServiceUnavailable,
+ ServiceUnavailable,
+ SessionExpired,
+ WriteServiceUnavailable,
+)
+from ...routing import RoutingTable
+from ._bolt import Bolt
+
+
+# Set up logger
+log = getLogger("neo4j")
+
+
+class IOPool(abc.ABC):
+ """ A collection of connections to one or more server addresses.
+ """
+
+ def __init__(self, opener, pool_config, workspace_config):
+ assert callable(opener)
+ assert isinstance(pool_config, PoolConfig)
+ assert isinstance(workspace_config, WorkspaceConfig)
+
+ self.opener = opener
+ self.pool_config = pool_config
+ self.workspace_config = workspace_config
+ self.connections = defaultdict(deque)
+ self.lock = RLock()
+ self.cond = Condition(self.lock)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.close()
+
+ def _acquire(self, address, timeout):
+ """ Acquire a connection to a given address from the pool.
+ The address supplied should always be an IP address, not
+ a host name.
+
+ This method is thread safe.
+ """
+ t0 = perf_counter()
+ if timeout is None:
+ timeout = self.workspace_config.connection_acquisition_timeout
+
+ with self.lock:
+ def time_remaining():
+ t = timeout - (perf_counter() - t0)
+ return t if t > 0 else 0
+
+ while True:
+ # try to find a free connection in pool
+ for connection in list(self.connections.get(address, [])):
+ if (connection.closed() or connection.defunct()
+ or (connection.stale() and not connection.in_use)):
+ # `close` is a noop on already closed connections.
+ # This is to make sure that the connection is
+ # gracefully closed, e.g. if it's just marked as
+ # `stale` but still alive.
+ if log.isEnabledFor(logging.DEBUG):
+ log.debug(
+ "[#%04X] C: removing old connection "
+ "(closed=%s, defunct=%s, stale=%s, in_use=%s)",
+ connection.local_port,
+ connection.closed(), connection.defunct(),
+ connection.stale(), connection.in_use
+ )
+ connection.close()
+ try:
+ self.connections.get(address, []).remove(connection)
+ except ValueError:
+ # If closure fails (e.g. because the server went
+ # down), all connections to the same address will
+ # be removed. Therefore, we silently ignore if the
+ # connection isn't in the pool anymore.
+ pass
+ continue
+ if not connection.in_use:
+ connection.in_use = True
+ return connection
+ # all connections in pool are in-use
+ connections = self.connections[address]
+ max_pool_size = self.pool_config.max_connection_pool_size
+ infinite_pool_size = (max_pool_size < 0
+ or max_pool_size == float("inf"))
+ can_create_new_connection = (
+ infinite_pool_size
+ or len(connections) < max_pool_size
+ )
+ if can_create_new_connection:
+ timeout = min(self.pool_config.connection_timeout,
+ time_remaining())
+ try:
+ connection = self.opener(address, timeout)
+ except ServiceUnavailable:
+ self.remove(address)
+ raise
+ else:
+ connection.pool = self
+ connection.in_use = True
+ connections.append(connection)
+ return connection
+
+ # failed to obtain a connection from pool because the
+ # pool is full and no free connection in the pool
+ if time_remaining():
+ self.cond.wait(time_remaining())
+ # if timed out, then we throw error. This time
+ # computation is needed, as with python 2.7, we
+ # cannot tell if the condition is notified or
+ # timed out when we come to this line
+ if not time_remaining():
+ raise ClientError("Failed to obtain a connection from pool "
+ "within {!r}s".format(timeout))
+ else:
+ raise ClientError("Failed to obtain a connection from pool "
+ "within {!r}s".format(timeout))
+
+ @abc.abstractmethod
+ def acquire(
+ self, access_mode=None, timeout=None, database=None, bookmarks=None
+ ):
+ """ Acquire a connection to a server that can satisfy a set of parameters.
+
+ :param access_mode:
+ :param timeout:
+ :param database:
+ :param bookmarks:
+ """
+
+ def release(self, *connections):
+ """ Release a connection back into the pool.
+ This method is thread safe.
+ """
+ with self.lock:
+ for connection in connections:
+ if not (connection.defunct()
+ or connection.closed()
+ or connection.is_reset):
+ try:
+ connection.reset()
+ except (Neo4jError, DriverError, BoltError) as e:
+ log.debug(
+ "Failed to reset connection on release: %s", e
+ )
+ connection.in_use = False
+ self.cond.notify_all()
+
+ def in_use_connection_count(self, address):
+ """ Count the number of connections currently in use to a given
+ address.
+ """
+ try:
+ connections = self.connections[address]
+ except KeyError:
+ return 0
+ else:
+ return sum(1 if connection.in_use else 0 for connection in connections)
+
+ def mark_all_stale(self):
+ with self.lock:
+ for address in self.connections:
+ for connection in self.connections[address]:
+ connection.set_stale()
+
+ def deactivate(self, address):
+ """ Deactivate an address from the connection pool, if present, closing
+ all idle connection to that address
+ """
+ with self.lock:
+ try:
+ connections = self.connections[address]
+ except KeyError: # already removed from the connection pool
+ return
+ for conn in list(connections):
+ if not conn.in_use:
+ connections.remove(conn)
+ try:
+ conn.close()
+ except OSError:
+ pass
+ if not connections:
+ self.remove(address)
+
+ def on_write_failure(self, address):
+ raise WriteServiceUnavailable(
+ "No write service available for pool {}".format(self)
+ )
+
+ def remove(self, address):
+ """ Remove an address from the connection pool, if present, closing
+ all connections to that address.
+ """
+ with self.lock:
+ for connection in self.connections.pop(address, ()):
+ try:
+ connection.close()
+ except OSError:
+ pass
+
+ def close(self):
+ """ Close all connections and empty the pool.
+ This method is thread safe.
+ """
+ try:
+ with self.lock:
+ for address in list(self.connections):
+ self.remove(address)
+ except TypeError:
+ pass
+
+
+class BoltPool(IOPool):
+
+ @classmethod
+ def open(cls, address, *, auth, pool_config, workspace_config):
+ """Create a new BoltPool
+
+ :param address:
+ :param auth:
+ :param pool_config:
+ :param workspace_config:
+ :return: BoltPool
+ """
+
+ def opener(addr, timeout):
+ return Bolt.open(
+ addr, auth=auth, timeout=timeout, routing_context=None,
+ **pool_config
+ )
+
+ pool = cls(opener, pool_config, workspace_config, address)
+ return pool
+
+ def __init__(self, opener, pool_config, workspace_config, address):
+ super().__init__(opener, pool_config, workspace_config)
+ self.address = address
+
+ def __repr__(self):
+ return "<{} address={!r}>".format(self.__class__.__name__,
+ self.address)
+
+ def acquire(
+ self, access_mode=None, timeout=None, database=None, bookmarks=None
+ ):
+ # The access_mode and database is not needed for a direct connection,
+ # it's just there for consistency.
+ return self._acquire(self.address, timeout)
+
+
+class Neo4jPool(IOPool):
+ """ Connection pool with routing table.
+ """
+
+ @classmethod
+ def open(cls, *addresses, auth, pool_config, workspace_config,
+ routing_context=None):
+ """Create a new Neo4jPool
+
+ :param addresses: one or more address as positional argument
+ :param auth:
+ :param pool_config:
+ :param workspace_config:
+ :param routing_context:
+ :return: Neo4jPool
+ """
+
+ address = addresses[0]
+ if routing_context is None:
+ routing_context = {}
+ elif "address" in routing_context:
+ raise ConfigurationError("The key 'address' is reserved for routing context.")
+ routing_context["address"] = str(address)
+
+ def opener(addr, timeout):
+ return Bolt.open(
+ addr, auth=auth, timeout=timeout,
+ routing_context=routing_context, **pool_config
+ )
+
+ pool = cls(opener, pool_config, workspace_config, address)
+ return pool
+
+ def __init__(self, opener, pool_config, workspace_config, address):
+ """
+
+ :param opener:
+ :param pool_config:
+ :param workspace_config:
+ :param address:
+ """
+ super().__init__(opener, pool_config, workspace_config)
+ # Each database have a routing table, the default database is a special case.
+ log.debug("[#0000] C: routing address %r", address)
+ self.address = address
+ self.routing_tables = {workspace_config.database: RoutingTable(database=workspace_config.database, routers=[address])}
+ self.refresh_lock = RLock()
+
+ def __repr__(self):
+ """ The representation shows the initial routing addresses.
+
+ :return: The representation
+ :rtype: str
+ """
+ return "<{} addresses={!r}>".format(self.__class__.__name__, self.get_default_database_initial_router_addresses())
+
+ @property
+ def first_initial_routing_address(self):
+ return self.get_default_database_initial_router_addresses()[0]
+
+ def get_default_database_initial_router_addresses(self):
+ """ Get the initial router addresses for the default database.
+
+ :return:
+ :rtype: OrderedSet
+ """
+ return self.get_routing_table_for_default_database().initial_routers
+
+ def get_default_database_router_addresses(self):
+ """ Get the router addresses for the default database.
+
+ :return:
+ :rtype: OrderedSet
+ """
+ return self.get_routing_table_for_default_database().routers
+
+ def get_routing_table_for_default_database(self):
+ return self.routing_tables[self.workspace_config.database]
+
+ def get_or_create_routing_table(self, database):
+ with self.refresh_lock:
+ if database not in self.routing_tables:
+ self.routing_tables[database] = RoutingTable(
+ database=database,
+ routers=self.get_default_database_initial_router_addresses()
+ )
+ return self.routing_tables[database]
+
+ def fetch_routing_info(
+ self, address, database, imp_user, bookmarks, timeout
+ ):
+ """ Fetch raw routing info from a given router address.
+
+ :param address: router address
+ :param database: the database name to get routing table for
+ :param imp_user: the user to impersonate while fetching the routing
+ table
+ :type imp_user: str or None
+ :param bookmarks: iterable of bookmark values after which the routing
+ info should be fetched
+ :param timeout: connection acquisition timeout in seconds
+
+ :return: list of routing records, or None if no connection
+ could be established or if no readers or writers are present
+ :raise ServiceUnavailable: if the server does not support
+ routing, or if routing support is broken or outdated
+ """
+ cx = self._acquire(address, timeout)
+ try:
+ routing_table = cx.route(
+ database or self.workspace_config.database,
+ imp_user or self.workspace_config.impersonated_user,
+ bookmarks
+ )
+ finally:
+ self.release(cx)
+ return routing_table
+
+ def fetch_routing_table(
+ self, *, address, timeout, database, imp_user, bookmarks
+ ):
+ """ Fetch a routing table from a given router address.
+
+ :param address: router address
+ :param timeout: seconds
+ :param database: the database name
+ :type: str
+ :param imp_user: the user to impersonate while fetching the routing
+ table
+ :type imp_user: str or None
+ :param bookmarks: bookmarks used when fetching routing table
+
+ :return: a new RoutingTable instance or None if the given router is
+ currently unable to provide routing information
+ """
+ new_routing_info = None
+ try:
+ new_routing_info = self.fetch_routing_info(
+ address, database, imp_user, bookmarks, timeout
+ )
+ except Neo4jError as e:
+ # checks if the code is an error that is caused by the client. In
+ # this case there is no sense in trying to fetch a RT from another
+ # router. Hence, the driver should fail fast during discovery.
+ if e.is_fatal_during_discovery():
+ raise
+ except (ServiceUnavailable, SessionExpired):
+ pass
+ if not new_routing_info:
+ log.debug("Failed to fetch routing info %s", address)
+ return None
+ else:
+ servers = new_routing_info[0]["servers"]
+ ttl = new_routing_info[0]["ttl"]
+ database = new_routing_info[0].get("db", database)
+ new_routing_table = RoutingTable.parse_routing_info(
+ database=database, servers=servers, ttl=ttl
+ )
+
+ # Parse routing info and count the number of each type of server
+ num_routers = len(new_routing_table.routers)
+ num_readers = len(new_routing_table.readers)
+
+ # num_writers = len(new_routing_table.writers)
+ # If no writers are available. This likely indicates a temporary state,
+ # such as leader switching, so we should not signal an error.
+
+ # No routers
+ if num_routers == 0:
+ log.debug("No routing servers returned from server %s", address)
+ return None
+
+ # No readers
+ if num_readers == 0:
+ log.debug("No read servers returned from server %s", address)
+ return None
+
+ # At least one of each is fine, so return this table
+ return new_routing_table
+
+ def _update_routing_table_from(
+ self, *routers, database=None, imp_user=None, bookmarks=None,
+ database_callback=None
+ ):
+ """ Try to update routing tables with the given routers.
+
+ :return: True if the routing table is successfully updated,
+ otherwise False
+ """
+ log.debug("Attempting to update routing table from {}".format(
+ ", ".join(map(repr, routers)))
+ )
+ for router in routers:
+ for address in NetworkUtil.resolve_address(
+ router, resolver=self.pool_config.resolver
+ ):
+ new_routing_table = self.fetch_routing_table(
+ address=address,
+ timeout=self.pool_config.connection_timeout,
+ database=database, imp_user=imp_user, bookmarks=bookmarks
+ )
+ if new_routing_table is not None:
+ new_databse = new_routing_table.database
+ old_routing_table = self.get_or_create_routing_table(
+ new_databse
+ )
+ old_routing_table.update(new_routing_table)
+ log.debug(
+ "[#0000] C: address=%r (%r)",
+ address, self.routing_tables[new_databse]
+ )
+ if callable(database_callback):
+ database_callback(new_databse)
+ return True
+ self.deactivate(router)
+ return False
+
+ def update_routing_table(
+ self, *, database, imp_user, bookmarks, database_callback=None
+ ):
+ """ Update the routing table from the first router able to provide
+ valid routing information.
+
+ :param database: The database name
+ :param imp_user: the user to impersonate while fetching the routing
+ table
+ :type imp_user: str or None
+ :param bookmarks: bookmarks used when fetching routing table
+ :param database_callback: A callback function that will be called with
+ the database name as only argument when a new routing table has been
+ acquired. This database name might different from `database` if that
+ was None and the underlying protocol supports reporting back the
+ actual database.
+
+ :raise neo4j.exceptions.ServiceUnavailable:
+ """
+ with self.refresh_lock:
+ routing_table = self.get_or_create_routing_table(database)
+ # copied because it can be modified
+ existing_routers = set(routing_table.routers)
+
+ prefer_initial_routing_address = \
+ self.routing_tables[database].initialized_without_writers
+
+ if prefer_initial_routing_address:
+ # TODO: Test this state
+ if self._update_routing_table_from(
+ self.first_initial_routing_address, database=database,
+ imp_user=imp_user, bookmarks=bookmarks,
+ database_callback=database_callback
+ ):
+ # Why is only the first initial routing address used?
+ return
+ if self._update_routing_table_from(
+ *(existing_routers - {self.first_initial_routing_address}),
+ database=database, imp_user=imp_user, bookmarks=bookmarks,
+ database_callback=database_callback
+ ):
+ return
+
+ if not prefer_initial_routing_address:
+ if self._update_routing_table_from(
+ self.first_initial_routing_address, database=database,
+ imp_user=imp_user, bookmarks=bookmarks,
+ database_callback=database_callback
+ ):
+ # Why is only the first initial routing address used?
+ return
+
+ # None of the routers have been successful, so just fail
+ log.error("Unable to retrieve routing information")
+ raise ServiceUnavailable("Unable to retrieve routing information")
+
+ def update_connection_pool(self, *, database):
+ routing_table = self.get_or_create_routing_table(database)
+ servers = routing_table.servers()
+ for address in list(self.connections):
+ if address.unresolved not in servers:
+ super(Neo4jPool, self).deactivate(address)
+
+ def ensure_routing_table_is_fresh(
+ self, *, access_mode, database, imp_user, bookmarks,
+ database_callback=None
+ ):
+ """ Update the routing table if stale.
+
+ This method performs two freshness checks, before and after acquiring
+ the refresh lock. If the routing table is already fresh on entry, the
+ method exits immediately; otherwise, the refresh lock is acquired and
+ the second freshness check that follows determines whether an update
+ is still required.
+
+ This method is thread-safe.
+
+ :return: `True` if an update was required, `False` otherwise.
+ """
+ from neo4j.api import READ_ACCESS
+ with self.refresh_lock:
+ routing_table = self.get_or_create_routing_table(database)
+ if routing_table.is_fresh(readonly=(access_mode == READ_ACCESS)):
+ # Readers are fresh.
+ return False
+
+ self.update_routing_table(
+ database=database, imp_user=imp_user, bookmarks=bookmarks,
+ database_callback=database_callback
+ )
+ self.update_connection_pool(database=database)
+
+ for database in list(self.routing_tables.keys()):
+ # Remove unused databases in the routing table
+ # Remove the routing table after a timeout = TTL + 30s
+ log.debug("[#0000] C: database=%s", database)
+ if (self.routing_tables[database].should_be_purged_from_memory()
+ and database != self.workspace_config.database):
+ del self.routing_tables[database]
+
+ return True
+
+ def _select_address(self, *, access_mode, database):
+ from ...api import READ_ACCESS
+ """ Selects the address with the fewest in-use connections.
+ """
+ with self.refresh_lock:
+ if access_mode == READ_ACCESS:
+ addresses = self.routing_tables[database].readers
+ else:
+ addresses = self.routing_tables[database].writers
+ addresses_by_usage = {}
+ for address in addresses:
+ addresses_by_usage.setdefault(
+ self.in_use_connection_count(address), []
+ ).append(address)
+ if not addresses_by_usage:
+ if access_mode == READ_ACCESS:
+ raise ReadServiceUnavailable(
+ "No read service currently available"
+ )
+ else:
+ raise WriteServiceUnavailable(
+ "No write service currently available"
+ )
+ return choice(addresses_by_usage[min(addresses_by_usage)])
+
+ def acquire(
+ self, access_mode=None, timeout=None, database=None, bookmarks=None
+ ):
+ if access_mode not in (WRITE_ACCESS, READ_ACCESS):
+ raise ClientError("Non valid 'access_mode'; {}".format(access_mode))
+ if not timeout:
+ raise ClientError("'timeout' must be a float larger than 0; {}"
+ .format(timeout))
+
+ from neo4j.api import check_access_mode
+ access_mode = check_access_mode(access_mode)
+ with self.refresh_lock:
+ log.debug("[#0000] C: %r",
+ self.routing_tables)
+ self.ensure_routing_table_is_fresh(
+ access_mode=access_mode, database=database, imp_user=None,
+ bookmarks=bookmarks
+ )
+
+ while True:
+ try:
+ # Get an address for a connection that have the fewest in-use
+ # connections.
+ address = self._select_address(
+ access_mode=access_mode, database=database
+ )
+ except (ReadServiceUnavailable, WriteServiceUnavailable) as err:
+ raise SessionExpired("Failed to obtain connection towards '%s' server." % access_mode) from err
+ try:
+ log.debug("[#0000] C: database=%r address=%r", database, address)
+ # should always be a resolved address
+ connection = self._acquire(address, timeout=timeout)
+ except ServiceUnavailable:
+ self.deactivate(address=address)
+ else:
+ return connection
+
+ def deactivate(self, address):
+ """ Deactivate an address from the connection pool,
+ if present, remove from the routing table and also closing
+ all idle connections to that address.
+ """
+ log.debug("[#0000] C: Deactivating address %r", address)
+ # We use `discard` instead of `remove` here since the former
+ # will not fail if the address has already been removed.
+ for database in self.routing_tables.keys():
+ self.routing_tables[database].routers.discard(address)
+ self.routing_tables[database].readers.discard(address)
+ self.routing_tables[database].writers.discard(address)
+ log.debug("[#0000] C: table=%r", self.routing_tables)
+ super(Neo4jPool, self).deactivate(address)
+
+ def on_write_failure(self, address):
+ """ Remove a writer address from the routing table, if present.
+ """
+ log.debug("[#0000] C: Removing writer %r", address)
+ for database in self.routing_tables.keys():
+ self.routing_tables[database].writers.discard(address)
+ log.debug("[#0000] C: table=%r", self.routing_tables)
diff --git a/neo4j/_sync/work/__init__.py b/neo4j/_sync/work/__init__.py
new file mode 100644
index 00000000..3ceebdb1
--- /dev/null
+++ b/neo4j/_sync/work/__init__.py
@@ -0,0 +1,32 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .session import (
+ Result,
+ Session,
+ Transaction,
+ Workspace,
+)
+
+
+__all__ = [
+ "Result",
+ "Session",
+ "Transaction",
+ "Workspace",
+]
diff --git a/neo4j/work/result.py b/neo4j/_sync/work/result.py
similarity index 90%
rename from neo4j/work/result.py
rename to neo4j/_sync/work/result.py
index bb3dba02..8d6342fd 100644
--- a/neo4j/work/result.py
+++ b/neo4j/_sync/work/result.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -22,16 +19,16 @@
from collections import deque
from warnings import warn
-from neo4j.data import DataDehydrator
-from neo4j.io import ConnectionErrorHandler
-from neo4j.work.summary import ResultSummary
-from neo4j.exceptions import ResultConsumedError
+from ..._async_compat.util import Util
+from ...data import DataDehydrator
+from ...work import ResultSummary
+from ..io import ConnectionErrorHandler
class Result:
"""A handler for the result of Cypher query execution. Instances
of this class are typically constructed and returned by
- :meth:`.Session.run` and :meth:`.Transaction.run`.
+ :meth:`.AyncSession.run` and :meth:`.Transaction.run`.
"""
def __init__(self, connection, hydrant, fetch_size, on_closed,
@@ -63,19 +60,23 @@ def _qid(self):
else:
return self._raw_qid
- def _tx_ready_run(self, query, parameters, **kwparameters):
+ def _tx_ready_run(self, query, parameters, **kwargs):
# BEGIN+RUN does not carry any extra on the RUN message.
# BEGIN {extra}
# RUN "query" {parameters} {extra}
- self._run(query, parameters, None, None, None, None, **kwparameters)
+ self._run(
+ query, parameters, None, None, None, None, **kwargs
+ )
- def _run(self, query, parameters, db, imp_user, access_mode, bookmarks,
- **kwparameters):
+ def _run(
+ self, query, parameters, db, imp_user, access_mode, bookmarks,
+ **kwargs
+ ):
query_text = str(query) # Query or string object
query_metadata = getattr(query, "metadata", None)
query_timeout = getattr(query, "timeout", None)
- parameters = DataDehydrator.fix_parameters(dict(parameters or {}, **kwparameters))
+ parameters = DataDehydrator.fix_parameters(dict(parameters or {}, **kwargs))
self._metadata = {
"query": query_text,
@@ -95,7 +96,7 @@ def on_attached(metadata):
def on_failed_attach(metadata):
self._metadata.update(metadata)
self._attached = False
- self._on_closed()
+ Util.callback(self._on_closed)
self._connection.run(
query_text,
@@ -120,11 +121,11 @@ def on_records(records):
def on_summary():
self._attached = False
- self._on_closed()
+ Util.callback(self._on_closed)
def on_failure(metadata):
self._attached = False
- self._on_closed()
+ Util.callback(self._on_closed)
def on_success(summary_metadata):
self._streaming = False
@@ -148,12 +149,12 @@ def on_success(summary_metadata):
def _discard(self):
def on_summary():
self._attached = False
- self._on_closed()
+ Util.callback(self._on_closed)
def on_failure(metadata):
self._metadata.update(metadata)
self._attached = False
- self._on_closed()
+ Util.callback(self._on_closed)
def on_success(summary_metadata):
self._streaming = False
@@ -203,7 +204,7 @@ def _attach(self):
self._connection.fetch_message()
def _buffer(self, n=None):
- """Try to fill `self_record_buffer` with n records.
+ """Try to fill `self._record_buffer` with n records.
Might end up with more records in the buffer if the fetch size makes it
overshoot.
@@ -259,7 +260,9 @@ def consume(self):
Example::
def create_node_tx(tx, name):
- result = tx.run("CREATE (n:ExampleNode { name: $name }) RETURN n", name=name)
+ result = tx.run(
+ "CREATE (n:ExampleNode { name: $name }) RETURN n", name=name
+ )
record = result.single()
value = record.value()
info = result.consume()
@@ -273,11 +276,12 @@ def create_node_tx(tx, name):
def get_two_tx(tx):
result = tx.run("UNWIND [1,2,3,4] AS x RETURN x")
values = []
- for ix, record in enumerate(result):
- if x > 1:
+ for record in result:
+ if len(values) >= 2:
break
values.append(record.values())
- info = result.consume() # discard the remaining records if there are any
+ # discard the remaining records if there are any
+ info = result.consume()
# use the info for logging etc.
return values, info
@@ -310,7 +314,8 @@ def single(self):
# raise SomeError("Expected exactly 1 record, found %i"
# % len(self._record_buffer))
# return self._record_buffer.popleft()
- records = list(self) # TODO: exhausts the result with self.consume if there are more records.
+ # TODO: exhausts the result with self.consume if there are more records.
+ records = Util.list(self)
size = len(records)
if size == 0:
return None
diff --git a/neo4j/work/simple.py b/neo4j/_sync/work/session.py
similarity index 74%
rename from neo4j/work/simple.py
rename to neo4j/_sync/work/session.py
index 9c404da4..dbdf9482 100644
--- a/neo4j/work/simple.py
+++ b/neo4j/_sync/work/session.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,31 +16,32 @@
# limitations under the License.
+import asyncio
from logging import getLogger
from random import random
-from time import (
- perf_counter,
- sleep,
-)
+from time import perf_counter
-from neo4j.api import (
+from ..._async_compat import sleep
+from ...api import (
READ_ACCESS,
WRITE_ACCESS,
)
-from neo4j.conf import SessionConfig
-from neo4j.data import DataHydrator
-from neo4j.exceptions import (
+from ...conf import SessionConfig
+from ...data import DataHydrator
+from ...exceptions import (
ClientError,
IncompleteCommit,
Neo4jError,
ServiceUnavailable,
SessionExpired,
- TransientError,
TransactionError,
+ TransientError,
)
-from neo4j.work import Workspace
-from neo4j.work.result import Result
-from neo4j.work.transaction import Transaction
+from ...work import Query
+from .result import Result
+from .transaction import Transaction
+from .workspace import Workspace
+
log = getLogger("neo4j")
@@ -53,9 +51,10 @@ class Session(Workspace):
of work. Connections are drawn from the :class:`.Driver` connection
pool as required.
- Session creation is a lightweight operation and sessions are not thread
- safe. Therefore a session should generally be short-lived, and not
- span multiple threads.
+ Session creation is a lightweight operation and sessions are not safe to
+ be used in concurrent contexts (multiple threads/coroutines).
+ Therefore, a session should generally be short-lived, and must not
+ span multiple threads/coroutines.
In general, sessions will be created and destroyed within a `with`
context. For example::
@@ -75,7 +74,7 @@ class Session(Workspace):
_transaction = None
# The current auto-transaction result, if any.
- _autoResult = None
+ _auto_result = None
# The state this session is in.
_state_failed = False
@@ -89,6 +88,8 @@ def __init__(self, pool, session_config):
self._bookmarks = tuple(session_config.bookmarks)
def __del__(self):
+ if asyncio.iscoroutinefunction(self.close):
+ return
try:
self.close()
except (OSError, ServiceUnavailable, SessionExpired):
@@ -112,14 +113,14 @@ def _collect_bookmark(self, bookmark):
self._bookmarks = [bookmark]
def _result_closed(self):
- if self._autoResult:
- self._collect_bookmark(self._autoResult._bookmark)
- self._autoResult = None
+ if self._auto_result:
+ self._collect_bookmark(self._auto_result._bookmark)
+ self._auto_result = None
self._disconnect()
def _result_error(self, _):
- if self._autoResult:
- self._autoResult = None
+ if self._auto_result:
+ self._auto_result = None
self._disconnect()
def close(self):
@@ -129,14 +130,14 @@ def close(self):
roll back any outstanding transactions.
"""
if self._connection:
- if self._autoResult:
+ if self._auto_result:
if self._state_failed is False:
try:
- self._autoResult.consume()
- self._collect_bookmark(self._autoResult._bookmark)
+ self._auto_result.consume()
+ self._collect_bookmark(self._auto_result._bookmark)
except Exception as error:
# TODO: Investigate potential non graceful close states
- self._autoResult = None
+ self._auto_result = None
self._state_failed = True
if self._transaction:
@@ -163,7 +164,7 @@ def close(self):
self._state_failed = False
self._closed = True
- def run(self, query, parameters=None, **kwparameters):
+ def run(self, query, parameters=None, **kwargs):
"""Run a Cypher query within an auto-commit transaction.
The query is sent and the result header received
@@ -185,9 +186,9 @@ def run(self, query, parameters=None, **kwparameters):
:type query: str, neo4j.Query
:param parameters: dictionary of parameters
:type parameters: dict
- :param kwparameters: additional keyword parameters
+ :param kwargs: additional keyword parameters
:returns: a new :class:`neo4j.Result` object
- :rtype: :class:`neo4j.Result`
+ :rtype: Result
"""
if not query:
raise ValueError("Cannot run an empty query")
@@ -197,8 +198,9 @@ def run(self, query, parameters=None, **kwparameters):
if self._transaction:
raise ClientError("Explicit Transaction must be handled explicitly")
- if self._autoResult:
- self._autoResult._buffer_all() # This will buffer upp all records for the previous auto-transaction
+ if self._auto_result:
+ # This will buffer upp all records for the previous auto-transaction
+ self._auto_result._buffer_all()
if not self._connection:
self._connect(self._config.default_access_mode)
@@ -208,17 +210,17 @@ def run(self, query, parameters=None, **kwparameters):
hydrant = DataHydrator()
- self._autoResult = Result(
+ self._auto_result = Result(
cx, hydrant, self._config.fetch_size, self._result_closed,
self._result_error
)
- self._autoResult._run(
+ self._auto_result._run(
query, parameters, self._config.database,
self._config.impersonated_user, self._config.default_access_mode,
- self._bookmarks, **kwparameters
+ self._bookmarks, **kwargs
)
- return self._autoResult
+ return self._auto_result
def last_bookmark(self):
"""Return the bookmark received following the last completed transaction.
@@ -228,8 +230,8 @@ def last_bookmark(self):
"""
# The set of bookmarks to be passed into the next transaction.
- if self._autoResult:
- self._autoResult.consume()
+ if self._auto_result:
+ self._auto_result.consume()
if self._transaction and self._transaction._closed:
self._collect_bookmark(self._transaction._bookmark)
@@ -286,25 +288,28 @@ def begin_transaction(self, metadata=None, timeout=None):
:type timeout: int
:returns: A new transaction instance.
- :rtype: :class:`neo4j.Transaction`
+ :rtype: Transaction
:raises TransactionError: :class:`neo4j.exceptions.TransactionError` if a transaction is already open.
"""
# TODO: Implement TransactionConfig consumption
- if self._autoResult:
- self._autoResult.consume()
+ if self._auto_result:
+ self._auto_result.consume()
if self._transaction:
raise TransactionError("Explicit transaction already open")
- self._open_transaction(access_mode=self._config.default_access_mode,
- metadata=metadata, timeout=timeout)
+ self._open_transaction(
+ access_mode=self._config.default_access_mode, metadata=metadata,
+ timeout=timeout
+ )
return self._transaction
- def _run_transaction(self, access_mode, transaction_function, *args, **kwargs):
-
+ def _run_transaction(
+ self, access_mode, transaction_function, *args, **kwargs
+ ):
if not callable(transaction_function):
raise TypeError("Unit of work is not callable")
@@ -319,7 +324,10 @@ def _run_transaction(self, access_mode, transaction_function, *args, **kwargs):
while True:
try:
- self._open_transaction(access_mode=access_mode, metadata=metadata, timeout=timeout)
+ self._open_transaction(
+ access_mode=access_mode, metadata=metadata,
+ timeout=timeout
+ )
tx = self._transaction
try:
result = transaction_function(tx, *args, **kwargs)
@@ -358,15 +366,14 @@ def read_transaction(self, transaction_function, *args, **kwargs):
This transaction will automatically be committed unless an exception is thrown during query execution or by the user code.
Note, that this function perform retries and that the supplied `transaction_function` might get invoked more than once.
- Managed transactions should not generally be explicitly committed (via tx.commit()).
+ Managed transactions should not generally be explicitly committed
+ (via ``tx.commit()``).
Example::
def do_cypher_tx(tx, cypher):
result = tx.run(cypher)
- values = []
- for record in result:
- values.append(record.values())
+ values = [record.values() for record in result]
return values
with driver.session() as session:
@@ -377,23 +384,29 @@ def do_cypher_tx(tx, cypher):
def get_two_tx(tx):
result = tx.run("UNWIND [1,2,3,4] AS x RETURN x")
values = []
- for ix, record in enumerate(result):
- if x > 1:
+ for record in result:
+ if len(values) >= 2:
break
values.append(record.values())
- info = result.consume() # discard the remaining records if there are any
+ # discard the remaining records if there are any
+ info = result.consume()
# use the info for logging etc.
return values
with driver.session() as session:
values = session.read_transaction(get_two_tx)
- :param transaction_function: a function that takes a transaction as an argument and does work with the transaction. `tx_function(tx, *args, **kwargs)`
+ :param transaction_function: a function that takes a transaction as an
+ argument and does work with the transaction.
+ `transaction_function(tx, *args, **kwargs)` where `tx` is a
+ :class:`.Transaction`.
:param args: arguments for the `transaction_function`
:param kwargs: key word arguments for the `transaction_function`
:return: a result as returned by the given unit of work
"""
- return self._run_transaction(READ_ACCESS, transaction_function, *args, **kwargs)
+ return self._run_transaction(
+ READ_ACCESS, transaction_function, *args, **kwargs
+ )
def write_transaction(self, transaction_function, *args, **kwargs):
"""Execute a unit of work in a managed write transaction.
@@ -405,79 +418,25 @@ def write_transaction(self, transaction_function, *args, **kwargs):
Example::
def create_node_tx(tx, name):
- result = tx.run("CREATE (n:NodeExample { name: $name }) RETURN id(n) AS node_id", name=name)
+ query = "CREATE (n:NodeExample { name: $name }) RETURN id(n) AS node_id"
+ result = tx.run(query, name=name)
record = result.single()
return record["node_id"]
with driver.session() as session:
node_id = session.write_transaction(create_node_tx, "example")
-
- :param transaction_function: a function that takes a transaction as an argument and does work with the transaction. `tx_function(tx, *args, **kwargs)`
+ :param transaction_function: a function that takes a transaction as an
+ argument and does work with the transaction.
+ `transaction_function(tx, *args, **kwargs)` where `tx` is a
+ :class:`.Transaction`.
:param args: key word arguments for the `transaction_function`
:param kwargs: key word arguments for the `transaction_function`
:return: a result as returned by the given unit of work
"""
- return self._run_transaction(WRITE_ACCESS, transaction_function, *args, **kwargs)
-
-
-class Query:
- """ Create a new query.
-
- :param text: The query text.
- :type text: str
- :param metadata: metadata attached to the query.
- :type metadata: dict
- :param timeout: seconds.
- :type timeout: int
- """
- def __init__(self, text, metadata=None, timeout=None):
- self.text = text
-
- self.metadata = metadata
- self.timeout = timeout
-
- def __str__(self):
- return str(self.text)
-
-
-def unit_of_work(metadata=None, timeout=None):
- """This function is a decorator for transaction functions that allows extra control over how the transaction is carried out.
-
- For example, a timeout may be applied::
-
- @unit_of_work(timeout=100)
- def count_people_tx(tx):
- result = tx.run("MATCH (a:Person) RETURN count(a) AS persons")
- record = result.single()
- return record["persons"]
-
- :param metadata:
- a dictionary with metadata.
- Specified metadata will be attached to the executing transaction and visible in the output of ``dbms.listQueries`` and ``dbms.listTransactions`` procedures.
- It will also get logged to the ``query.log``.
- This functionality makes it easier to tag transactions and is equivalent to ``dbms.setTXMetaData`` procedure, see https://neo4j.com/docs/operations-manual/current/reference/procedures/ for procedure reference.
- :type metadata: dict
-
- :param timeout:
- the transaction timeout in seconds.
- Transactions that execute longer than the configured timeout will be terminated by the database.
- This functionality allows to limit query/transaction execution time.
- Specified timeout overrides the default timeout configured in the database using ``dbms.transaction.timeout`` setting.
- Value should not represent a duration of zero or negative duration.
- :type timeout: int
- """
-
- def wrapper(f):
-
- def wrapped(*args, **kwargs):
- return f(*args, **kwargs)
-
- wrapped.metadata = metadata
- wrapped.timeout = timeout
- return wrapped
-
- return wrapper
+ return self._run_transaction(
+ WRITE_ACCESS, transaction_function, *args, **kwargs
+ )
def retry_delay_generator(initial_delay, multiplier, jitter_factor):
diff --git a/neo4j/work/transaction.py b/neo4j/_sync/work/transaction.py
similarity index 90%
rename from neo4j/work/transaction.py
rename to neo4j/_sync/work/transaction.py
index 8b773987..73d08238 100644
--- a/neo4j/work/transaction.py
+++ b/neo4j/_sync/work/transaction.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,22 +16,22 @@
# limitations under the License.
-from neo4j.work.result import Result
-from neo4j.data import DataHydrator
-from neo4j.exceptions import (
- TransactionError,
-)
-from neo4j.io import ConnectionErrorHandler
+from ..._async_compat.util import Util
+from ...data import DataHydrator
+from ...exceptions import TransactionError
+from ...work import Query
+from ..io import ConnectionErrorHandler
+from .result import Result
class Transaction:
- """ Container for multiple Cypher queries to be executed within
- a single context. Transactions can be used within a :py:const:`with`
+ """ Container for multiple Cypher queries to be executed within a single
+ context. asynctransactions can be used within a :py:const:`with`
block where the transaction is committed or rolled back on based on
- whether or not an exception is raised::
+ whether an exception is raised::
with session.begin_transaction() as tx:
- pass
+ ...
"""
@@ -62,7 +59,9 @@ def __exit__(self, exception_type, exception_value, traceback):
self.commit()
self.close()
- def _begin(self, database, imp_user, bookmarks, access_mode, metadata, timeout):
+ def _begin(
+ self, database, imp_user, bookmarks, access_mode, metadata, timeout
+ ):
self._connection.begin(
bookmarks=bookmarks, metadata=metadata, timeout=timeout,
mode=access_mode, db=database, imp_user=imp_user
@@ -75,7 +74,7 @@ def _result_on_closed_handler(self):
def _error_handler(self, exc):
self._last_error = exc
- self._on_error(exc)
+ Util.callback(self._on_error, exc)
def _consume_results(self):
for result in self._results:
@@ -110,7 +109,6 @@ def run(self, query, parameters=None, **kwparameters):
:rtype: :class:`neo4j.Result`
:raise TransactionError: if the transaction is already closed
"""
- from neo4j.work.simple import Query
if isinstance(query, Query):
raise ValueError("Query object is only supported for session.run")
@@ -151,14 +149,15 @@ def commit(self):
metadata = {}
try:
- self._consume_results() # DISCARD pending records then do a commit.
+ # DISCARD pending records then do a commit.
+ self._consume_results()
self._connection.commit(on_success=metadata.update)
self._connection.send_all()
self._connection.fetch_all()
self._bookmark = metadata.get("bookmark")
finally:
self._closed = True
- self._on_closed()
+ Util.callback(self._on_closed)
return self._bookmark
@@ -182,7 +181,7 @@ def rollback(self):
self._connection.fetch_all()
finally:
self._closed = True
- self._on_closed()
+ Util.callback(self._on_closed)
def close(self):
"""Close this transaction, triggering a ROLLBACK if not closed.
diff --git a/neo4j/_sync/work/workspace.py b/neo4j/_sync/work/workspace.py
new file mode 100644
index 00000000..3ed50ad2
--- /dev/null
+++ b/neo4j/_sync/work/workspace.py
@@ -0,0 +1,102 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import asyncio
+
+from ...conf import WorkspaceConfig
+from ...exceptions import ServiceUnavailable
+from ..io import Neo4jPool
+
+
+class Workspace:
+
+ def __init__(self, pool, config):
+ assert isinstance(config, WorkspaceConfig)
+ self._pool = pool
+ self._config = config
+ self._connection = None
+ self._connection_access_mode = None
+ # Sessions are supposed to cache the database on which to operate.
+ self._cached_database = False
+ self._bookmarks = None
+
+ def __del__(self):
+ if asyncio.iscoroutinefunction(self.close):
+ return
+ try:
+ self.close()
+ except OSError:
+ pass
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.close()
+
+ def _set_cached_database(self, database):
+ self._cached_database = True
+ self._config.database = database
+
+ def _connect(self, access_mode):
+ if self._connection:
+ # TODO: Investigate this
+ # log.warning("FIXME: should always disconnect before connect")
+ self._connection.send_all()
+ self._connection.fetch_all()
+ self._disconnect()
+ if not self._cached_database:
+ if (self._config.database is not None
+ or not isinstance(self._pool, Neo4jPool)):
+ self._set_cached_database(self._config.database)
+ else:
+ # This is the first time we open a connection to a server in a
+ # cluster environment for this session without explicitly
+ # configured database. Hence, we request a routing table update
+ # to try to fetch the home database. If provided by the server,
+ # we shall use this database explicitly for all subsequent
+ # actions within this session.
+ self._pool.update_routing_table(
+ database=self._config.database,
+ imp_user=self._config.impersonated_user,
+ bookmarks=self._bookmarks,
+ database_callback=self._set_cached_database
+ )
+ self._connection = self._pool.acquire(
+ access_mode=access_mode,
+ timeout=self._config.connection_acquisition_timeout,
+ database=self._config.database,
+ bookmarks=self._bookmarks
+ )
+ self._connection_access_mode = access_mode
+
+ def _disconnect(self, sync=False):
+ if self._connection:
+ if sync:
+ try:
+ self._connection.send_all()
+ self._connection.fetch_all()
+ except ServiceUnavailable:
+ pass
+ if self._connection:
+ self._pool.release(self._connection)
+ self._connection = None
+ self._connection_access_mode = None
+
+ def close(self):
+ self._disconnect(sync=True)
diff --git a/neo4j/addressing.py b/neo4j/addressing.py
index f547f05d..780e2c23 100644
--- a/neo4j/addressing.py
+++ b/neo4j/addressing.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,14 +16,12 @@
# limitations under the License.
+import logging
from socket import (
- getaddrinfo,
- getservbyname,
- SOCK_STREAM,
AF_INET,
AF_INET6,
+ getservbyname,
)
-import logging
log = logging.getLogger("neo4j")
@@ -142,59 +137,6 @@ def port(self):
def unresolved(self):
return self
- @classmethod
- def _dns_resolve(cls, address, family=0):
- """ Regular DNS resolver. Takes an address object and optional
- address family for filtering.
-
- :param address:
- :param family:
- :return:
- """
- try:
- info = getaddrinfo(address.host, address.port, family, SOCK_STREAM)
- except OSError:
- raise ValueError("Cannot resolve address {}".format(address))
- else:
- resolved = []
- for fam, _, _, _, addr in info:
- if fam == AF_INET6 and addr[3] != 0:
- # skip any IPv6 addresses with a non-zero scope id
- # as these appear to cause problems on some platforms
- continue
- if addr not in resolved:
- resolved.append(ResolvedAddress(
- addr, host_name=address.host_name)
- )
- return resolved
-
- def resolve(self, family=0, resolver=None):
- """ Carry out domain name resolution on this Address object.
-
- If a resolver function is supplied, and is callable, this is
- called first, with this object as its argument. This may yield
- multiple output addresses, which are chained into a subsequent
- regular DNS resolution call. If no resolver function is passed,
- the DNS resolution is carried out on the original Address
- object.
-
- This function returns a list of resolved Address objects.
-
- :param family: optional address family to filter resolved
- addresses by (e.g. AF_INET6)
- :param resolver: optional customer resolver function to be
- called before regular DNS resolution
- """
-
- log.debug("[#0000] C: %s", self)
- resolved = []
- if resolver:
- for address in map(Address, resolver(self)):
- resolved.extend(self._dns_resolve(address, family))
- else:
- resolved.extend(self._dns_resolve(self, family))
- return resolved
-
@property
def port_number(self):
try:
@@ -234,9 +176,6 @@ def host_name(self):
def unresolved(self):
return super().__new__(Address, (self._host_name, *self[1:]))
- def resolve(self, family=0, resolver=None):
- return [self]
-
def __new__(cls, iterable, host_name=None):
new = super().__new__(cls, iterable)
new._host_name = host_name
diff --git a/neo4j/api.py b/neo4j/api.py
index 1275c2a3..36c05cbb 100644
--- a/neo4j/api.py
+++ b/neo4j/api.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -18,20 +15,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+""" Base classes and helpers.
+"""
+
+
from urllib.parse import (
- urlparse,
parse_qs,
+ urlparse,
)
-from.exceptions import (
- DriverError,
+
+from .exceptions import (
ConfigurationError,
+ DriverError,
)
from .meta import deprecated
-""" Base classes and helpers.
-"""
-
READ_ACCESS = "READ"
WRITE_ACCESS = "WRITE"
@@ -107,7 +106,8 @@ def basic_auth(user, password, realm=None):
:param realm: specifies the authentication provider
:type realm: str or None
- :return: auth token for use with :meth:`GraphDatabase.driver`
+ :return: auth token for use with :meth:`GraphDatabase.driver` or
+ :meth:`AsyncGraphDatabase.driver`
:rtype: :class:`neo4j.Auth`
"""
return Auth("basic", user, password, realm)
@@ -122,7 +122,8 @@ def kerberos_auth(base64_encoded_ticket):
the credentials
:type base64_encoded_ticket: str
- :return: auth token for use with :meth:`GraphDatabase.driver`
+ :return: auth token for use with :meth:`GraphDatabase.driver` or
+ :meth:`AsyncGraphDatabase.driver`
:rtype: :class:`neo4j.Auth`
"""
return Auth("kerberos", "", base64_encoded_ticket)
@@ -137,7 +138,8 @@ def bearer_auth(base64_encoded_token):
by a Single-Sign-On provider.
:type base64_encoded_token: str
- :return: auth token for use with :meth:`GraphDatabase.driver`
+ :return: auth token for use with :meth:`GraphDatabase.driver` or
+ :meth:`AsyncGraphDatabase.driver`
:rtype: :class:`neo4j.Auth`
"""
return Auth("bearer", None, base64_encoded_token)
@@ -158,7 +160,8 @@ def custom_auth(principal, credentials, realm, scheme, **parameters):
authentication provider
:type parameters: Dict[str, Any]
- :return: auth token for use with :meth:`GraphDatabase.driver`
+ :return: auth token for use with :meth:`GraphDatabase.driver` or
+ :meth:`AsyncGraphDatabase.driver`
:rtype: :class:`neo4j.Auth`
"""
return Auth(scheme, principal, credentials, realm, **parameters)
diff --git a/neo4j/conf.py b/neo4j/conf.py
index fb86c489..7131ba56 100644
--- a/neo4j/conf.py
+++ b/neo4j/conf.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -23,17 +20,14 @@
from collections.abc import Mapping
from warnings import warn
-from neo4j.meta import get_user_agent
-
-from neo4j.api import (
- TRUST_SYSTEM_CA_SIGNED_CERTIFICATES,
+from .api import (
+ DEFAULT_DATABASE,
TRUST_ALL_CERTIFICATES,
+ TRUST_SYSTEM_CA_SIGNED_CERTIFICATES,
WRITE_ACCESS,
- DEFAULT_DATABASE,
-)
-from neo4j.exceptions import (
- ConfigurationError,
)
+from .exceptions import ConfigurationError
+from .meta import get_user_agent
def iter_items(iterable):
diff --git a/neo4j/data.py b/neo4j/data.py
index e1a3a959..d60937cd 100644
--- a/neo4j/data.py
+++ b/neo4j/data.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,22 +16,57 @@
# limitations under the License.
-from abc import ABCMeta, abstractmethod
-from collections.abc import Sequence, Set, Mapping
-from datetime import date, time, datetime, timedelta
+from abc import (
+ ABCMeta,
+ abstractmethod,
+)
+from collections.abc import (
+ Mapping,
+ Sequence,
+ Set,
+)
+from datetime import (
+ date,
+ datetime,
+ time,
+ timedelta,
+)
from functools import reduce
from operator import xor as xor_operator
-from neo4j.conf import iter_items
-from neo4j.graph import Graph, Node, Relationship, Path
-from neo4j.packstream import INT64_MIN, INT64_MAX, Structure
-from neo4j.spatial import Point, hydrate_point, dehydrate_point
-from neo4j.time import Date, Time, DateTime, Duration
-from neo4j.time.hydration import (
- hydrate_date, dehydrate_date,
- hydrate_time, dehydrate_time,
- hydrate_datetime, dehydrate_datetime,
- hydrate_duration, dehydrate_duration, dehydrate_timedelta,
+from .conf import iter_items
+from .graph import (
+ Graph,
+ Node,
+ Path,
+ Relationship,
+)
+from .packstream import (
+ INT64_MAX,
+ INT64_MIN,
+ Structure,
+)
+from .spatial import (
+ dehydrate_point,
+ hydrate_point,
+ Point,
+)
+from .time import (
+ Date,
+ DateTime,
+ Duration,
+ Time,
+)
+from .time.hydration import (
+ dehydrate_date,
+ dehydrate_datetime,
+ dehydrate_duration,
+ dehydrate_time,
+ dehydrate_timedelta,
+ hydrate_date,
+ hydrate_datetime,
+ hydrate_duration,
+ hydrate_time,
)
@@ -44,7 +76,7 @@
class Record(tuple, Mapping):
""" A :class:`.Record` is an immutable ordered collection of key-value
pairs. It is generally closer to a :py:class:`namedtuple` than to a
- :py:class:`OrderedDict` inasmuch as iteration of the collection will
+ :py:class:`OrderedDict` in as much as iteration of the collection will
yield values rather than keys.
"""
diff --git a/neo4j/debug.py b/neo4j/debug.py
index b8872a37..7b2f0db5 100644
--- a/neo4j/debug.py
+++ b/neo4j/debug.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,7 +16,16 @@
# limitations under the License.
-from logging import CRITICAL, ERROR, WARNING, INFO, DEBUG, Formatter, StreamHandler, getLogger
+from logging import (
+ CRITICAL,
+ DEBUG,
+ ERROR,
+ Formatter,
+ getLogger,
+ INFO,
+ StreamHandler,
+ WARNING,
+)
from sys import stderr
diff --git a/neo4j/exceptions.py b/neo4j/exceptions.py
index c69c75c7..2b944bcf 100644
--- a/neo4j/exceptions.py
+++ b/neo4j/exceptions.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
diff --git a/neo4j/graph/__init__.py b/neo4j/graph/__init__.py
index 66285af7..bc1e8a52 100644
--- a/neo4j/graph/__init__.py
+++ b/neo4j/graph/__init__.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -18,22 +15,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+
"""
Graph data types
"""
-from collections.abc import Mapping
-
-
__all__ = [
"Graph",
"Node",
- "Relationship",
"Path",
+ "Relationship",
]
+from collections.abc import Mapping
+
+
class Graph:
""" Local, self-contained graph object that acts as a container for
:class:`.Node` and :class:`.Relationship` instances.
diff --git a/neo4j/io/README.rst b/neo4j/io/README.rst
deleted file mode 100644
index dfe8742a..00000000
--- a/neo4j/io/README.rst
+++ /dev/null
@@ -1 +0,0 @@
-Regular (non-async) I/O for Neo4j.
\ No newline at end of file
diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py
deleted file mode 100644
index 348bbce8..00000000
--- a/neo4j/io/__init__.py
+++ /dev/null
@@ -1,1414 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-# Copyright (c) "Neo4j"
-# Neo4j Sweden AB [http://neo4j.com]
-#
-# This file is part of Neo4j.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-"""
-This module contains the low-level functionality required for speaking
-Bolt. It is not intended to be used directly by driver users. Instead,
-the `session` module provides the main user-facing abstractions.
-"""
-
-
-__all__ = [
- "Bolt",
- "BoltPool",
- "ConnectionErrorHandler",
- "Neo4jPool",
- "check_supported_server_product",
-]
-
-import abc
-from collections import (
- defaultdict,
- deque,
-)
-import logging
-from logging import getLogger
-from random import choice
-import selectors
-from socket import (
- AF_INET,
- AF_INET6,
- SHUT_RDWR,
- SO_KEEPALIVE,
- socket,
- SOL_SOCKET,
- timeout as SocketTimeout,
-)
-from ssl import (
- CertificateError,
- HAS_SNI,
- SSLError,
-)
-from threading import (
- Condition,
- RLock,
-)
-from time import perf_counter
-
-from neo4j._exceptions import (
- BoltError,
- BoltHandshakeError,
- BoltProtocolError,
- BoltSecurityError,
-)
-from neo4j.addressing import Address
-from neo4j.api import (
- READ_ACCESS,
- ServerInfo,
- Version,
- WRITE_ACCESS,
-)
-from neo4j.conf import (
- PoolConfig,
- WorkspaceConfig,
-)
-from neo4j.exceptions import (
- AuthError,
- ClientError,
- ConfigurationError,
- DriverError,
- IncompleteCommit,
- Neo4jError,
- ReadServiceUnavailable,
- ServiceUnavailable,
- SessionExpired,
- UnsupportedServerProduct,
- WriteServiceUnavailable,
-)
-from neo4j.io._common import (
- CommitResponse,
- ConnectionErrorHandler,
- Inbox,
- InitResponse,
- Outbox,
- Response,
-)
-from neo4j.meta import get_user_agent
-from neo4j.packstream import (
- Packer,
- Unpacker,
-)
-from neo4j.routing import RoutingTable
-
-# Set up logger
-log = getLogger("neo4j")
-
-
-class Bolt(abc.ABC):
- """ Server connection for Bolt protocol.
-
- A :class:`.Bolt` should be constructed following a
- successful .open()
-
- Bolt handshake and takes the socket over which
- the handshake was carried out.
- """
-
- MAGIC_PREAMBLE = b"\x60\x60\xB0\x17"
-
- PROTOCOL_VERSION = None
-
- # flag if connection needs RESET to go back to READY state
- is_reset = False
-
- # The socket
- in_use = False
-
- # The socket
- _closed = False
-
- # The socket
- _defunct = False
-
- #: The pool of which this connection is a member
- pool = None
-
- # Store the id of the most recent ran query to be able to reduce sent bits by
- # using the default (-1) to refer to the most recent query when pulling
- # results for it.
- most_recent_qid = None
-
- def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, user_agent=None, routing_context=None):
- self.unresolved_address = unresolved_address
- self.socket = sock
- self.server_info = ServerInfo(Address(sock.getpeername()), self.PROTOCOL_VERSION)
- # so far `connection.recv_timeout_seconds` is the only available
- # configuration hint that exists. Therefore, all hints can be stored at
- # connection level. This might change in the future.
- self.configuration_hints = {}
- self.outbox = Outbox()
- self.inbox = Inbox(self.socket, on_error=self._set_defunct_read)
- self.packer = Packer(self.outbox)
- self.unpacker = Unpacker(self.inbox)
- self.responses = deque()
- self._max_connection_lifetime = max_connection_lifetime
- self._creation_timestamp = perf_counter()
- self.routing_context = routing_context
-
- # Determine the user agent
- if user_agent:
- self.user_agent = user_agent
- else:
- self.user_agent = get_user_agent()
-
- # Determine auth details
- if not auth:
- self.auth_dict = {}
- elif isinstance(auth, tuple) and 2 <= len(auth) <= 3:
- from neo4j import Auth
- self.auth_dict = vars(Auth("basic", *auth))
- else:
- try:
- self.auth_dict = vars(auth)
- except (KeyError, TypeError):
- raise AuthError("Cannot determine auth details from %r" % auth)
-
- # Check for missing password
- try:
- credentials = self.auth_dict["credentials"]
- except KeyError:
- pass
- else:
- if credentials is None:
- raise AuthError("Password cannot be None")
-
- @property
- @abc.abstractmethod
- def supports_multiple_results(self):
- """ Boolean flag to indicate if the connection version supports multiple
- queries to be buffered on the server side (True) or if all results need
- to be eagerly pulled before sending the next RUN (False).
- """
- pass
-
- @property
- @abc.abstractmethod
- def supports_multiple_databases(self):
- """ Boolean flag to indicate if the connection version supports multiple
- databases.
- """
- pass
-
- @classmethod
- def protocol_handlers(cls, protocol_version=None):
- """ Return a dictionary of available Bolt protocol handlers,
- keyed by version tuple. If an explicit protocol version is
- provided, the dictionary will contain either zero or one items,
- depending on whether that version is supported. If no protocol
- version is provided, all available versions will be returned.
-
- :param protocol_version: tuple identifying a specific protocol
- version (e.g. (3, 5)) or None
- :return: dictionary of version tuple to handler class for all
- relevant and supported protocol versions
- :raise TypeError: if protocol version is not passed in a tuple
- """
-
- # Carry out Bolt subclass imports locally to avoid circular dependency issues.
- from neo4j.io._bolt3 import Bolt3
- from neo4j.io._bolt4 import Bolt4x0, Bolt4x1, Bolt4x2, Bolt4x3, Bolt4x4
-
- handlers = {
- Bolt3.PROTOCOL_VERSION: Bolt3,
- Bolt4x0.PROTOCOL_VERSION: Bolt4x0,
- Bolt4x1.PROTOCOL_VERSION: Bolt4x1,
- Bolt4x2.PROTOCOL_VERSION: Bolt4x2,
- Bolt4x3.PROTOCOL_VERSION: Bolt4x3,
- Bolt4x4.PROTOCOL_VERSION: Bolt4x4,
- }
-
- if protocol_version is None:
- return handlers
-
- if not isinstance(protocol_version, tuple):
- raise TypeError("Protocol version must be specified as a tuple")
-
- if protocol_version in handlers:
- return {protocol_version: handlers[protocol_version]}
-
- return {}
-
- @classmethod
- def version_list(cls, versions, limit=4):
- """ Return a list of supported protocol versions in order of
- preference. The number of protocol versions (or ranges)
- returned is limited to four.
- """
- # In fact, 4.3 is the fist version to support ranges. However, the range
- # support got backported to 4.2. But even if the server is too old to
- # have the backport, negotiating BOLT 4.1 is no problem as it's
- # equivalent to 4.2
- first_with_range_support = Version(4, 2)
- result = []
- for version in versions:
- if (result
- and version >= first_with_range_support
- and result[-1][0] == version[0]
- and result[-1][1][1] == version[1] + 1):
- # can use range to encompass this version
- result[-1][1][1] = version[1]
- continue
- result.append(Version(version[0], [version[1], version[1]]))
- if len(result) == 4:
- break
- return result
-
- @classmethod
- def get_handshake(cls):
- """ Return the supported Bolt versions as bytes.
- The length is 16 bytes as specified in the Bolt version negotiation.
- :return: bytes
- """
- supported_versions = sorted(cls.protocol_handlers().keys(), reverse=True)
- offered_versions = cls.version_list(supported_versions)
- return b"".join(version.to_bytes() for version in offered_versions).ljust(16, b"\x00")
-
- @classmethod
- def ping(cls, address, *, timeout=None, **config):
- """ Attempt to establish a Bolt connection, returning the
- agreed Bolt protocol version if successful.
- """
- config = PoolConfig.consume(config)
- try:
- s, protocol_version, handshake, data = connect(
- address,
- timeout=timeout,
- custom_resolver=config.resolver,
- ssl_context=config.get_ssl_context(),
- keep_alive=config.keep_alive,
- )
- except (ServiceUnavailable, SessionExpired, BoltHandshakeError):
- return None
- else:
- _close_socket(s)
- return protocol_version
-
- @classmethod
- def open(cls, address, *, auth=None, timeout=None, routing_context=None, **pool_config):
- """ Open a new Bolt connection to a given server address.
-
- :param address:
- :param auth:
- :param timeout: the connection timeout in seconds
- :param routing_context: dict containing routing context
- :param pool_config:
- :return:
- :raise BoltHandshakeError: raised if the Bolt Protocol can not negotiate a protocol version.
- :raise ServiceUnavailable: raised if there was a connection issue.
- """
- pool_config = PoolConfig.consume(pool_config)
- s, pool_config.protocol_version, handshake, data = connect(
- address,
- timeout=timeout,
- custom_resolver=pool_config.resolver,
- ssl_context=pool_config.get_ssl_context(),
- keep_alive=pool_config.keep_alive,
- )
-
- # Carry out Bolt subclass imports locally to avoid circular dependency
- # issues.
- if pool_config.protocol_version == (3, 0):
- from neo4j.io._bolt3 import Bolt3
- bolt_cls = Bolt3
- elif pool_config.protocol_version == (4, 0):
- from neo4j.io._bolt4 import Bolt4x0
- bolt_cls = Bolt4x0
- elif pool_config.protocol_version == (4, 1):
- from neo4j.io._bolt4 import Bolt4x1
- bolt_cls = Bolt4x1
- elif pool_config.protocol_version == (4, 2):
- from neo4j.io._bolt4 import Bolt4x2
- bolt_cls = Bolt4x2
- elif pool_config.protocol_version == (4, 3):
- from neo4j.io._bolt4 import Bolt4x3
- bolt_cls = Bolt4x3
- elif pool_config.protocol_version == (4, 4):
- from neo4j.io._bolt4 import Bolt4x4
- bolt_cls = Bolt4x4
- else:
- log.debug("[#%04X] S: ", s.getsockname()[1])
- _close_socket(s)
-
- supported_versions = Bolt.protocol_handlers().keys()
- raise BoltHandshakeError("The Neo4J server does not support communication with this driver. This driver have support for Bolt Protocols {}".format(supported_versions), address=address, request_data=handshake, response_data=data)
-
- connection = bolt_cls(
- address, s, pool_config.max_connection_lifetime, auth=auth,
- user_agent=pool_config.user_agent, routing_context=routing_context
- )
-
- try:
- connection.hello()
- except Exception:
- connection.close()
- raise
-
- return connection
-
- @property
- @abc.abstractmethod
- def encrypted(self):
- pass
-
- @property
- @abc.abstractmethod
- def der_encoded_server_certificate(self):
- pass
-
- @property
- @abc.abstractmethod
- def local_port(self):
- pass
-
- @abc.abstractmethod
- def hello(self):
- """ Appends a HELLO message to the outgoing queue, sends it and consumes
- all remaining messages.
- """
- pass
-
- def __del__(self):
- try:
- self.close()
- except OSError:
- pass
-
- @abc.abstractmethod
- def route(self, database=None, imp_user=None, bookmarks=None):
- """ Fetch a routing table from the server for the given
- `database`. For Bolt 4.3 and above, this appends a ROUTE
- message; for earlier versions, a procedure call is made via
- the regular Cypher execution mechanism. In all cases, this is
- sent to the network, and a response is fetched.
-
- :param database: database for which to fetch a routing table
- :param imp_user: the user to impersonate
- :param bookmarks: iterable of bookmark values after which this
- transaction should begin
- :return: dictionary of raw routing data
- """
- pass
-
- @abc.abstractmethod
- def run(self, query, parameters=None, mode=None, bookmarks=None,
- metadata=None, timeout=None, db=None, imp_user=None, **handlers):
- """ Appends a RUN message to the output queue.
-
- :param query: Cypher query string
- :param parameters: dictionary of Cypher parameters
- :param mode: access mode for routing - "READ" or "WRITE" (default)
- :param bookmarks: iterable of bookmark values after which this transaction should begin
- :param metadata: custom metadata dictionary to attach to the transaction
- :param timeout: timeout for transaction execution (seconds)
- :param db: name of the database against which to begin the transaction
- :param imp_user: the user to impersonate
- :param handlers: handler functions passed into the returned Response object
- :return: Response object
- """
- pass
-
- @abc.abstractmethod
- def discard(self, n=-1, qid=-1, **handlers):
- """ Appends a DISCARD message to the output queue.
-
- :param n: number of records to discard, default = -1 (ALL)
- :param qid: query ID to discard for, default = -1 (last query)
- :param handlers: handler functions passed into the returned Response object
- :return: Response object
- """
- pass
-
- @abc.abstractmethod
- def pull(self, n=-1, qid=-1, **handlers):
- """ Appends a PULL message to the output queue.
-
- :param n: number of records to pull, default = -1 (ALL)
- :param qid: query ID to pull for, default = -1 (last query)
- :param handlers: handler functions passed into the returned Response object
- :return: Response object
- """
- pass
-
- @abc.abstractmethod
- def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
- db=None, imp_user=None, **handlers):
- """ Appends a BEGIN message to the output queue.
-
- :param mode: access mode for routing - "READ" or "WRITE" (default)
- :param bookmarks: iterable of bookmark values after which this transaction should begin
- :param metadata: custom metadata dictionary to attach to the transaction
- :param timeout: timeout for transaction execution (seconds)
- :param db: name of the database against which to begin the transaction
- :param imp_user: the user to impersonate
- :param handlers: handler functions passed into the returned Response object
- :return: Response object
- """
- pass
-
- @abc.abstractmethod
- def commit(self, **handlers):
- """ Appends a COMMIT message to the output queue."""
- pass
-
- @abc.abstractmethod
- def rollback(self, **handlers):
- """ Appends a ROLLBACK message to the output queue."""
- pass
-
- @abc.abstractmethod
- def reset(self):
- """ Appends a RESET message to the outgoing queue, sends it and consumes
- all remaining messages.
- """
- pass
-
- def _append(self, signature, fields=(), response=None):
- """ Appends a message to the outgoing queue.
-
- :param signature: the signature of the message
- :param fields: the fields of the message as a tuple
- :param response: a response object to handle callbacks
- """
- self.packer.pack_struct(signature, fields)
- self.outbox.wrap_message()
- self.responses.append(response)
-
- def _send_all(self):
- data = self.outbox.view()
- if data:
- try:
- self.socket.sendall(data)
- except OSError as error:
- self._set_defunct_write(error)
- self.outbox.clear()
-
- def send_all(self):
- """ Send all queued messages to the server.
- """
- if self.closed():
- raise ServiceUnavailable("Failed to write to closed connection {!r} ({!r})".format(
- self.unresolved_address, self.server_info.address))
-
- if self.defunct():
- raise ServiceUnavailable("Failed to write to defunct connection {!r} ({!r})".format(
- self.unresolved_address, self.server_info.address))
-
- self._send_all()
-
- @abc.abstractmethod
- def fetch_message(self):
- """ Receive at most one message from the server, if available.
-
- :return: 2-tuple of number of detail messages and number of summary
- messages fetched
- """
- pass
-
- def fetch_all(self):
- """ Fetch all outstanding messages.
-
- :return: 2-tuple of number of detail messages and number of summary
- messages fetched
- """
- detail_count = summary_count = 0
- while self.responses:
- response = self.responses[0]
- while not response.complete:
- detail_delta, summary_delta = self.fetch_message()
- detail_count += detail_delta
- summary_count += summary_delta
- return detail_count, summary_count
-
- def _set_defunct_read(self, error=None, silent=False):
- message = "Failed to read from defunct connection {!r} ({!r})".format(
- self.unresolved_address, self.server_info.address
- )
- self._set_defunct(message, error=error, silent=silent)
-
- def _set_defunct_write(self, error=None, silent=False):
- message = "Failed to write data to connection {!r} ({!r})".format(
- self.unresolved_address, self.server_info.address
- )
- self._set_defunct(message, error=error, silent=silent)
-
- def _set_defunct(self, message, error=None, silent=False):
- direct_driver = isinstance(self.pool, BoltPool)
-
- if error:
- log.debug("[#%04X] %s", self.socket.getsockname()[1], error)
- log.error(message)
- # We were attempting to receive data but the connection
- # has unexpectedly terminated. So, we need to close the
- # connection from the client side, and remove the address
- # from the connection pool.
- self._defunct = True
- self.close()
- if self.pool:
- self.pool.deactivate(address=self.unresolved_address)
- # Iterate through the outstanding responses, and if any correspond
- # to COMMIT requests then raise an error to signal that we are
- # unable to confirm that the COMMIT completed successfully.
- if silent:
- return
- for response in self.responses:
- if isinstance(response, CommitResponse):
- if error:
- raise IncompleteCommit(message) from error
- else:
- raise IncompleteCommit(message)
-
- if direct_driver:
- if error:
- raise ServiceUnavailable(message) from error
- else:
- raise ServiceUnavailable(message)
- else:
- if error:
- raise SessionExpired(message) from error
- else:
- raise SessionExpired(message)
-
- def stale(self):
- return (self._stale
- or (0 <= self._max_connection_lifetime
- <= perf_counter() - self._creation_timestamp))
-
- _stale = False
-
- def set_stale(self):
- self._stale = True
-
- @abc.abstractmethod
- def close(self):
- """ Close the connection.
- """
- pass
-
- @abc.abstractmethod
- def closed(self):
- pass
-
- @abc.abstractmethod
- def defunct(self):
- pass
-
-
-class IOPool:
- """ A collection of connections to one or more server addresses.
- """
-
- def __init__(self, opener, pool_config, workspace_config):
- assert callable(opener)
- assert isinstance(pool_config, PoolConfig)
- assert isinstance(workspace_config, WorkspaceConfig)
-
- self.opener = opener
- self.pool_config = pool_config
- self.workspace_config = workspace_config
- self.connections = defaultdict(deque)
- self.lock = RLock()
- self.cond = Condition(self.lock)
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_value, traceback):
- self.close()
-
- def _acquire(self, address, timeout):
- """ Acquire a connection to a given address from the pool.
- The address supplied should always be an IP address, not
- a host name.
-
- This method is thread safe.
- """
- t0 = perf_counter()
- if timeout is None:
- timeout = self.workspace_config.connection_acquisition_timeout
-
- with self.lock:
- def time_remaining():
- t = timeout - (perf_counter() - t0)
- return t if t > 0 else 0
-
- while True:
- # try to find a free connection in pool
- for connection in list(self.connections.get(address, [])):
- if (connection.closed() or connection.defunct()
- or (connection.stale() and not connection.in_use)):
- # `close` is a noop on already closed connections.
- # This is to make sure that the connection is gracefully
- # closed, e.g. if it's just marked as `stale` but still
- # alive.
- if log.isEnabledFor(logging.DEBUG):
- log.debug(
- "[#%04X] C: removing old connection "
- "(closed=%s, defunct=%s, stale=%s, in_use=%s)",
- connection.local_port,
- connection.closed(), connection.defunct(),
- connection.stale(), connection.in_use
- )
- connection.close()
- try:
- self.connections.get(address, []).remove(connection)
- except ValueError:
- # If closure fails (e.g. because the server went
- # down), all connections to the same address will
- # be removed. Therefore, we silently ignore if the
- # connection isn't in the pool anymore.
- pass
- continue
- if not connection.in_use:
- connection.in_use = True
- return connection
- # all connections in pool are in-use
- connections = self.connections[address]
- max_pool_size = self.pool_config.max_connection_pool_size
- infinite_pool_size = (max_pool_size < 0
- or max_pool_size == float("inf"))
- can_create_new_connection = (
- infinite_pool_size
- or len(connections) < max_pool_size
- )
- if can_create_new_connection:
- timeout = min(self.pool_config.connection_timeout,
- time_remaining())
- try:
- connection = self.opener(address, timeout)
- except ServiceUnavailable:
- self.remove(address)
- raise
- else:
- connection.pool = self
- connection.in_use = True
- connections.append(connection)
- return connection
-
- # failed to obtain a connection from pool because the
- # pool is full and no free connection in the pool
- if time_remaining():
- self.cond.wait(time_remaining())
- # if timed out, then we throw error. This time
- # computation is needed, as with python 2.7, we
- # cannot tell if the condition is notified or
- # timed out when we come to this line
- if not time_remaining():
- raise ClientError("Failed to obtain a connection from pool "
- "within {!r}s".format(timeout))
- else:
- raise ClientError("Failed to obtain a connection from pool "
- "within {!r}s".format(timeout))
-
- def acquire(self, access_mode=None, timeout=None, database=None,
- bookmarks=None):
- """ Acquire a connection to a server that can satisfy a set of parameters.
-
- :param access_mode:
- :param timeout:
- :param database:
- :param bookmarks:
- """
-
- def release(self, *connections):
- """ Release a connection back into the pool.
- This method is thread safe.
- """
- with self.lock:
- for connection in connections:
- if not (connection.defunct()
- or connection.closed()
- or connection.is_reset):
- try:
- connection.reset()
- except (Neo4jError, DriverError, BoltError) as e:
- log.debug(
- "Failed to reset connection on release: %s", e
- )
- connection.in_use = False
- self.cond.notify_all()
-
- def in_use_connection_count(self, address):
- """ Count the number of connections currently in use to a given
- address.
- """
- try:
- connections = self.connections[address]
- except KeyError:
- return 0
- else:
- return sum(1 if connection.in_use else 0 for connection in connections)
-
- def mark_all_stale(self):
- with self.lock:
- for address in self.connections:
- for connection in self.connections[address]:
- connection.set_stale()
-
- def deactivate(self, address):
- """ Deactivate an address from the connection pool, if present, closing
- all idle connection to that address
- """
- with self.lock:
- try:
- connections = self.connections[address]
- except KeyError: # already removed from the connection pool
- return
- for conn in list(connections):
- if not conn.in_use:
- connections.remove(conn)
- try:
- conn.close()
- except OSError:
- pass
- if not connections:
- self.remove(address)
-
- def on_write_failure(self, address):
- raise WriteServiceUnavailable("No write service available for pool {}".format(self))
-
- def remove(self, address):
- """ Remove an address from the connection pool, if present, closing
- all connections to that address.
- """
- with self.lock:
- for connection in self.connections.pop(address, ()):
- try:
- connection.close()
- except OSError:
- pass
-
- def close(self):
- """ Close all connections and empty the pool.
- This method is thread safe.
- """
- try:
- with self.lock:
- for address in list(self.connections):
- self.remove(address)
- except TypeError:
- pass
-
-
-class BoltPool(IOPool):
-
- @classmethod
- def open(cls, address, *, auth, pool_config, workspace_config):
- """Create a new BoltPool
-
- :param address:
- :param auth:
- :param pool_config:
- :param workspace_config:
- :return: BoltPool
- """
-
- def opener(addr, timeout):
- return Bolt.open(
- addr, auth=auth, timeout=timeout, routing_context=None,
- **pool_config
- )
-
- pool = cls(opener, pool_config, workspace_config, address)
- return pool
-
- def __init__(self, opener, pool_config, workspace_config, address):
- super(BoltPool, self).__init__(opener, pool_config, workspace_config)
- self.address = address
-
- def __repr__(self):
- return "<{} address={!r}>".format(self.__class__.__name__, self.address)
-
- def acquire(self, access_mode=None, timeout=None, database=None, bookmarks=None):
- # The access_mode and database is not needed for a direct connection, its just there for consistency.
- return self._acquire(self.address, timeout)
-
-
-class Neo4jPool(IOPool):
- """ Connection pool with routing table.
- """
-
- @classmethod
- def open(cls, *addresses, auth, pool_config, workspace_config, routing_context=None):
- """Create a new Neo4jPool
-
- :param addresses: one or more address as positional argument
- :param auth:
- :param pool_config:
- :param workspace_config:
- :param routing_context:
- :return: Neo4jPool
- """
-
- address = addresses[0]
- if routing_context is None:
- routing_context = {}
- elif "address" in routing_context:
- raise ConfigurationError("The key 'address' is reserved for routing context.")
- routing_context["address"] = str(address)
-
- def opener(addr, timeout):
- return Bolt.open(addr, auth=auth, timeout=timeout,
- routing_context=routing_context, **pool_config)
-
- pool = cls(opener, pool_config, workspace_config, address)
- return pool
-
- def __init__(self, opener, pool_config, workspace_config, address):
- """
-
- :param opener:
- :param pool_config:
- :param workspace_config:
- :param address:
- """
- super(Neo4jPool, self).__init__(opener, pool_config, workspace_config)
- # Each database have a routing table, the default database is a special case.
- log.debug("[#0000] C: routing address %r", address)
- self.address = address
- self.routing_tables = {workspace_config.database: RoutingTable(database=workspace_config.database, routers=[address])}
- self.refresh_lock = RLock()
-
- def __repr__(self):
- """ The representation shows the initial routing addresses.
-
- :return: The representation
- :rtype: str
- """
- return "<{} addresses={!r}>".format(self.__class__.__name__, self.get_default_database_initial_router_addresses())
-
- @property
- def first_initial_routing_address(self):
- return self.get_default_database_initial_router_addresses()[0]
-
- def get_default_database_initial_router_addresses(self):
- """ Get the initial router addresses for the default database.
-
- :return:
- :rtype: OrderedSet
- """
- return self.get_routing_table_for_default_database().initial_routers
-
- def get_default_database_router_addresses(self):
- """ Get the router addresses for the default database.
-
- :return:
- :rtype: OrderedSet
- """
- return self.get_routing_table_for_default_database().routers
-
- def get_routing_table_for_default_database(self):
- return self.routing_tables[self.workspace_config.database]
-
- def get_or_create_routing_table(self, database):
- with self.refresh_lock:
- if database not in self.routing_tables:
- self.routing_tables[database] = RoutingTable(
- database=database,
- routers=self.get_default_database_initial_router_addresses()
- )
- return self.routing_tables[database]
-
- def fetch_routing_info(self, address, database, imp_user, bookmarks,
- timeout):
- """ Fetch raw routing info from a given router address.
-
- :param address: router address
- :param database: the database name to get routing table for
- :param imp_user: the user to impersonate while fetching the routing
- table
- :type imp_user: str or None
- :param bookmarks: iterable of bookmark values after which the routing
- info should be fetched
- :param timeout: connection acquisition timeout in seconds
-
- :return: list of routing records, or None if no connection
- could be established or if no readers or writers are present
- :raise ServiceUnavailable: if the server does not support
- routing, or if routing support is broken or outdated
- """
- cx = self._acquire(address, timeout)
- try:
- routing_table = cx.route(
- database or self.workspace_config.database,
- imp_user or self.workspace_config.impersonated_user,
- bookmarks
- )
- finally:
- self.release(cx)
- return routing_table
-
- def fetch_routing_table(self, *, address, timeout, database, imp_user,
- bookmarks):
- """ Fetch a routing table from a given router address.
-
- :param address: router address
- :param timeout: seconds
- :param database: the database name
- :type: str
- :param imp_user: the user to impersonate while fetching the routing
- table
- :type imp_user: str or None
- :param bookmarks: bookmarks used when fetching routing table
-
- :return: a new RoutingTable instance or None if the given router is
- currently unable to provide routing information
- """
- new_routing_info = None
- try:
- new_routing_info = self.fetch_routing_info(
- address, database, imp_user, bookmarks, timeout
- )
- except Neo4jError as e:
- # checks if the code is an error that is caused by the client. In
- # this case there is no sense in trying to fetch a RT from another
- # router. Hence, the driver should fail fast during discovery.
- if e.is_fatal_during_discovery():
- raise
- except (ServiceUnavailable, SessionExpired):
- pass
- if not new_routing_info:
- log.debug("Failed to fetch routing info %s", address)
- return None
- else:
- servers = new_routing_info[0]["servers"]
- ttl = new_routing_info[0]["ttl"]
- database = new_routing_info[0].get("db", database)
- new_routing_table = RoutingTable.parse_routing_info(
- database=database, servers=servers, ttl=ttl
- )
-
- # Parse routing info and count the number of each type of server
- num_routers = len(new_routing_table.routers)
- num_readers = len(new_routing_table.readers)
-
- # num_writers = len(new_routing_table.writers)
- # If no writers are available. This likely indicates a temporary state,
- # such as leader switching, so we should not signal an error.
-
- # No routers
- if num_routers == 0:
- log.debug("No routing servers returned from server %s", address)
- return None
-
- # No readers
- if num_readers == 0:
- log.debug("No read servers returned from server %s", address)
- return None
-
- # At least one of each is fine, so return this table
- return new_routing_table
-
- def _update_routing_table_from(self, *routers, database=None, imp_user=None,
- bookmarks=None, database_callback=None):
- """ Try to update routing tables with the given routers.
-
- :return: True if the routing table is successfully updated,
- otherwise False
- """
- log.debug("Attempting to update routing table from {}".format(", ".join(map(repr, routers))))
- for router in routers:
- for address in router.resolve(resolver=self.pool_config.resolver):
- new_routing_table = self.fetch_routing_table(
- address=address,
- timeout=self.pool_config.connection_timeout,
- database=database, imp_user=imp_user, bookmarks=bookmarks
- )
- if new_routing_table is not None:
- new_databse = new_routing_table.database
- self.get_or_create_routing_table(new_databse)\
- .update(new_routing_table)
- log.debug(
- "[#0000] C: address=%r (%r)",
- address, self.routing_tables[new_databse]
- )
- if callable(database_callback):
- database_callback(new_databse)
- return True
- self.deactivate(router)
- return False
-
- def update_routing_table(self, *, database, imp_user, bookmarks,
- database_callback=None):
- """ Update the routing table from the first router able to provide
- valid routing information.
-
- :param database: The database name
- :param imp_user: the user to impersonate while fetching the routing
- table
- :type imp_user: str or None
- :param bookmarks: bookmarks used when fetching routing table
- :param database_callback: A callback function that will be called with
- the database name as only argument when a new routing table has been
- acquired. This database name might different from `database` if that
- was None and the underlying protocol supports reporting back the
- actual database.
-
- :raise neo4j.exceptions.ServiceUnavailable:
- """
- with self.refresh_lock:
- # copied because it can be modified
- existing_routers = set(
- self.get_or_create_routing_table(database).routers
- )
-
- prefer_initial_routing_address = \
- self.routing_tables[database].initialized_without_writers
-
- if prefer_initial_routing_address:
- # TODO: Test this state
- if self._update_routing_table_from(
- self.first_initial_routing_address, database=database,
- imp_user=imp_user, bookmarks=bookmarks,
- database_callback=database_callback
- ):
- # Why is only the first initial routing address used?
- return
- if self._update_routing_table_from(
- *(existing_routers - {self.first_initial_routing_address}),
- database=database, imp_user=imp_user, bookmarks=bookmarks,
- database_callback=database_callback
- ):
- return
-
- if not prefer_initial_routing_address:
- if self._update_routing_table_from(
- self.first_initial_routing_address, database=database,
- imp_user=imp_user, bookmarks=bookmarks,
- database_callback=database_callback
- ):
- # Why is only the first initial routing address used?
- return
-
- # None of the routers have been successful, so just fail
- log.error("Unable to retrieve routing information")
- raise ServiceUnavailable("Unable to retrieve routing information")
-
- def update_connection_pool(self, *, database):
- servers = self.get_or_create_routing_table(database).servers()
- for address in list(self.connections):
- if address.unresolved not in servers:
- super(Neo4jPool, self).deactivate(address)
-
- def ensure_routing_table_is_fresh(self, *, access_mode, database, imp_user,
- bookmarks, database_callback=None):
- """ Update the routing table if stale.
-
- This method performs two freshness checks, before and after acquiring
- the refresh lock. If the routing table is already fresh on entry, the
- method exits immediately; otherwise, the refresh lock is acquired and
- the second freshness check that follows determines whether an update
- is still required.
-
- This method is thread-safe.
-
- :return: `True` if an update was required, `False` otherwise.
- """
- from neo4j.api import READ_ACCESS
- with self.refresh_lock:
- if self.get_or_create_routing_table(database)\
- .is_fresh(readonly=(access_mode == READ_ACCESS)):
- # Readers are fresh.
- return False
-
- self.update_routing_table(
- database=database, imp_user=imp_user, bookmarks=bookmarks,
- database_callback=database_callback
- )
- self.update_connection_pool(database=database)
-
- for database in list(self.routing_tables.keys()):
- # Remove unused databases in the routing table
- # Remove the routing table after a timeout = TTL + 30s
- log.debug("[#0000] C: database=%s", database)
- if (self.routing_tables[database].should_be_purged_from_memory()
- and database != self.workspace_config.database):
- del self.routing_tables[database]
-
- return True
-
- def _select_address(self, *, access_mode, database):
- from neo4j.api import READ_ACCESS
- """ Selects the address with the fewest in-use connections.
- """
- with self.refresh_lock:
- if access_mode == READ_ACCESS:
- addresses = self.routing_tables[database].readers
- else:
- addresses = self.routing_tables[database].writers
- addresses_by_usage = {}
- for address in addresses:
- addresses_by_usage.setdefault(
- self.in_use_connection_count(address), []
- ).append(address)
- if not addresses_by_usage:
- if access_mode == READ_ACCESS:
- raise ReadServiceUnavailable(
- "No read service currently available"
- )
- else:
- raise WriteServiceUnavailable(
- "No write service currently available"
- )
- return choice(addresses_by_usage[min(addresses_by_usage)])
-
- def acquire(self, access_mode=None, timeout=None, database=None,
- bookmarks=None):
- if access_mode not in (WRITE_ACCESS, READ_ACCESS):
- raise ClientError("Non valid 'access_mode'; {}".format(access_mode))
- if not timeout:
- raise ClientError("'timeout' must be a float larger than 0; {}"
- .format(timeout))
-
- from neo4j.api import check_access_mode
- access_mode = check_access_mode(access_mode)
- with self.refresh_lock:
- log.debug("[#0000] C: %r",
- self.routing_tables)
- self.ensure_routing_table_is_fresh(
- access_mode=access_mode, database=database, imp_user=None,
- bookmarks=bookmarks
- )
-
- while True:
- try:
- # Get an address for a connection that have the fewest in-use
- # connections.
- address = self._select_address(access_mode=access_mode,
- database=database)
- except (ReadServiceUnavailable, WriteServiceUnavailable) as err:
- raise SessionExpired("Failed to obtain connection towards '%s' server." % access_mode) from err
- try:
- log.debug("[#0000] C: database=%r address=%r", database, address)
- connection = self._acquire(address, timeout=timeout) # should always be a resolved address
- except ServiceUnavailable:
- self.deactivate(address=address)
- else:
- return connection
-
- def deactivate(self, address):
- """ Deactivate an address from the connection pool,
- if present, remove from the routing table and also closing
- all idle connections to that address.
- """
- log.debug("[#0000] C: Deactivating address %r", address)
- # We use `discard` instead of `remove` here since the former
- # will not fail if the address has already been removed.
- for database in self.routing_tables.keys():
- self.routing_tables[database].routers.discard(address)
- self.routing_tables[database].readers.discard(address)
- self.routing_tables[database].writers.discard(address)
- log.debug("[#0000] C: table=%r", self.routing_tables)
- super(Neo4jPool, self).deactivate(address)
-
- def on_write_failure(self, address):
- """ Remove a writer address from the routing table, if present.
- """
- log.debug("[#0000] C: Removing writer %r", address)
- for database in self.routing_tables.keys():
- self.routing_tables[database].writers.discard(address)
- log.debug("[#0000] C: table=%r", self.routing_tables)
-
-
-def _connect(resolved_address, timeout, keep_alive):
- """
-
- :param resolved_address:
- :param timeout: seconds
- :param keep_alive: True or False
- :return: socket object
- """
-
- s = None # The socket
-
- try:
- if len(resolved_address) == 2:
- s = socket(AF_INET)
- elif len(resolved_address) == 4:
- s = socket(AF_INET6)
- else:
- raise ValueError("Unsupported address {!r}".format(resolved_address))
- t = s.gettimeout()
- if timeout:
- s.settimeout(timeout)
- log.debug("[#0000] C: %s", resolved_address)
- s.connect(resolved_address)
- s.settimeout(t)
- keep_alive = 1 if keep_alive else 0
- s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive)
- except SocketTimeout:
- log.debug("[#0000] C: %s", resolved_address)
- log.debug("[#0000] C: %s", resolved_address)
- _close_socket(s)
- raise ServiceUnavailable("Timed out trying to establish connection to {!r}".format(resolved_address))
- except OSError as error:
- log.debug("[#0000] C: %s %s", type(error).__name__,
- " ".join(map(repr, error.args)))
- log.debug("[#0000] C: %s", resolved_address)
- s.close()
- raise ServiceUnavailable("Failed to establish connection to {!r} (reason {})".format(resolved_address, error))
- else:
- return s
-
-
-def _secure(s, host, ssl_context):
- local_port = s.getsockname()[1]
- # Secure the connection if an SSL context has been provided
- if ssl_context:
- last_error = None
- log.debug("[#%04X] C: %s", local_port, host)
- try:
- sni_host = host if HAS_SNI and host else None
- s = ssl_context.wrap_socket(s, server_hostname=sni_host)
- except (OSError, SSLError, CertificateError) as cause:
- raise BoltSecurityError(
- message="Failed to establish encrypted connection.",
- address=(host, local_port)
- ) from cause
- # Check that the server provides a certificate
- der_encoded_server_certificate = s.getpeercert(binary_form=True)
- if der_encoded_server_certificate is None:
- raise BoltProtocolError(
- "When using an encrypted socket, the server should always "
- "provide a certificate", address=(host, local_port)
- )
- return s
- return s
-
-
-def _handshake(s, resolved_address):
- """
-
- :param s: Socket
- :param resolved_address:
-
- :return: (socket, version, client_handshake, server_response_data)
- """
- local_port = s.getsockname()[1]
-
- # TODO: Optimize logging code
- handshake = Bolt.get_handshake()
- import struct
- handshake = struct.unpack(">16B", handshake)
- handshake = [handshake[i:i + 4] for i in range(0, len(handshake), 4)]
-
- supported_versions = [("0x%02X%02X%02X%02X" % (vx[0], vx[1], vx[2], vx[3])) for vx in handshake]
-
- log.debug("[#%04X] C: 0x%08X", local_port, int.from_bytes(Bolt.MAGIC_PREAMBLE, byteorder="big"))
- log.debug("[#%04X] C: %s %s %s %s", local_port, *supported_versions)
-
- data = Bolt.MAGIC_PREAMBLE + Bolt.get_handshake()
- s.sendall(data)
-
- # Handle the handshake response
- ready_to_read = False
- with selectors.DefaultSelector() as selector:
- selector.register(s, selectors.EVENT_READ)
- selector.select(1)
- try:
- data = s.recv(4)
- except OSError:
- raise ServiceUnavailable("Failed to read any data from server {!r} "
- "after connected".format(resolved_address))
- data_size = len(data)
- if data_size == 0:
- # If no data is returned after a successful select
- # response, the server has closed the connection
- log.debug("[#%04X] S: ", local_port)
- _close_socket(s)
- raise ServiceUnavailable("Connection to {address} closed without handshake response".format(address=resolved_address))
- if data_size != 4:
- # Some garbled data has been received
- log.debug("[#%04X] S: @*#!", local_port)
- s.close()
- raise BoltProtocolError("Expected four byte Bolt handshake response from %r, received %r instead; check for incorrect port number" % (resolved_address, data), address=resolved_address)
- elif data == b"HTTP":
- log.debug("[#%04X] S: ", local_port)
- _close_socket(s)
- raise ServiceUnavailable("Cannot to connect to Bolt service on {!r} "
- "(looks like HTTP)".format(resolved_address))
- agreed_version = data[-1], data[-2]
- log.debug("[#%04X] S: 0x%06X%02X", local_port, agreed_version[1], agreed_version[0])
- return s, agreed_version, handshake, data
-
-
-def _close_socket(socket_):
- try:
- socket_.shutdown(SHUT_RDWR)
- socket_.close()
- except OSError:
- pass
-
-
-def connect(address, *, timeout, custom_resolver, ssl_context, keep_alive):
- """ Connect and perform a handshake and return a valid Connection object,
- assuming a protocol version can be agreed.
- """
- errors = []
- # Establish a connection to the host and port specified
- # Catches refused connections see:
- # https://docs.python.org/2/library/errno.html
-
- resolved_addresses = Address(address).resolve(resolver=custom_resolver)
- for resolved_address in resolved_addresses:
- s = None
- try:
- s = _connect(resolved_address, timeout, keep_alive)
- s = _secure(s, resolved_address.host_name, ssl_context)
- return _handshake(s, resolved_address)
- except (BoltError, DriverError, OSError) as error:
- try:
- local_port = s.getsockname()[1]
- except (OSError, AttributeError):
- local_port = 0
- err_str = error.__class__.__name__
- if str(error):
- err_str += ": " + str(error)
- log.debug("[#%04X] C: %s", local_port, err_str)
- if s:
- _close_socket(s)
- errors.append(error)
- except Exception:
- if s:
- _close_socket(s)
- raise
- if not errors:
- raise ServiceUnavailable(
- "Couldn't connect to %s (resolved to %s)" % (
- str(address), tuple(map(str, resolved_addresses)))
- )
- else:
- raise ServiceUnavailable(
- "Couldn't connect to %s (resolved to %s):\n%s" % (
- str(address), tuple(map(str, resolved_addresses)),
- "\n".join(map(str, errors))
- )
- ) from errors[0]
-
-
-def check_supported_server_product(agent):
- """ Checks that a server product is supported by the driver by
- looking at the server agent string.
-
- :param agent: server agent string to check for validity
- :raises UnsupportedServerProduct: if the product is not supported
- """
- if not agent.startswith("Neo4j/"):
- raise UnsupportedServerProduct(agent)
diff --git a/neo4j/meta.py b/neo4j/meta.py
index dd727eab..37d8f8d0 100644
--- a/neo4j/meta.py
+++ b/neo4j/meta.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -31,7 +28,10 @@ def get_user_agent():
""" Obtain the default user agent string sent to the server after
a successful handshake.
"""
- from sys import platform, version_info
+ from sys import (
+ platform,
+ version_info,
+ )
template = "neo4j-python/{} Python/{}.{}.{}-{}-{} ({})"
fields = (version,) + tuple(version_info) + (platform,)
return template.format(*fields)
diff --git a/neo4j/packstream.py b/neo4j/packstream.py
index 406d761e..e453013c 100644
--- a/neo4j/packstream.py
+++ b/neo4j/packstream.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -20,8 +17,11 @@
from codecs import decode
-from io import BytesIO
-from struct import pack as struct_pack, unpack as struct_unpack
+from struct import (
+ pack as struct_pack,
+ unpack as struct_unpack,
+)
+
PACKED_UINT_8 = [struct_pack(">B", value) for value in range(0x100)]
PACKED_UINT_16 = [struct_pack(">H", value) for value in range(0x10000)]
@@ -472,14 +472,3 @@ def pop_u16(self):
return value
else:
return -1
-
- def receive(self, sock, n_bytes):
- end = self.used + n_bytes
- if end > len(self.data):
- self.data += bytearray(end - len(self.data))
- view = memoryview(self.data)
- while self.used < end:
- n = sock.recv_into(view[self.used:end], end - self.used)
- if n == 0:
- raise OSError("No data")
- self.used += n
diff --git a/neo4j/routing.py b/neo4j/routing.py
index 8303f4c2..b0546d12 100644
--- a/neo4j/routing.py
+++ b/neo4j/routing.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,12 +16,11 @@
# limitations under the License.
-from collections import OrderedDict
from collections.abc import MutableSet
from logging import getLogger
from time import perf_counter
-from neo4j.addressing import Address
+from .addressing import Address
log = getLogger("neo4j")
@@ -33,7 +29,8 @@
class OrderedSet(MutableSet):
def __init__(self, elements=()):
- self._elements = OrderedDict.fromkeys(elements)
+ # dicts keep insertion order starting with Python 3.7
+ self._elements = dict.fromkeys(elements)
self._current = None
def __repr__(self):
@@ -70,12 +67,12 @@ def remove(self, element):
raise ValueError(element)
def update(self, elements=()):
- self._elements.update(OrderedDict.fromkeys(elements))
+ self._elements.update(dict.fromkeys(elements))
def replace(self, elements=()):
e = self._elements
e.clear()
- e.update(OrderedDict.fromkeys(elements))
+ e.update(dict.fromkeys(elements))
class RoutingTable:
diff --git a/neo4j/spatial/__init__.py b/neo4j/spatial/__init__.py
index 36e8d66f..5e250338 100644
--- a/neo4j/spatial/__init__.py
+++ b/neo4j/spatial/__init__.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -24,21 +21,21 @@
"""
-from threading import Lock
-
-from neo4j.packstream import Structure
-
-
__all__ = [
- "Point",
"CartesianPoint",
- "WGS84Point",
- "point_type",
- "hydrate_point",
"dehydrate_point",
+ "hydrate_point",
+ "Point",
+ "point_type",
+ "WGS84Point",
]
+from threading import Lock
+
+from neo4j.packstream import Structure
+
+
# SRID to subclass mappings
__srid_table = {}
__srid_table_lock = Lock()
diff --git a/neo4j/time/__init__.py b/neo4j/time/__init__.py
index 0a76280f..4e070d7a 100644
--- a/neo4j/time/__init__.py
+++ b/neo4j/time/__init__.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# coding: utf-8
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -24,19 +21,18 @@
as a number of utility functions.
"""
+
from contextlib import contextmanager
from datetime import (
- timedelta,
date,
- time,
datetime,
+ time,
+ timedelta,
)
from decimal import (
Decimal,
localcontext,
- ROUND_DOWN,
ROUND_HALF_EVEN,
- ROUND_HALF_UP,
)
from functools import total_ordering
from re import compile as re_compile
@@ -48,18 +44,18 @@
from neo4j.meta import (
deprecated,
- deprecation_warn
+ deprecation_warn,
)
from neo4j.time.arithmetic import (
nano_add,
nano_div,
- symmetric_divmod,
round_half_to_even,
+ symmetric_divmod,
)
from neo4j.time.metaclasses import (
+ DateTimeType,
DateType,
TimeType,
- DateTimeType,
)
@@ -84,19 +80,15 @@ def inner(*args, **kwargs):
MIN_INT64 = -(2 ** 63)
MAX_INT64 = (2 ** 63) - 1
+#: The smallest year number allowed in a :class:`.Date` or :class:`.DateTime`
+#: object to be compatible with :class:`datetime.date` and
+#: :class:`datetime.datetime`.
MIN_YEAR = 1
-"""
-The smallest year number allowed in a :class:`.Date` or :class:`.DateTime`
-object to be compatible with :class:`datetime.date` and
-:class:`datetime.datetime`.
-"""
+#: The largest year number allowed in a :class:`.Date` or :class:`.DateTime`
+#: object to be compatible with :class:`datetime.date` and
+#: :class:`datetime.datetime`.
MAX_YEAR = 9999
-"""
-The largest year number allowed in a :class:`.Date` or :class:`.DateTime`
-object to be compatible with :class:`datetime.date` and
-:class:`datetime.datetime`.
-"""
DATE_ISO_PATTERN = re_compile(r"^(\d{4})-(\d{2})-(\d{2})$")
TIME_ISO_PATTERN = re_compile(
@@ -1387,11 +1379,9 @@ def __format__(self, format_spec):
Date.max = Date.from_ordinal(3652059)
Date.resolution = Duration(days=1)
+#: A :class:`neo4j.time.Date` instance set to `0000-00-00`.
+#: This has an ordinal value of `0`.
ZeroDate = object.__new__(Date)
-"""
-A :class:`neo4j.time.Date` instance set to `0000-00-00`.
-This has an ordinal value of `0`.
-"""
class Time(metaclass=TimeType):
@@ -2001,17 +1991,13 @@ def __format__(self, format_spec):
Time.max = Time(hour=23, minute=59, second=59, nanosecond=999999999)
Time.resolution = Duration(nanoseconds=1)
+#: A :class:`.Time` instance set to `00:00:00`.
+#: This has a :attr:`.ticks_ns` value of `0`.
Midnight = Time.min
-"""
-A :class:`.Time` instance set to `00:00:00`.
-This has a :attr:`.ticks_ns` value of `0`.
-"""
+#: A :class:`.Time` instance set to `12:00:00`.
+#: This has a :attr:`.ticks_ns` value of `43200000000000`.
Midday = Time(hour=12)
-"""
-A :class:`.Time` instance set to `12:00:00`.
-This has a :attr:`.ticks_ns` value of `43200000000000`.
-"""
@total_ordering
@@ -2621,12 +2607,9 @@ def __format__(self, format_spec):
DateTime.max = DateTime.combine(Date.max, Time.max)
DateTime.resolution = Time.resolution
+#: A :class:`.DateTime` instance set to `0000-00-00T00:00:00`.
+#: This has a :class:`.Date` component equal to :attr:`ZeroDate` and a
Never = DateTime.combine(ZeroDate, Midnight)
-"""
-A :class:`.DateTime` instance set to `0000-00-00T00:00:00`.
-This has a :class:`.Date` component equal to :attr:`ZeroDate` and a
-:class:`.Time` component equal to :attr:`Midnight`.
-"""
+#: A :class:`.DateTime` instance set to `1970-01-01T00:00:00`.
UnixEpoch = DateTime(1970, 1, 1, 0, 0, 0)
-"""A :class:`.DateTime` instance set to `1970-01-01T00:00:00`."""
diff --git a/neo4j/time/__main__.py b/neo4j/time/__main__.py
index 447b375c..9d1858b5 100644
--- a/neo4j/time/__main__.py
+++ b/neo4j/time/__main__.py
@@ -1,5 +1,4 @@
#!/usr/bin/env python
-# coding: utf-8
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
@@ -20,7 +19,11 @@
def main():
- from neo4j.time import Clock, DateTime, UnixEpoch
+ from neo4j.time import (
+ Clock,
+ DateTime,
+ UnixEpoch,
+ )
clock = Clock()
time = clock.utc_time()
print("Using %s" % type(clock).__name__)
diff --git a/neo4j/time/arithmetic.py b/neo4j/time/arithmetic.py
index c39ea2c2..71c43f9f 100644
--- a/neo4j/time/arithmetic.py
+++ b/neo4j/time/arithmetic.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# coding: utf-8
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,9 +16,6 @@
# limitations under the License.
-from math import isnan
-
-
def nano_add(x, y):
"""
diff --git a/neo4j/time/clock_implementations.py b/neo4j/time/clock_implementations.py
index fccbd433..de4d0ba7 100644
--- a/neo4j/time/clock_implementations.py
+++ b/neo4j/time/clock_implementations.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# coding: utf-8
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,10 +16,19 @@
# limitations under the License.
-from ctypes import CDLL, Structure, c_longlong, c_long, byref
+from ctypes import (
+ byref,
+ c_long,
+ c_longlong,
+ CDLL,
+ Structure,
+)
from platform import uname
-from neo4j.time import Clock, ClockTime
+from neo4j.time import (
+ Clock,
+ ClockTime,
+)
from neo4j.time.arithmetic import nano_divmod
diff --git a/neo4j/time/hydration.py b/neo4j/time/hydration.py
index 69a3ddcf..f522ebd7 100644
--- a/neo4j/time/hydration.py
+++ b/neo4j/time/hydration.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -20,17 +17,17 @@
from datetime import (
- time,
datetime,
+ time,
timedelta,
)
from neo4j.packstream import Structure
from neo4j.time import (
- Duration,
Date,
- Time,
DateTime,
+ Duration,
+ Time,
)
@@ -114,7 +111,10 @@ def hydrate_datetime(seconds, nanoseconds, tz=None):
:param tz:
:return: datetime
"""
- from pytz import FixedOffset, timezone
+ from pytz import (
+ FixedOffset,
+ timezone,
+ )
minutes, seconds = map(int, divmod(seconds, 60))
hours, minutes = map(int, divmod(minutes, 60))
days, hours = map(int, divmod(hours, 24))
diff --git a/neo4j/time/metaclasses.py b/neo4j/time/metaclasses.py
index 278a0fb6..ad5e53ff 100644
--- a/neo4j/time/metaclasses.py
+++ b/neo4j/time/metaclasses.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# coding: utf-8
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
diff --git a/neo4j/work/__init__.py b/neo4j/work/__init__.py
index 154f7381..a4254afa 100644
--- a/neo4j/work/__init__.py
+++ b/neo4j/work/__init__.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,88 +16,19 @@
# limitations under the License.
-from neo4j.conf import WorkspaceConfig
-from neo4j.exceptions import ServiceUnavailable
-from neo4j.io import Neo4jPool
-
-
-class Workspace:
-
- def __init__(self, pool, config):
- assert isinstance(config, WorkspaceConfig)
- self._pool = pool
- self._config = config
- self._connection = None
- self._connection_access_mode = None
- # Sessions are supposed to cache the database on which to operate.
- self._cached_database = False
- self._bookmarks = None
-
- def __del__(self):
- try:
- self.close()
- except OSError:
- pass
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_value, traceback):
- self.close()
-
- def _set_cached_database(self, database):
- self._cached_database = True
- self._config.database = database
-
- def _connect(self, access_mode):
- if self._connection:
- # TODO: Investigate this
- # log.warning("FIXME: should always disconnect before connect")
- self._connection.send_all()
- self._connection.fetch_all()
- self._disconnect()
- if not self._cached_database:
- if (self._config.database is not None
- or not isinstance(self._pool, Neo4jPool)):
- self._set_cached_database(self._config.database)
- else:
- # This is the first time we open a connection to a server in a
- # cluster environment for this session without explicitly
- # configured database. Hence, we request a routing table update
- # to try to fetch the home database. If provided by the server,
- # we shall use this database explicitly for all subsequent
- # actions within this session.
- self._pool.update_routing_table(
- database=self._config.database,
- imp_user=self._config.impersonated_user,
- bookmarks=self._bookmarks,
- database_callback=self._set_cached_database
- )
- self._connection = self._pool.acquire(
- access_mode=access_mode,
- timeout=self._config.connection_acquisition_timeout,
- database=self._config.database,
- bookmarks=self._bookmarks
- )
- self._connection_access_mode = access_mode
-
- def _disconnect(self, sync=False):
- if self._connection:
- if sync:
- try:
- self._connection.send_all()
- self._connection.fetch_all()
- except (WorkspaceError, ServiceUnavailable):
- pass
- if self._connection:
- self._pool.release(self._connection)
- self._connection = None
- self._connection_access_mode = None
-
- def close(self):
- self._disconnect(sync=True)
-
+from .query import (
+ Query,
+ unit_of_work,
+)
+from .summary import (
+ ResultSummary,
+ SummaryCounters,
+)
-class WorkspaceError(Exception):
- pass
+__all__ = [
+ "Query",
+ "ResultSummary",
+ "SummaryCounters",
+ "unit_of_work",
+]
diff --git a/neo4j/work/pipelining.py b/neo4j/work/pipelining.py
deleted file mode 100644
index ccca8420..00000000
--- a/neo4j/work/pipelining.py
+++ /dev/null
@@ -1,136 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-# Copyright (c) "Neo4j"
-# Neo4j Sweden AB [http://neo4j.com]
-#
-# This file is part of Neo4j.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from collections import deque
-from threading import Thread, Lock
-from time import sleep
-
-from neo4j.work import Workspace
-from neo4j.conf import WorkspaceConfig
-from neo4j.api import (
- WRITE_ACCESS,
-)
-
-class PipelineConfig(WorkspaceConfig):
-
- #:
- flush_every = 8192 # bytes
-
-
-class Pipeline(Workspace):
-
- def __init__(self, pool, config):
- assert isinstance(config, PipelineConfig)
- super(Pipeline, self).__init__(pool, config)
- self._connect(WRITE_ACCESS)
- self._flush_every = config.flush_every
- self._data = deque()
- self._pull_lock = Lock()
-
- def push(self, statement, parameters=None):
- self._connection.run(statement, parameters)
- self._connection.pull(on_records=self._data.extend)
- output_buffer_size = len(self._connection.outbox.view())
- if output_buffer_size >= self._flush_every:
- self._connection.send_all()
-
- def _results_generator(self):
- results_returned_count = 0
- try:
- summary = 0
- while summary == 0:
- _, summary = self._connection.fetch_message()
- summary = 0
- while summary == 0:
- detail, summary = self._connection.fetch_message()
- for n in range(detail):
- response = self._data.popleft()
- results_returned_count += 1
- yield response
- finally:
- self._pull_lock.release()
-
- def pull(self):
- """Returns a generator containing the results of the next query in the pipeline"""
- # n.b. pull is now somewhat misleadingly named because it doesn't do anything
- # the connection isn't touched until you try and iterate the generator we return
- lock_acquired = self._pull_lock.acquire(blocking=False)
- if not lock_acquired:
- raise PullOrderException()
- return self._results_generator()
-
-
-class PullOrderException(Exception):
- """Raise when calling pull if a previous pull result has not been fully consumed"""
-
-
-class Pusher(Thread):
-
- def __init__(self, pipeline):
- super(Pusher, self).__init__()
- self.pipeline = pipeline
- self.running = True
- self.count = 0
-
- def run(self):
- while self.running:
- self.pipeline.push("RETURN $x", {"x": self.count})
- self.count += 1
-
-
-class Puller(Thread):
-
- def __init__(self, pipeline):
- super(Puller, self).__init__()
- self.pipeline = pipeline
- self.running = True
- self.count = 0
-
- def run(self):
- while self.running:
- for _ in self.pipeline.pull():
- pass # consume and discard records
- self.count += 1
-
-
-def main():
- from neo4j import Driver
- # from neo4j.bolt.diagnostics import watch
- # watch("neobolt")
- with Driver("bolt://", auth=("neo4j", "password")) as dx:
- p = dx.pipeline(flush_every=1024)
- pusher = Pusher(p)
- puller = Puller(p)
- try:
- pusher.start()
- puller.start()
- while True:
- print("sent %d, received %d, backlog %d" % (pusher.count, puller.count, pusher.count - puller.count))
- sleep(1)
- except KeyboardInterrupt:
- pusher.running = False
- pusher.join()
- puller.running = False
- puller.join()
-
-
-if __name__ == "__main__":
- main()
diff --git a/neo4j/work/query.py b/neo4j/work/query.py
new file mode 100644
index 00000000..acefa6e7
--- /dev/null
+++ b/neo4j/work/query.py
@@ -0,0 +1,75 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class Query:
+ """ Create a new query.
+
+ :param text: The query text.
+ :type text: str
+ :param metadata: metadata attached to the query.
+ :type metadata: dict
+ :param timeout: seconds.
+ :type timeout: int
+ """
+ def __init__(self, text, metadata=None, timeout=None):
+ self.text = text
+
+ self.metadata = metadata
+ self.timeout = timeout
+
+ def __str__(self):
+ return str(self.text)
+
+
+def unit_of_work(metadata=None, timeout=None):
+ """This function is a decorator for transaction functions that allows extra control over how the transaction is carried out.
+
+ For example, a timeout may be applied::
+
+ @unit_of_work(timeout=100)
+ def count_people_tx(tx):
+ result = tx.run("MATCH (a:Person) RETURN count(a) AS persons")
+ record = result.single()
+ return record["persons"]
+
+ :param metadata:
+ a dictionary with metadata.
+ Specified metadata will be attached to the executing transaction and visible in the output of ``dbms.listQueries`` and ``dbms.listTransactions`` procedures.
+ It will also get logged to the ``query.log``.
+ This functionality makes it easier to tag transactions and is equivalent to ``dbms.setTXMetaData`` procedure, see https://neo4j.com/docs/operations-manual/current/reference/procedures/ for procedure reference.
+ :type metadata: dict
+
+ :param timeout:
+ the transaction timeout in seconds.
+ Transactions that execute longer than the configured timeout will be terminated by the database.
+ This functionality allows to limit query/transaction execution time.
+ Specified timeout overrides the default timeout configured in the database using ``dbms.transaction.timeout`` setting.
+ Value should not represent a duration of zero or negative duration.
+ :type timeout: int
+ """
+
+ def wrapper(f):
+
+ def wrapped(*args, **kwargs):
+ return f(*args, **kwargs)
+
+ wrapped.metadata = metadata
+ wrapped.timeout = timeout
+ return wrapped
+
+ return wrapper
diff --git a/neo4j/work/summary.py b/neo4j/work/summary.py
index c0955b00..600cfe2e 100644
--- a/neo4j/work/summary.py
+++ b/neo4j/work/summary.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,9 +16,8 @@
# limitations under the License.
-from collections import namedtuple
+from .._exceptions import BoltProtocolError
-from neo4j._exceptions import BoltProtocolError
BOLT_VERSION_1 = 1
BOLT_VERSION_2 = 2
diff --git a/requirements-dev.txt b/requirements-dev.txt
new file mode 100644
index 00000000..cc40eecd
--- /dev/null
+++ b/requirements-dev.txt
@@ -0,0 +1,10 @@
+# auto-generate sync driver from async code
+unasync>=0.5.0
+pre-commit>=2.15.0
+isort>=5.10.0
+
+# needed for running tests
+-r tests/requirements.txt
+
+# production dependencies
+-r requirements.txt
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 00000000..c63faed9
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,13 @@
+[isort]
+combine_as_imports=true
+ensure_newline_before_comments=true
+force_grid_wrap=2
+force_sort_within_sections=true
+include_trailing_comma=true
+#lines_before_imports=2 # currently broken
+lines_after_imports=2
+lines_between_sections=1
+multi_line_output=3
+order_by_type=false
+remove_redundant_aliases=true
+use_parentheses=true
diff --git a/setup.py b/setup.py
index 24384398..f22a1842 100644
--- a/setup.py
+++ b/setup.py
@@ -1,5 +1,4 @@
#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
@@ -20,9 +19,17 @@
import os
-from setuptools import find_packages, setup
-from neo4j.meta import package, version
+from setuptools import (
+ find_packages,
+ setup,
+)
+
+from neo4j.meta import (
+ package,
+ version,
+)
+
install_requires = [
"pytz",
@@ -44,7 +51,8 @@
}
packages = find_packages(exclude=["tests"])
-readme_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "README.rst"))
+readme_path = os.path.abspath(os.path.join(os.path.dirname(__file__),
+ "README.rst"))
with open(readme_path, mode="r", encoding="utf-8") as fr:
readme = fr.read()
diff --git a/testkit/.dockerignore b/testkit/.dockerignore
index fe7d40a5..f104652b 100644
--- a/testkit/.dockerignore
+++ b/testkit/.dockerignore
@@ -1,2 +1 @@
*.py
-
diff --git a/testkit/backend.py b/testkit/backend.py
index ca5a4779..5366afe9 100644
--- a/testkit/backend.py
+++ b/testkit/backend.py
@@ -1,8 +1,30 @@
+#!/usr/bin/env python
+
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
import subprocess
import sys
+
if __name__ == "__main__":
- subprocess.check_call(
- ["python", "-m", "testkitbackend"],
- stdout=sys.stdout, stderr=sys.stderr
- )
+ cmd = ["python", "-m", "testkitbackend"]
+ if "TEST_BACKEND_SERVER" in os.environ:
+ cmd.append(os.environ["TEST_BACKEND_SERVER"])
+ subprocess.check_call(cmd, stdout=sys.stdout, stderr=sys.stderr)
diff --git a/testkit/build.py b/testkit/build.py
index 132452e5..5de76f14 100644
--- a/testkit/build.py
+++ b/testkit/build.py
@@ -1,7 +1,29 @@
+#!/usr/bin/env python
+
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
"""
Executed in Go driver container.
Responsible for building driver and test backend.
"""
+
+
import subprocess
diff --git a/testkit/integration.py b/testkit/integration.py
index 244a421f..c4b1c685 100644
--- a/testkit/integration.py
+++ b/testkit/integration.py
@@ -1,2 +1,22 @@
+#!/usr/bin/env python
+
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
if __name__ == "__main__":
pass
diff --git a/testkit/stress.py b/testkit/stress.py
index bc9718f2..3d171c49 100644
--- a/testkit/stress.py
+++ b/testkit/stress.py
@@ -1,5 +1,23 @@
-import subprocess
-import os
+#!/usr/bin/env python
+
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
import sys
diff --git a/testkit/unittests.py b/testkit/unittests.py
index 9561611d..a3730195 100644
--- a/testkit/unittests.py
+++ b/testkit/unittests.py
@@ -1,3 +1,23 @@
+#!/usr/bin/env python
+
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
import subprocess
@@ -7,5 +27,4 @@ def run(args):
if __name__ == "__main__":
- run([
- "python", "-m", "tox", "-c", "tox-unit.ini"])
+ run(["python", "-m", "tox", "-c", "tox-unit.ini"])
diff --git a/testkitbackend/__init__.py b/testkitbackend/__init__.py
index e69de29b..b81a309d 100644
--- a/testkitbackend/__init__.py
+++ b/testkitbackend/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/testkitbackend/__main__.py b/testkitbackend/__main__.py
index 3e8a4b87..2dec1adb 100644
--- a/testkitbackend/__main__.py
+++ b/testkitbackend/__main__.py
@@ -1,6 +1,52 @@
-from .server import Server
+#!/usr/bin/env python
-if __name__ == "__main__":
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import asyncio
+import sys
+
+from .server import (
+ AsyncServer,
+ Server,
+)
+
+
+def sync_main():
server = Server(("0.0.0.0", 9876))
while True:
server.handle_request()
+
+
+def async_main():
+ async def main():
+ server = AsyncServer(("0.0.0.0", 9876))
+ await server.start()
+ try:
+ await server.serve_forever()
+ finally:
+ server.stop()
+
+ asyncio.run(main())
+
+
+if __name__ == "__main__":
+ if len(sys.argv) == 2 and sys.argv[1].lower().strip() == "async":
+ async_main()
+ else:
+ sync_main()
diff --git a/testkitbackend/_async/__init__.py b/testkitbackend/_async/__init__.py
new file mode 100644
index 00000000..b81a309d
--- /dev/null
+++ b/testkitbackend/_async/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/testkitbackend/_async/backend.py b/testkitbackend/_async/backend.py
new file mode 100644
index 00000000..014e99e0
--- /dev/null
+++ b/testkitbackend/_async/backend.py
@@ -0,0 +1,145 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import asyncio
+from inspect import (
+ getmembers,
+ isfunction,
+)
+from json import (
+ dumps,
+ loads,
+)
+import traceback
+
+from neo4j._exceptions import BoltError
+from neo4j.exceptions import (
+ DriverError,
+ Neo4jError,
+ UnsupportedServerProduct,
+)
+
+from . import requests
+from .._driver_logger import (
+ buffer_handler,
+ log,
+)
+from ..backend import Request
+
+
+class AsyncBackend:
+ def __init__(self, rd, wr):
+ self._rd = rd
+ self._wr = wr
+ self.drivers = {}
+ self.custom_resolutions = {}
+ self.dns_resolutions = {}
+ self.sessions = {}
+ self.results = {}
+ self.errors = {}
+ self.transactions = {}
+ self.errors = {}
+ self.key = 0
+ # Collect all request handlers
+ self._requestHandlers = dict(
+ [m for m in getmembers(requests, isfunction)])
+
+ def next_key(self):
+ self.key = self.key + 1
+ return self.key
+
+ async def process_request(self):
+ """ Reads next request from the stream and processes it.
+ """
+ in_request = False
+ request = ""
+ async for line in self._rd:
+ # Remove trailing newline
+ line = line.decode('UTF-8').rstrip()
+ if line == "#request begin":
+ in_request = True
+ elif line == "#request end":
+ await self._process(request)
+ return True
+ else:
+ if in_request:
+ request = request + line
+ return False
+
+ async def _process(self, request):
+ """ Process a received request by retrieving handler that
+ corresponds to the request name.
+ """
+ try:
+ request = loads(request, object_pairs_hook=Request)
+ if not isinstance(request, Request):
+ raise Exception("Request is not an object")
+ name = request.get('name', 'invalid')
+ handler = self._requestHandlers.get(name)
+ if not handler:
+ raise Exception("No request handler for " + name)
+ data = request["data"]
+ log.info("<<< " + name + dumps(data))
+ await handler(self, data)
+ unsused_keys = request.unseen_keys
+ if unsused_keys:
+ raise NotImplementedError(
+ "Backend does not support some properties of the " + name +
+ " request: " + ", ".join(unsused_keys)
+ )
+ except (Neo4jError, DriverError, UnsupportedServerProduct,
+ BoltError) as e:
+ log.debug(traceback.format_exc())
+ if isinstance(e, Neo4jError):
+ msg = "" if e.message is None else str(e.message)
+ else:
+ msg = str(e.args[0]) if e.args else ""
+
+ key = self.next_key()
+ self.errors[key] = e
+ payload = {"id": key, "errorType": str(type(e)), "msg": msg}
+ if isinstance(e, Neo4jError):
+ payload["code"] = e.code
+ await self.send_response("DriverError", payload)
+ except requests.FrontendError as e:
+ await self.send_response("FrontendError", {"msg": str(e)})
+ except Exception:
+ tb = traceback.format_exc()
+ log.error(tb)
+ await self.send_response("BackendError", {"msg": tb})
+
+ async def send_response(self, name, data):
+ """ Sends a response to backend.
+ """
+ with buffer_handler.lock:
+ log_output = buffer_handler.stream.getvalue()
+ buffer_handler.stream.truncate(0)
+ buffer_handler.stream.seek(0)
+ if not log_output.endswith("\n"):
+ log_output += "\n"
+ self._wr.write(log_output.encode("utf-8"))
+ response = {"name": name, "data": data}
+ response = dumps(response)
+ log.info(">>> " + name + dumps(data))
+ self._wr.write(b"#response begin\n")
+ self._wr.write(bytes(response+"\n", "utf-8"))
+ self._wr.write(b"#response end\n")
+ if isinstance(self._wr, asyncio.StreamWriter):
+ await self._wr.drain()
+ else:
+ self._wr.flush()
diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py
new file mode 100644
index 00000000..e80e39c5
--- /dev/null
+++ b/testkitbackend/_async/requests.py
@@ -0,0 +1,444 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import json
+from os import path
+
+import neo4j
+from neo4j._async_compat.util import AsyncUtil
+
+from .. import (
+ fromtestkit,
+ totestkit,
+)
+
+
+class FrontendError(Exception):
+ pass
+
+
+def load_config():
+ config_path = path.join(path.dirname(__file__), "..", "test_config.json")
+ with open(config_path, "r") as fd:
+ config = json.load(fd)
+ skips = config["skips"]
+ features = [k for k, v in config["features"].items() if v is True]
+ import ssl
+ if ssl.HAS_TLSv1_3:
+ features += ["Feature:TLS:1.3"]
+ return skips, features
+
+
+SKIPPED_TESTS, FEATURES = load_config()
+
+
+async def StartTest(backend, data):
+ if data["testName"] in SKIPPED_TESTS:
+ await backend.send_response("SkipTest", {
+ "reason": SKIPPED_TESTS[data["testName"]]
+ })
+ else:
+ await backend.send_response("RunTest", {})
+
+
+async def GetFeatures(backend, data):
+ await backend.send_response("FeatureList", {"features": FEATURES})
+
+
+async def NewDriver(backend, data):
+ auth_token = data["authorizationToken"]["data"]
+ data["authorizationToken"].mark_item_as_read_if_equals(
+ "name", "AuthorizationToken"
+ )
+ scheme = auth_token["scheme"]
+ if scheme == "basic":
+ auth = neo4j.basic_auth(
+ auth_token["principal"], auth_token["credentials"],
+ realm=auth_token.get("realm", None)
+ )
+ elif scheme == "kerberos":
+ auth = neo4j.kerberos_auth(auth_token["credentials"])
+ elif scheme == "bearer":
+ auth = neo4j.bearer_auth(auth_token["credentials"])
+ else:
+ auth = neo4j.custom_auth(
+ auth_token["principal"], auth_token["credentials"],
+ auth_token["realm"], auth_token["scheme"],
+ **auth_token.get("parameters", {})
+ )
+ auth_token.mark_item_as_read("parameters", recursive=True)
+ resolver = None
+ if data["resolverRegistered"] or data["domainNameResolverRegistered"]:
+ resolver = resolution_func(backend, data["resolverRegistered"],
+ data["domainNameResolverRegistered"])
+ connection_timeout = data.get("connectionTimeoutMs")
+ if connection_timeout is not None:
+ connection_timeout /= 1000
+ max_transaction_retry_time = data.get("maxTxRetryTimeMs")
+ if max_transaction_retry_time is not None:
+ max_transaction_retry_time /= 1000
+ data.mark_item_as_read("domainNameResolverRegistered")
+ driver = neo4j.AsyncGraphDatabase.driver(
+ data["uri"], auth=auth, user_agent=data["userAgent"],
+ resolver=resolver, connection_timeout=connection_timeout,
+ fetch_size=data.get("fetchSize"),
+ max_transaction_retry_time=max_transaction_retry_time,
+ )
+ key = backend.next_key()
+ backend.drivers[key] = driver
+ await backend.send_response("Driver", {"id": key})
+
+
+async def VerifyConnectivity(backend, data):
+ driver_id = data["driverId"]
+ driver = backend.drivers[driver_id]
+ await driver.verify_connectivity()
+ await backend.send_response("Driver", {"id": driver_id})
+
+
+async def CheckMultiDBSupport(backend, data):
+ driver_id = data["driverId"]
+ driver = backend.drivers[driver_id]
+ await backend.send_response("MultiDBSupport", {
+ "id": backend.next_key(), "available": await driver.supports_multi_db()
+ })
+
+
+def resolution_func(backend, custom_resolver=False, custom_dns_resolver=False):
+ # This solution (putting custom resolution together with DNS resolution
+ # into one function only works because the Python driver calls the custom
+ # resolver function for every connection, which is not true for all
+ # drivers. Properly exposing a way to change the DNS lookup behavior is not
+ # possible without changing the driver's code.
+ assert custom_resolver or custom_dns_resolver
+
+ async def resolve(address):
+ addresses = [":".join(map(str, address))]
+ if custom_resolver:
+ key = backend.next_key()
+ await backend.send_response("ResolverResolutionRequired", {
+ "id": key,
+ "address": addresses[0]
+ })
+ if not await backend.process_request():
+ # connection was closed before end of next message
+ return []
+ if key not in backend.custom_resolutions:
+ raise RuntimeError(
+ "Backend did not receive expected "
+ "ResolverResolutionCompleted message for id %s" % key
+ )
+ addresses = backend.custom_resolutions.pop(key)
+ if custom_dns_resolver:
+ dns_resolved_addresses = []
+ for address in addresses:
+ key = backend.next_key()
+ address = address.rsplit(":", 1)
+ await backend.send_response("DomainNameResolutionRequired", {
+ "id": key,
+ "name": address[0]
+ })
+ if not await backend.process_request():
+ # connection was closed before end of next message
+ return []
+ if key not in backend.dns_resolutions:
+ raise RuntimeError(
+ "Backend did not receive expected "
+ "DomainNameResolutionCompleted message for id %s" % key
+ )
+ dns_resolved_addresses += list(map(
+ lambda a: ":".join((a, *address[1:])),
+ backend.dns_resolutions.pop(key)
+ ))
+
+ addresses = dns_resolved_addresses
+
+ return list(map(neo4j.Address.parse, addresses))
+
+ return resolve
+
+
+async def ResolverResolutionCompleted(backend, data):
+ backend.custom_resolutions[data["requestId"]] = data["addresses"]
+
+
+async def DomainNameResolutionCompleted(backend, data):
+ backend.dns_resolutions[data["requestId"]] = data["addresses"]
+
+
+async def DriverClose(backend, data):
+ key = data["driverId"]
+ driver = backend.drivers[key]
+ await driver.close()
+ await backend.send_response("Driver", {"id": key})
+
+
+class SessionTracker:
+ """ Keeps some extra state about the tracked session
+ """
+
+ def __init__(self, session):
+ self.session = session
+ self.state = ""
+ self.error_id = ""
+
+
+async def NewSession(backend, data):
+ driver = backend.drivers[data["driverId"]]
+ access_mode = data["accessMode"]
+ if access_mode == "r":
+ access_mode = neo4j.READ_ACCESS
+ elif access_mode == "w":
+ access_mode = neo4j.WRITE_ACCESS
+ else:
+ raise ValueError("Unknown access mode:" + access_mode)
+ config = {
+ "default_access_mode": access_mode,
+ "bookmarks": data["bookmarks"],
+ "database": data["database"],
+ "fetch_size": data.get("fetchSize", None),
+ "impersonated_user": data.get("impersonatedUser", None),
+
+ }
+ session = driver.session(**config)
+ key = backend.next_key()
+ backend.sessions[key] = SessionTracker(session)
+ await backend.send_response("Session", {"id": key})
+
+
+async def SessionRun(backend, data):
+ session = backend.sessions[data["sessionId"]].session
+ query, params = fromtestkit.to_query_and_params(data)
+ result = await session.run(query, parameters=params)
+ key = backend.next_key()
+ backend.results[key] = result
+ await backend.send_response("Result", {"id": key, "keys": result.keys()})
+
+
+async def SessionClose(backend, data):
+ key = data["sessionId"]
+ session = backend.sessions[key].session
+ await session.close()
+ del backend.sessions[key]
+ await backend.send_response("Session", {"id": key})
+
+
+async def SessionBeginTransaction(backend, data):
+ key = data["sessionId"]
+ session = backend.sessions[key].session
+ metadata, timeout = fromtestkit.to_meta_and_timeout(data)
+ tx = await session.begin_transaction(metadata=metadata, timeout=timeout)
+ key = backend.next_key()
+ backend.transactions[key] = tx
+ await backend.send_response("Transaction", {"id": key})
+
+
+async def SessionReadTransaction(backend, data):
+ await transactionFunc(backend, data, True)
+
+
+async def SessionWriteTransaction(backend, data):
+ await transactionFunc(backend, data, False)
+
+
+async def transactionFunc(backend, data, is_read):
+ key = data["sessionId"]
+ session_tracker = backend.sessions[key]
+ session = session_tracker.session
+ metadata, timeout = fromtestkit.to_meta_and_timeout(data)
+
+ @neo4j.unit_of_work(metadata=metadata, timeout=timeout)
+ async def func(tx):
+ txkey = backend.next_key()
+ backend.transactions[txkey] = tx
+ session_tracker.state = ''
+ await backend.send_response("RetryableTry", {"id": txkey})
+
+ cont = True
+ while cont:
+ cont = await backend.process_request()
+ if session_tracker.state == '+':
+ cont = False
+ elif session_tracker.state == '-':
+ if session_tracker.error_id:
+ raise backend.errors[session_tracker.error_id]
+ else:
+ raise FrontendError("Client said no")
+
+ if is_read:
+ await session.read_transaction(func)
+ else:
+ await session.write_transaction(func)
+ await backend.send_response("RetryableDone", {})
+
+
+async def RetryablePositive(backend, data):
+ key = data["sessionId"]
+ session_tracker = backend.sessions[key]
+ session_tracker.state = '+'
+
+
+async def RetryableNegative(backend, data):
+ key = data["sessionId"]
+ session_tracker = backend.sessions[key]
+ session_tracker.state = '-'
+ session_tracker.error_id = data.get('errorId', '')
+
+
+async def SessionLastBookmarks(backend, data):
+ key = data["sessionId"]
+ session = backend.sessions[key].session
+ bookmark = await session.last_bookmark()
+ bookmarks = []
+ if bookmark:
+ bookmarks.append(bookmark)
+ await backend.send_response("Bookmarks", {"bookmarks": bookmarks})
+
+
+async def TransactionRun(backend, data):
+ key = data["txId"]
+ tx = backend.transactions[key]
+ cypher, params = fromtestkit.to_cypher_and_params(data)
+ result = await tx.run(cypher, parameters=params)
+ key = backend.next_key()
+ backend.results[key] = result
+ await backend.send_response("Result", {"id": key, "keys": result.keys()})
+
+
+async def TransactionCommit(backend, data):
+ key = data["txId"]
+ tx = backend.transactions[key]
+ await tx.commit()
+ await backend.send_response("Transaction", {"id": key})
+
+
+async def TransactionRollback(backend, data):
+ key = data["txId"]
+ tx = backend.transactions[key]
+ await tx.rollback()
+ await backend.send_response("Transaction", {"id": key})
+
+
+async def TransactionClose(backend, data):
+ key = data["txId"]
+ tx = backend.transactions[key]
+ await tx.close()
+ await backend.send_response("Transaction", {"id": key})
+
+
+async def ResultNext(backend, data):
+ result = backend.results[data["resultId"]]
+
+ try:
+ record = await AsyncUtil.next(AsyncUtil.iter(result))
+ except StopAsyncIteration:
+ await backend.send_response("NullRecord", {})
+ return
+ await backend.send_response("Record", totestkit.record(record))
+
+
+async def ResultSingle(backend, data):
+ result = backend.results[data["resultId"]]
+ await backend.send_response("Record", totestkit.record(result.single()))
+
+
+async def ResultPeek(backend, data):
+ result = backend.results[data["resultId"]]
+ record = await result.peek()
+ if record is not None:
+ await backend.send_response("Record", totestkit.record(record))
+ else:
+ await backend.send_response("NullRecord", {})
+
+
+async def ResultList(backend, data):
+ result = backend.results[data["resultId"]]
+ records = await AsyncUtil.list(result)
+ await backend.send_response("RecordList", {
+ "records": [totestkit.record(r) for r in records]
+ })
+
+
+async def ResultConsume(backend, data):
+ result = backend.results[data["resultId"]]
+ summary = await result.consume()
+ from neo4j import ResultSummary
+ assert isinstance(summary, ResultSummary)
+ await backend.send_response("Summary", {
+ "serverInfo": {
+ "address": ":".join(map(str, summary.server.address)),
+ "agent": summary.server.agent,
+ "protocolVersion":
+ ".".join(map(str, summary.server.protocol_version)),
+ },
+ "counters": None if not summary.counters else {
+ "constraintsAdded": summary.counters.constraints_added,
+ "constraintsRemoved": summary.counters.constraints_removed,
+ "containsSystemUpdates": summary.counters.contains_system_updates,
+ "containsUpdates": summary.counters.contains_updates,
+ "indexesAdded": summary.counters.indexes_added,
+ "indexesRemoved": summary.counters.indexes_removed,
+ "labelsAdded": summary.counters.labels_added,
+ "labelsRemoved": summary.counters.labels_removed,
+ "nodesCreated": summary.counters.nodes_created,
+ "nodesDeleted": summary.counters.nodes_deleted,
+ "propertiesSet": summary.counters.properties_set,
+ "relationshipsCreated": summary.counters.relationships_created,
+ "relationshipsDeleted": summary.counters.relationships_deleted,
+ "systemUpdates": summary.counters.system_updates,
+ },
+ "database": summary.database,
+ "notifications": summary.notifications,
+ "plan": summary.plan,
+ "profile": summary.profile,
+ "query": {
+ "text": summary.query,
+ "parameters": {k: totestkit.field(v)
+ for k, v in summary.parameters.items()},
+ },
+ "queryType": summary.query_type,
+ "resultAvailableAfter": summary.result_available_after,
+ "resultConsumedAfter": summary.result_consumed_after,
+ })
+
+
+async def ForcedRoutingTableUpdate(backend, data):
+ driver_id = data["driverId"]
+ driver = backend.drivers[driver_id]
+ database = data["database"]
+ bookmarks = data["bookmarks"]
+ async with driver._pool.refresh_lock:
+ await driver._pool.update_routing_table(
+ database=database, imp_user=None, bookmarks=bookmarks
+ )
+ await backend.send_response("Driver", {"id": driver_id})
+
+
+async def GetRoutingTable(backend, data):
+ driver_id = data["driverId"]
+ database = data["database"]
+ driver = backend.drivers[driver_id]
+ routing_table = driver._pool.routing_tables[database]
+ response_data = {
+ "database": routing_table.database,
+ "ttl": routing_table.ttl,
+ }
+ for role in ("routers", "readers", "writers"):
+ addresses = routing_table.__getattribute__(role)
+ response_data[role] = list(map(str, addresses))
+ await backend.send_response("RoutingTable", response_data)
diff --git a/testkitbackend/_driver_logger.py b/testkitbackend/_driver_logger.py
new file mode 100644
index 00000000..ef09ec35
--- /dev/null
+++ b/testkitbackend/_driver_logger.py
@@ -0,0 +1,41 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import io
+import logging
+import sys
+
+
+buffer_handler = logging.StreamHandler(io.StringIO())
+buffer_handler.setLevel(logging.DEBUG)
+
+handler = logging.StreamHandler(sys.stdout)
+handler.setLevel(logging.DEBUG)
+logging.getLogger("neo4j").addHandler(handler)
+logging.getLogger("neo4j").addHandler(buffer_handler)
+logging.getLogger("neo4j").setLevel(logging.DEBUG)
+
+log = logging.getLogger("testkitbackend")
+log.addHandler(handler)
+log.setLevel(logging.DEBUG)
+
+
+__all__ = [
+ "buffer_handler",
+ "log",
+]
diff --git a/testkitbackend/_sync/__init__.py b/testkitbackend/_sync/__init__.py
new file mode 100644
index 00000000..b81a309d
--- /dev/null
+++ b/testkitbackend/_sync/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/testkitbackend/_sync/backend.py b/testkitbackend/_sync/backend.py
new file mode 100644
index 00000000..9f1bae94
--- /dev/null
+++ b/testkitbackend/_sync/backend.py
@@ -0,0 +1,145 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import asyncio
+from inspect import (
+ getmembers,
+ isfunction,
+)
+from json import (
+ dumps,
+ loads,
+)
+import traceback
+
+from neo4j._exceptions import BoltError
+from neo4j.exceptions import (
+ DriverError,
+ Neo4jError,
+ UnsupportedServerProduct,
+)
+
+from . import requests
+from .._driver_logger import (
+ buffer_handler,
+ log,
+)
+from ..backend import Request
+
+
+class Backend:
+ def __init__(self, rd, wr):
+ self._rd = rd
+ self._wr = wr
+ self.drivers = {}
+ self.custom_resolutions = {}
+ self.dns_resolutions = {}
+ self.sessions = {}
+ self.results = {}
+ self.errors = {}
+ self.transactions = {}
+ self.errors = {}
+ self.key = 0
+ # Collect all request handlers
+ self._requestHandlers = dict(
+ [m for m in getmembers(requests, isfunction)])
+
+ def next_key(self):
+ self.key = self.key + 1
+ return self.key
+
+ def process_request(self):
+ """ Reads next request from the stream and processes it.
+ """
+ in_request = False
+ request = ""
+ for line in self._rd:
+ # Remove trailing newline
+ line = line.decode('UTF-8').rstrip()
+ if line == "#request begin":
+ in_request = True
+ elif line == "#request end":
+ self._process(request)
+ return True
+ else:
+ if in_request:
+ request = request + line
+ return False
+
+ def _process(self, request):
+ """ Process a received request by retrieving handler that
+ corresponds to the request name.
+ """
+ try:
+ request = loads(request, object_pairs_hook=Request)
+ if not isinstance(request, Request):
+ raise Exception("Request is not an object")
+ name = request.get('name', 'invalid')
+ handler = self._requestHandlers.get(name)
+ if not handler:
+ raise Exception("No request handler for " + name)
+ data = request["data"]
+ log.info("<<< " + name + dumps(data))
+ handler(self, data)
+ unsused_keys = request.unseen_keys
+ if unsused_keys:
+ raise NotImplementedError(
+ "Backend does not support some properties of the " + name +
+ " request: " + ", ".join(unsused_keys)
+ )
+ except (Neo4jError, DriverError, UnsupportedServerProduct,
+ BoltError) as e:
+ log.debug(traceback.format_exc())
+ if isinstance(e, Neo4jError):
+ msg = "" if e.message is None else str(e.message)
+ else:
+ msg = str(e.args[0]) if e.args else ""
+
+ key = self.next_key()
+ self.errors[key] = e
+ payload = {"id": key, "errorType": str(type(e)), "msg": msg}
+ if isinstance(e, Neo4jError):
+ payload["code"] = e.code
+ self.send_response("DriverError", payload)
+ except requests.FrontendError as e:
+ self.send_response("FrontendError", {"msg": str(e)})
+ except Exception:
+ tb = traceback.format_exc()
+ log.error(tb)
+ self.send_response("BackendError", {"msg": tb})
+
+ def send_response(self, name, data):
+ """ Sends a response to backend.
+ """
+ with buffer_handler.lock:
+ log_output = buffer_handler.stream.getvalue()
+ buffer_handler.stream.truncate(0)
+ buffer_handler.stream.seek(0)
+ if not log_output.endswith("\n"):
+ log_output += "\n"
+ self._wr.write(log_output.encode("utf-8"))
+ response = {"name": name, "data": data}
+ response = dumps(response)
+ log.info(">>> " + name + dumps(data))
+ self._wr.write(b"#response begin\n")
+ self._wr.write(bytes(response+"\n", "utf-8"))
+ self._wr.write(b"#response end\n")
+ if isinstance(self._wr, asyncio.StreamWriter):
+ self._wr.drain()
+ else:
+ self._wr.flush()
diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py
new file mode 100644
index 00000000..744d8356
--- /dev/null
+++ b/testkitbackend/_sync/requests.py
@@ -0,0 +1,444 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import json
+from os import path
+
+import neo4j
+from neo4j._async_compat.util import Util
+
+from .. import (
+ fromtestkit,
+ totestkit,
+)
+
+
+class FrontendError(Exception):
+ pass
+
+
+def load_config():
+ config_path = path.join(path.dirname(__file__), "..", "test_config.json")
+ with open(config_path, "r") as fd:
+ config = json.load(fd)
+ skips = config["skips"]
+ features = [k for k, v in config["features"].items() if v is True]
+ import ssl
+ if ssl.HAS_TLSv1_3:
+ features += ["Feature:TLS:1.3"]
+ return skips, features
+
+
+SKIPPED_TESTS, FEATURES = load_config()
+
+
+def StartTest(backend, data):
+ if data["testName"] in SKIPPED_TESTS:
+ backend.send_response("SkipTest", {
+ "reason": SKIPPED_TESTS[data["testName"]]
+ })
+ else:
+ backend.send_response("RunTest", {})
+
+
+def GetFeatures(backend, data):
+ backend.send_response("FeatureList", {"features": FEATURES})
+
+
+def NewDriver(backend, data):
+ auth_token = data["authorizationToken"]["data"]
+ data["authorizationToken"].mark_item_as_read_if_equals(
+ "name", "AuthorizationToken"
+ )
+ scheme = auth_token["scheme"]
+ if scheme == "basic":
+ auth = neo4j.basic_auth(
+ auth_token["principal"], auth_token["credentials"],
+ realm=auth_token.get("realm", None)
+ )
+ elif scheme == "kerberos":
+ auth = neo4j.kerberos_auth(auth_token["credentials"])
+ elif scheme == "bearer":
+ auth = neo4j.bearer_auth(auth_token["credentials"])
+ else:
+ auth = neo4j.custom_auth(
+ auth_token["principal"], auth_token["credentials"],
+ auth_token["realm"], auth_token["scheme"],
+ **auth_token.get("parameters", {})
+ )
+ auth_token.mark_item_as_read("parameters", recursive=True)
+ resolver = None
+ if data["resolverRegistered"] or data["domainNameResolverRegistered"]:
+ resolver = resolution_func(backend, data["resolverRegistered"],
+ data["domainNameResolverRegistered"])
+ connection_timeout = data.get("connectionTimeoutMs")
+ if connection_timeout is not None:
+ connection_timeout /= 1000
+ max_transaction_retry_time = data.get("maxTxRetryTimeMs")
+ if max_transaction_retry_time is not None:
+ max_transaction_retry_time /= 1000
+ data.mark_item_as_read("domainNameResolverRegistered")
+ driver = neo4j.GraphDatabase.driver(
+ data["uri"], auth=auth, user_agent=data["userAgent"],
+ resolver=resolver, connection_timeout=connection_timeout,
+ fetch_size=data.get("fetchSize"),
+ max_transaction_retry_time=max_transaction_retry_time,
+ )
+ key = backend.next_key()
+ backend.drivers[key] = driver
+ backend.send_response("Driver", {"id": key})
+
+
+def VerifyConnectivity(backend, data):
+ driver_id = data["driverId"]
+ driver = backend.drivers[driver_id]
+ driver.verify_connectivity()
+ backend.send_response("Driver", {"id": driver_id})
+
+
+def CheckMultiDBSupport(backend, data):
+ driver_id = data["driverId"]
+ driver = backend.drivers[driver_id]
+ backend.send_response("MultiDBSupport", {
+ "id": backend.next_key(), "available": driver.supports_multi_db()
+ })
+
+
+def resolution_func(backend, custom_resolver=False, custom_dns_resolver=False):
+ # This solution (putting custom resolution together with DNS resolution
+ # into one function only works because the Python driver calls the custom
+ # resolver function for every connection, which is not true for all
+ # drivers. Properly exposing a way to change the DNS lookup behavior is not
+ # possible without changing the driver's code.
+ assert custom_resolver or custom_dns_resolver
+
+ def resolve(address):
+ addresses = [":".join(map(str, address))]
+ if custom_resolver:
+ key = backend.next_key()
+ backend.send_response("ResolverResolutionRequired", {
+ "id": key,
+ "address": addresses[0]
+ })
+ if not backend.process_request():
+ # connection was closed before end of next message
+ return []
+ if key not in backend.custom_resolutions:
+ raise RuntimeError(
+ "Backend did not receive expected "
+ "ResolverResolutionCompleted message for id %s" % key
+ )
+ addresses = backend.custom_resolutions.pop(key)
+ if custom_dns_resolver:
+ dns_resolved_addresses = []
+ for address in addresses:
+ key = backend.next_key()
+ address = address.rsplit(":", 1)
+ backend.send_response("DomainNameResolutionRequired", {
+ "id": key,
+ "name": address[0]
+ })
+ if not backend.process_request():
+ # connection was closed before end of next message
+ return []
+ if key not in backend.dns_resolutions:
+ raise RuntimeError(
+ "Backend did not receive expected "
+ "DomainNameResolutionCompleted message for id %s" % key
+ )
+ dns_resolved_addresses += list(map(
+ lambda a: ":".join((a, *address[1:])),
+ backend.dns_resolutions.pop(key)
+ ))
+
+ addresses = dns_resolved_addresses
+
+ return list(map(neo4j.Address.parse, addresses))
+
+ return resolve
+
+
+def ResolverResolutionCompleted(backend, data):
+ backend.custom_resolutions[data["requestId"]] = data["addresses"]
+
+
+def DomainNameResolutionCompleted(backend, data):
+ backend.dns_resolutions[data["requestId"]] = data["addresses"]
+
+
+def DriverClose(backend, data):
+ key = data["driverId"]
+ driver = backend.drivers[key]
+ driver.close()
+ backend.send_response("Driver", {"id": key})
+
+
+class SessionTracker:
+ """ Keeps some extra state about the tracked session
+ """
+
+ def __init__(self, session):
+ self.session = session
+ self.state = ""
+ self.error_id = ""
+
+
+def NewSession(backend, data):
+ driver = backend.drivers[data["driverId"]]
+ access_mode = data["accessMode"]
+ if access_mode == "r":
+ access_mode = neo4j.READ_ACCESS
+ elif access_mode == "w":
+ access_mode = neo4j.WRITE_ACCESS
+ else:
+ raise ValueError("Unknown access mode:" + access_mode)
+ config = {
+ "default_access_mode": access_mode,
+ "bookmarks": data["bookmarks"],
+ "database": data["database"],
+ "fetch_size": data.get("fetchSize", None),
+ "impersonated_user": data.get("impersonatedUser", None),
+
+ }
+ session = driver.session(**config)
+ key = backend.next_key()
+ backend.sessions[key] = SessionTracker(session)
+ backend.send_response("Session", {"id": key})
+
+
+def SessionRun(backend, data):
+ session = backend.sessions[data["sessionId"]].session
+ query, params = fromtestkit.to_query_and_params(data)
+ result = session.run(query, parameters=params)
+ key = backend.next_key()
+ backend.results[key] = result
+ backend.send_response("Result", {"id": key, "keys": result.keys()})
+
+
+def SessionClose(backend, data):
+ key = data["sessionId"]
+ session = backend.sessions[key].session
+ session.close()
+ del backend.sessions[key]
+ backend.send_response("Session", {"id": key})
+
+
+def SessionBeginTransaction(backend, data):
+ key = data["sessionId"]
+ session = backend.sessions[key].session
+ metadata, timeout = fromtestkit.to_meta_and_timeout(data)
+ tx = session.begin_transaction(metadata=metadata, timeout=timeout)
+ key = backend.next_key()
+ backend.transactions[key] = tx
+ backend.send_response("Transaction", {"id": key})
+
+
+def SessionReadTransaction(backend, data):
+ transactionFunc(backend, data, True)
+
+
+def SessionWriteTransaction(backend, data):
+ transactionFunc(backend, data, False)
+
+
+def transactionFunc(backend, data, is_read):
+ key = data["sessionId"]
+ session_tracker = backend.sessions[key]
+ session = session_tracker.session
+ metadata, timeout = fromtestkit.to_meta_and_timeout(data)
+
+ @neo4j.unit_of_work(metadata=metadata, timeout=timeout)
+ def func(tx):
+ txkey = backend.next_key()
+ backend.transactions[txkey] = tx
+ session_tracker.state = ''
+ backend.send_response("RetryableTry", {"id": txkey})
+
+ cont = True
+ while cont:
+ cont = backend.process_request()
+ if session_tracker.state == '+':
+ cont = False
+ elif session_tracker.state == '-':
+ if session_tracker.error_id:
+ raise backend.errors[session_tracker.error_id]
+ else:
+ raise FrontendError("Client said no")
+
+ if is_read:
+ session.read_transaction(func)
+ else:
+ session.write_transaction(func)
+ backend.send_response("RetryableDone", {})
+
+
+def RetryablePositive(backend, data):
+ key = data["sessionId"]
+ session_tracker = backend.sessions[key]
+ session_tracker.state = '+'
+
+
+def RetryableNegative(backend, data):
+ key = data["sessionId"]
+ session_tracker = backend.sessions[key]
+ session_tracker.state = '-'
+ session_tracker.error_id = data.get('errorId', '')
+
+
+def SessionLastBookmarks(backend, data):
+ key = data["sessionId"]
+ session = backend.sessions[key].session
+ bookmark = session.last_bookmark()
+ bookmarks = []
+ if bookmark:
+ bookmarks.append(bookmark)
+ backend.send_response("Bookmarks", {"bookmarks": bookmarks})
+
+
+def TransactionRun(backend, data):
+ key = data["txId"]
+ tx = backend.transactions[key]
+ cypher, params = fromtestkit.to_cypher_and_params(data)
+ result = tx.run(cypher, parameters=params)
+ key = backend.next_key()
+ backend.results[key] = result
+ backend.send_response("Result", {"id": key, "keys": result.keys()})
+
+
+def TransactionCommit(backend, data):
+ key = data["txId"]
+ tx = backend.transactions[key]
+ tx.commit()
+ backend.send_response("Transaction", {"id": key})
+
+
+def TransactionRollback(backend, data):
+ key = data["txId"]
+ tx = backend.transactions[key]
+ tx.rollback()
+ backend.send_response("Transaction", {"id": key})
+
+
+def TransactionClose(backend, data):
+ key = data["txId"]
+ tx = backend.transactions[key]
+ tx.close()
+ backend.send_response("Transaction", {"id": key})
+
+
+def ResultNext(backend, data):
+ result = backend.results[data["resultId"]]
+
+ try:
+ record = Util.next(Util.iter(result))
+ except StopIteration:
+ backend.send_response("NullRecord", {})
+ return
+ backend.send_response("Record", totestkit.record(record))
+
+
+def ResultSingle(backend, data):
+ result = backend.results[data["resultId"]]
+ backend.send_response("Record", totestkit.record(result.single()))
+
+
+def ResultPeek(backend, data):
+ result = backend.results[data["resultId"]]
+ record = result.peek()
+ if record is not None:
+ backend.send_response("Record", totestkit.record(record))
+ else:
+ backend.send_response("NullRecord", {})
+
+
+def ResultList(backend, data):
+ result = backend.results[data["resultId"]]
+ records = Util.list(result)
+ backend.send_response("RecordList", {
+ "records": [totestkit.record(r) for r in records]
+ })
+
+
+def ResultConsume(backend, data):
+ result = backend.results[data["resultId"]]
+ summary = result.consume()
+ from neo4j import ResultSummary
+ assert isinstance(summary, ResultSummary)
+ backend.send_response("Summary", {
+ "serverInfo": {
+ "address": ":".join(map(str, summary.server.address)),
+ "agent": summary.server.agent,
+ "protocolVersion":
+ ".".join(map(str, summary.server.protocol_version)),
+ },
+ "counters": None if not summary.counters else {
+ "constraintsAdded": summary.counters.constraints_added,
+ "constraintsRemoved": summary.counters.constraints_removed,
+ "containsSystemUpdates": summary.counters.contains_system_updates,
+ "containsUpdates": summary.counters.contains_updates,
+ "indexesAdded": summary.counters.indexes_added,
+ "indexesRemoved": summary.counters.indexes_removed,
+ "labelsAdded": summary.counters.labels_added,
+ "labelsRemoved": summary.counters.labels_removed,
+ "nodesCreated": summary.counters.nodes_created,
+ "nodesDeleted": summary.counters.nodes_deleted,
+ "propertiesSet": summary.counters.properties_set,
+ "relationshipsCreated": summary.counters.relationships_created,
+ "relationshipsDeleted": summary.counters.relationships_deleted,
+ "systemUpdates": summary.counters.system_updates,
+ },
+ "database": summary.database,
+ "notifications": summary.notifications,
+ "plan": summary.plan,
+ "profile": summary.profile,
+ "query": {
+ "text": summary.query,
+ "parameters": {k: totestkit.field(v)
+ for k, v in summary.parameters.items()},
+ },
+ "queryType": summary.query_type,
+ "resultAvailableAfter": summary.result_available_after,
+ "resultConsumedAfter": summary.result_consumed_after,
+ })
+
+
+def ForcedRoutingTableUpdate(backend, data):
+ driver_id = data["driverId"]
+ driver = backend.drivers[driver_id]
+ database = data["database"]
+ bookmarks = data["bookmarks"]
+ with driver._pool.refresh_lock:
+ driver._pool.update_routing_table(
+ database=database, imp_user=None, bookmarks=bookmarks
+ )
+ backend.send_response("Driver", {"id": driver_id})
+
+
+def GetRoutingTable(backend, data):
+ driver_id = data["driverId"]
+ database = data["database"]
+ driver = backend.drivers[driver_id]
+ routing_table = driver._pool.routing_tables[database]
+ response_data = {
+ "database": routing_table.database,
+ "ttl": routing_table.ttl,
+ }
+ for role in ("routers", "readers", "writers"):
+ addresses = routing_table.__getattribute__(role)
+ response_data[role] = list(map(str, addresses))
+ backend.send_response("RoutingTable", response_data)
diff --git a/testkitbackend/backend.py b/testkitbackend/backend.py
index c488f8ad..f97dcc1c 100644
--- a/testkitbackend/backend.py
+++ b/testkitbackend/backend.py
@@ -14,39 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from inspect import (
- getmembers,
- isfunction,
-)
-import io
-from json import loads, dumps
-import logging
-import sys
-import traceback
-
-from neo4j._exceptions import (
- BoltError
-)
-from neo4j.exceptions import (
- DriverError,
- Neo4jError,
- UnsupportedServerProduct,
-)
-
-import testkitbackend.requests as requests
-
-buffer_handler = logging.StreamHandler(io.StringIO())
-buffer_handler.setLevel(logging.DEBUG)
-
-handler = logging.StreamHandler(sys.stdout)
-handler.setLevel(logging.DEBUG)
-logging.getLogger("neo4j").addHandler(handler)
-logging.getLogger("neo4j").addHandler(buffer_handler)
-logging.getLogger("neo4j").setLevel(logging.DEBUG)
-
-log = logging.getLogger("testkitbackend")
-log.addHandler(handler)
-log.setLevel(logging.DEBUG)
class Request(dict):
@@ -93,104 +60,3 @@ def unseen_keys(self):
@property
def seen_all_keys(self):
return not self.unseen_keys
-
-
-class Backend:
- def __init__(self, rd, wr):
- self._rd = rd
- self._wr = wr
- self.drivers = {}
- self.custom_resolutions = {}
- self.dns_resolutions = {}
- self.sessions = {}
- self.results = {}
- self.errors = {}
- self.transactions = {}
- self.errors = {}
- self.key = 0
- # Collect all request handlers
- self._requestHandlers = dict(
- [m for m in getmembers(requests, isfunction)])
-
- def next_key(self):
- self.key = self.key + 1
- return self.key
-
- def process_request(self):
- """ Reads next request from the stream and processes it.
- """
- in_request = False
- request = ""
- for line in self._rd:
- # Remove trailing newline
- line = line.decode('UTF-8').rstrip()
- if line == "#request begin":
- in_request = True
- elif line == "#request end":
- self._process(request)
- return True
- else:
- if in_request:
- request = request + line
- return False
-
- def _process(self, request):
- """ Process a received request by retrieving handler that
- corresponds to the request name.
- """
- try:
- request = loads(request, object_pairs_hook=Request)
- if not isinstance(request, Request):
- raise Exception("Request is not an object")
- name = request.get('name', 'invalid')
- handler = self._requestHandlers.get(name)
- if not handler:
- raise Exception("No request handler for " + name)
- data = request["data"]
- log.info("<<< " + name + dumps(data))
- handler(self, data)
- unsused_keys = request.unseen_keys
- if unsused_keys:
- raise NotImplementedError(
- "Backend does not support some properties of the " + name +
- " request: " + ", ".join(unsused_keys)
- )
- except (Neo4jError, DriverError, UnsupportedServerProduct,
- BoltError) as e:
- log.debug(traceback.format_exc())
- if isinstance(e, Neo4jError):
- msg = "" if e.message is None else str(e.message)
- else:
- msg = str(e.args[0]) if e.args else ""
-
- key = self.next_key()
- self.errors[key] = e
- payload = {"id": key, "errorType": str(type(e)), "msg": msg}
- if isinstance(e, Neo4jError):
- payload["code"] = e.code
- self.send_response("DriverError", payload)
- except requests.FrontendError as e:
- self.send_response("FrontendError", {"msg": str(e)})
- except Exception:
- tb = traceback.format_exc()
- log.error(tb)
- self.send_response("BackendError", {"msg": tb})
-
- def send_response(self, name, data):
- """ Sends a response to backend.
- """
- buffer_handler.acquire()
- log_output = buffer_handler.stream.getvalue()
- buffer_handler.stream.truncate(0)
- buffer_handler.stream.seek(0)
- buffer_handler.release()
- if not log_output.endswith("\n"):
- log_output += "\n"
- self._wr.write(log_output.encode("utf-8"))
- response = {"name": name, "data": data}
- response = dumps(response)
- log.info(">>> " + name + dumps(data))
- self._wr.write(b"#response begin\n")
- self._wr.write(bytes(response+"\n", "utf-8"))
- self._wr.write(b"#response end\n")
- self._wr.flush()
diff --git a/testkitbackend/fromtestkit.py b/testkitbackend/fromtestkit.py
index 5c3c92d8..7cf17b92 100644
--- a/testkitbackend/fromtestkit.py
+++ b/testkitbackend/fromtestkit.py
@@ -15,7 +15,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from neo4j.work.simple import Query
+
+from neo4j import Query
def to_cypher_and_params(data):
diff --git a/testkitbackend/requests.py b/testkitbackend/requests.py
index 09a2c972..72e4d66f 100644
--- a/testkitbackend/requests.py
+++ b/testkitbackend/requests.py
@@ -14,13 +14,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+
import json
from os import path
import neo4j
-import testkitbackend.fromtestkit as fromtestkit
-import testkitbackend.totestkit as totestkit
-from testkitbackend.fromtestkit import to_meta_and_timeout
+
+from . import (
+ fromtestkit,
+ totestkit,
+)
+from .fromtestkit import to_meta_and_timeout
class FrontendError(Exception):
@@ -357,7 +362,7 @@ def ResultList(backend, data):
def ResultConsume(backend, data):
result = backend.results[data["resultId"]]
summary = result.consume()
- from neo4j.work.summary import ResultSummary
+ from neo4j import ResultSummary
assert isinstance(summary, ResultSummary)
backend.send_response("Summary", {
"serverInfo": {
diff --git a/testkitbackend/server.py b/testkitbackend/server.py
index 45728164..0eb36bd1 100644
--- a/testkitbackend/server.py
+++ b/testkitbackend/server.py
@@ -14,8 +14,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from socketserver import TCPServer, StreamRequestHandler
-from testkitbackend.backend import Backend
+
+
+import asyncio
+from socketserver import (
+ StreamRequestHandler,
+ TCPServer,
+)
+
+from ._async.backend import AsyncBackend
+from ._sync.backend import Backend
class Server(TCPServer):
@@ -29,3 +37,32 @@ def handle(self):
pass
print("Disconnected")
super(Server, self).__init__(address, Handler)
+
+
+class AsyncServer:
+ def __init__(self, address):
+ self._address = address
+ self._server = None
+
+ @staticmethod
+ async def _handler(reader, writer):
+ backend = AsyncBackend(reader, writer)
+ while await backend.process_request():
+ pass
+ print("Disconnected")
+
+ async def start(self):
+ self._server = await asyncio.start_server(
+ self._handler, host=self._address[0], port=self._address[1],
+ limit=float("inf") # this is dirty but works (for now)
+ )
+
+ async def serve_forever(self):
+ if not self._server:
+ raise RuntimeError("Server not started")
+ await self._server.serve_forever()
+
+ def stop(self):
+ if not self._server:
+ raise RuntimeError("Try starting the server before stopping it ;)")
+ self._server.close()
diff --git a/testkitbackend/totestkit.py b/testkitbackend/totestkit.py
index 3068017a..2e48861e 100644
--- a/testkitbackend/totestkit.py
+++ b/testkitbackend/totestkit.py
@@ -14,6 +14,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+
import math
from neo4j.graph import (
diff --git a/tests/__init__.py b/tests/__init__.py
index e69de29b..b81a309d 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/env.py b/tests/env.py
index cacb2440..f3e965cb 100644
--- a/tests/env.py
+++ b/tests/env.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py
index 015ba64d..9c123ac5 100644
--- a/tests/integration/conftest.py
+++ b/tests/integration/conftest.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -21,18 +18,19 @@
from math import ceil
from os import getenv
-from os.path import dirname, join
+from os.path import (
+ dirname,
+ join,
+)
from threading import RLock
import pytest
-import urllib
-from neo4j import (
- GraphDatabase,
-)
-from neo4j.exceptions import ServiceUnavailable
+from neo4j import GraphDatabase
from neo4j._exceptions import BoltHandshakeError
-from neo4j.io import Bolt
+from neo4j._sync.io import Bolt
+from neo4j.exceptions import ServiceUnavailable
+
# import logging
# log = logging.getLogger("neo4j")
@@ -50,7 +48,7 @@
NEO4J_CORES = 3
NEO4J_REPLICAS = 2
NEO4J_USER = "neo4j"
-NEO4J_PASSWORD = "password"
+NEO4J_PASSWORD = "pass"
NEO4J_AUTH = (NEO4J_USER, NEO4J_PASSWORD)
NEO4J_LOCK = RLock()
NEO4J_SERVICE = None
@@ -76,7 +74,10 @@ def __init__(self, name=None, image=None, auth=None,
n_cores=None, n_replicas=None,
bolt_port=None, http_port=None, debug_port=None,
debug_suspend=None, dir_spec=None, config=None):
- from boltkit.legacy.controller import _install, create_controller
+ from boltkit.legacy.controller import (
+ _install,
+ create_controller,
+ )
assert image.endswith("-enterprise")
release = image[:-11]
if release == "snapshot":
diff --git a/tests/integration/examples/__init__.py b/tests/integration/examples/__init__.py
index 87bbaf26..a629622e 100644
--- a/tests/integration/examples/__init__.py
+++ b/tests/integration/examples/__init__.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
diff --git a/tests/integration/examples/test_autocommit_transaction_example.py b/tests/integration/examples/test_autocommit_transaction_example.py
index 0f5b16e1..f4ffa7e7 100644
--- a/tests/integration/examples/test_autocommit_transaction_example.py
+++ b/tests/integration/examples/test_autocommit_transaction_example.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,9 +16,11 @@
# limitations under the License.
+# isort: off
# tag::autocommit-transaction-import[]
from neo4j import Query
# end::autocommit-transaction-import[]
+# isort: on
# python -m pytest tests/integration/examples/test_autocommit_transaction_example.py -s -v
diff --git a/tests/integration/examples/test_basic_auth_example.py b/tests/integration/examples/test_basic_auth_example.py
index 90e32808..482ac7ba 100644
--- a/tests/integration/examples/test_basic_auth_example.py
+++ b/tests/integration/examples/test_basic_auth_example.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -21,13 +18,16 @@
import pytest
+from neo4j._exceptions import BoltHandshakeError
+from neo4j.exceptions import ServiceUnavailable
+from tests.integration.examples import DriverSetupExample
+
+
+# isort: off
# tag::basic-auth-import[]
from neo4j import GraphDatabase
# end::basic-auth-import[]
-
-from neo4j.exceptions import ServiceUnavailable
-from neo4j._exceptions import BoltHandshakeError
-from tests.integration.examples import DriverSetupExample
+# isort: on
# python -m pytest tests/integration/examples/test_basic_auth_example.py -s -v
diff --git a/tests/integration/examples/test_bearer_auth_example.py b/tests/integration/examples/test_bearer_auth_example.py
index 727dd6a9..191ca7e0 100644
--- a/tests/integration/examples/test_bearer_auth_example.py
+++ b/tests/integration/examples/test_bearer_auth_example.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,22 +16,22 @@
# limitations under the License.
-import pytest
-
import neo4j
+from tests.integration.examples import DriverSetupExample
+
+
+# isort: off
# tag::bearer-auth-import[]
from neo4j import (
bearer_auth,
GraphDatabase,
)
# end::bearer-auth-import[]
-
-from tests.integration.examples import DriverSetupExample
+# isort: on
# python -m pytest tests/integration/examples/test_bearer_auth_example.py -s -v
-
class BearerAuthExample(DriverSetupExample):
# tag::bearer-auth[]
@@ -61,4 +58,3 @@ def test_example(uri, mocker):
assert not hasattr(auth, "principal")
assert auth.credentials == token
assert not hasattr(auth, "parameters")
-
diff --git a/tests/integration/examples/test_config_connection_pool_example.py b/tests/integration/examples/test_config_connection_pool_example.py
index 7ddaacc7..233adaa6 100644
--- a/tests/integration/examples/test_config_connection_pool_example.py
+++ b/tests/integration/examples/test_config_connection_pool_example.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -21,14 +18,16 @@
import pytest
-from neo4j.exceptions import ServiceUnavailable
from neo4j._exceptions import BoltHandshakeError
+from neo4j.exceptions import ServiceUnavailable
+from tests.integration.examples import DriverSetupExample
+
+# isort: off
# tag::config-connection-pool-import[]
from neo4j import GraphDatabase
# end::config-connection-pool-import[]
-
-from tests.integration.examples import DriverSetupExample
+# isort: on
# python -m pytest tests/integration/examples/test_config_connection_pool_example.py -s -v
diff --git a/tests/integration/examples/test_config_connection_timeout_example.py b/tests/integration/examples/test_config_connection_timeout_example.py
index 00d6c97d..ef17e2b4 100644
--- a/tests/integration/examples/test_config_connection_timeout_example.py
+++ b/tests/integration/examples/test_config_connection_timeout_example.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -21,14 +18,16 @@
import pytest
-from neo4j.exceptions import ServiceUnavailable
from neo4j._exceptions import BoltHandshakeError
+from neo4j.exceptions import ServiceUnavailable
+from tests.integration.examples import DriverSetupExample
+
+# isort: off
# tag::config-connection-timeout-import[]
from neo4j import GraphDatabase
# end::config-connection-timeout-import[]
-
-from tests.integration.examples import DriverSetupExample
+# isort: on
# python -m pytest tests/integration/examples/test_config_connection_timeout_example.py -s -v
diff --git a/tests/integration/examples/test_config_max_retry_time_example.py b/tests/integration/examples/test_config_max_retry_time_example.py
index 8cd25905..3a28235d 100644
--- a/tests/integration/examples/test_config_max_retry_time_example.py
+++ b/tests/integration/examples/test_config_max_retry_time_example.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -21,13 +18,16 @@
import pytest
+from neo4j._exceptions import BoltHandshakeError
+from neo4j.exceptions import ServiceUnavailable
+from tests.integration.examples import DriverSetupExample
+
+
+# isort: off
# tag::config-max-retry-time-import[]
from neo4j import GraphDatabase
# end::config-max-retry-time-import[]
-
-from neo4j.exceptions import ServiceUnavailable
-from neo4j._exceptions import BoltHandshakeError
-from tests.integration.examples import DriverSetupExample
+# isort: on
# python -m pytest tests/integration/examples/test_config_max_retry_time_example.py -s -v
diff --git a/tests/integration/examples/test_config_secure_example.py b/tests/integration/examples/test_config_secure_example.py
index ebd048dc..2a6c6603 100644
--- a/tests/integration/examples/test_config_secure_example.py
+++ b/tests/integration/examples/test_config_secure_example.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -21,15 +18,21 @@
import pytest
-# tag::config-secure-import[]
-from neo4j import GraphDatabase, TRUST_SYSTEM_CA_SIGNED_CERTIFICATES
-# end::config-secure-import[]
-
-from neo4j.exceptions import ServiceUnavailable
from neo4j._exceptions import BoltHandshakeError
+from neo4j.exceptions import ServiceUnavailable
from tests.integration.examples import DriverSetupExample
+# isort: off
+# tag::config-secure-import[]
+from neo4j import (
+ GraphDatabase,
+ TRUST_SYSTEM_CA_SIGNED_CERTIFICATES,
+)
+# end::config-secure-import[]
+# isort: off
+
+
# python -m pytest tests/integration/examples/test_config_secure_example.py -s -v
class ConfigSecureExample(DriverSetupExample):
diff --git a/tests/integration/examples/test_config_trust_example.py b/tests/integration/examples/test_config_trust_example.py
index 1466c536..b63139f1 100644
--- a/tests/integration/examples/test_config_trust_example.py
+++ b/tests/integration/examples/test_config_trust_example.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -21,15 +18,17 @@
import pytest
+from tests.integration.examples import DriverSetupExample
+
+
+# isort: off
# tag::config-trust-import[]
from neo4j import (
GraphDatabase,
- TRUST_SYSTEM_CA_SIGNED_CERTIFICATES,
- TRUST_ALL_CERTIFICATES,
+ TRUST_ALL_CERTIFICATES
)
# end::config-trust-import[]
-
-from tests.integration.examples import DriverSetupExample
+# isort: on
class ConfigTrustExample(DriverSetupExample):
diff --git a/tests/integration/examples/test_config_unencrypted_example.py b/tests/integration/examples/test_config_unencrypted_example.py
index 6d70ad03..d04297df 100644
--- a/tests/integration/examples/test_config_unencrypted_example.py
+++ b/tests/integration/examples/test_config_unencrypted_example.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -21,14 +18,16 @@
import pytest
+from neo4j._exceptions import BoltHandshakeError
+from neo4j.exceptions import ServiceUnavailable
+from tests.integration.examples import DriverSetupExample
+
+
+# isort: off
# tag::config-unencrypted-import[]
from neo4j import GraphDatabase
# end::config-unencrypted-import[]
-
-from neo4j.exceptions import ServiceUnavailable
-from neo4j._exceptions import BoltHandshakeError
-
-from tests.integration.examples import DriverSetupExample
+# isort: on
# python -m pytest tests/integration/examples/test_config_unencrypted_example.py -s -v
diff --git a/tests/integration/examples/test_custom_auth_example.py b/tests/integration/examples/test_custom_auth_example.py
index 1aaddd1e..c9621c6c 100644
--- a/tests/integration/examples/test_custom_auth_example.py
+++ b/tests/integration/examples/test_custom_auth_example.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -21,16 +18,19 @@
import pytest
+from neo4j._exceptions import BoltHandshakeError
+from neo4j.exceptions import ServiceUnavailable
+from tests.integration.examples import DriverSetupExample
+
+
+# isort: off
# tag::custom-auth-import[]
from neo4j import (
GraphDatabase,
custom_auth,
)
# end::custom-auth-import[]
-
-from neo4j.exceptions import ServiceUnavailable
-from neo4j._exceptions import BoltHandshakeError
-from tests.integration.examples import DriverSetupExample
+# isort: on
# python -m pytest tests/integration/examples/test_custom_auth_example.py -s -v
diff --git a/tests/integration/examples/test_custom_resolver_example.py b/tests/integration/examples/test_custom_resolver_example.py
index fe4e5e08..9bc349de 100644
--- a/tests/integration/examples/test_custom_resolver_example.py
+++ b/tests/integration/examples/test_custom_resolver_example.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -21,15 +18,18 @@
import pytest
+from neo4j._exceptions import BoltHandshakeError
+from neo4j.exceptions import ServiceUnavailable
+
+
+# isort: off
# tag::custom-resolver-import[]
from neo4j import (
GraphDatabase,
WRITE_ACCESS,
)
# end::custom-resolver-import[]
-
-from neo4j.exceptions import ServiceUnavailable
-from neo4j._exceptions import BoltHandshakeError
+# isort: on
# python -m pytest tests/integration/examples/test_custom_resolver_example.py -s -v
diff --git a/tests/integration/examples/test_cypher_error_example.py b/tests/integration/examples/test_cypher_error_example.py
index c2108b9f..411929db 100644
--- a/tests/integration/examples/test_cypher_error_example.py
+++ b/tests/integration/examples/test_cypher_error_example.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -22,9 +19,12 @@
from contextlib import redirect_stdout
from io import StringIO
+
+# isort: off
# tag::cypher-error-import[]
from neo4j.exceptions import ClientError
# end::cypher-error-import[]
+# isort: on
class Neo4jErrorExample:
diff --git a/tests/integration/examples/test_database_selection_example.py b/tests/integration/examples/test_database_selection_example.py
index 99293ea8..65b8b8fc 100644
--- a/tests/integration/examples/test_database_selection_example.py
+++ b/tests/integration/examples/test_database_selection_example.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,24 +16,22 @@
# limitations under the License.
-import pytest
-
from contextlib import redirect_stdout
from io import StringIO
-from neo4j import GraphDatabase
+# isort: off
# tag::database-selection-import[]
-from neo4j import READ_ACCESS
+from neo4j import (
+ GraphDatabase,
+ READ_ACCESS,
+)
# end::database-selection-import[]
-
-from neo4j.exceptions import ServiceUnavailable
-from neo4j._exceptions import BoltHandshakeError
+# isort: on
# python -m pytest tests/integration/examples/test_database_selection_example.py -s -v
-
class DatabaseSelectionExample:
def __init__(self, uri, user, password):
diff --git a/tests/integration/examples/test_driver_introduction_example.py b/tests/integration/examples/test_driver_introduction_example.py
index febd26a7..4348b35f 100644
--- a/tests/integration/examples/test_driver_introduction_example.py
+++ b/tests/integration/examples/test_driver_introduction_example.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,11 +16,15 @@
# limitations under the License.
-import pytest
-
from contextlib import redirect_stdout
from io import StringIO
+import pytest
+
+from neo4j._exceptions import BoltHandshakeError
+
+
+# isort: off
# tag::driver-introduction-example-import[]
import logging
import sys
@@ -31,8 +32,7 @@
from neo4j import GraphDatabase
from neo4j.exceptions import ServiceUnavailable
# end::driver-introduction-example-import[]
-
-from neo4j._exceptions import BoltHandshakeError
+# isort: on
# python -m pytest tests/integration/examples/test_driver_introduction_example.py -s -v
diff --git a/tests/integration/examples/test_driver_lifecycle_example.py b/tests/integration/examples/test_driver_lifecycle_example.py
index d416173f..3ba18854 100644
--- a/tests/integration/examples/test_driver_lifecycle_example.py
+++ b/tests/integration/examples/test_driver_lifecycle_example.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -21,12 +18,15 @@
import pytest
+from neo4j._exceptions import BoltHandshakeError
+from neo4j.exceptions import ServiceUnavailable
+
+
+# isort: off
# tag::driver-lifecycle-import[]
from neo4j import GraphDatabase
# end::driver-lifecycle-import[]
-
-from neo4j.exceptions import ServiceUnavailable
-from neo4j._exceptions import BoltHandshakeError
+# isort: on
# python -m pytest tests/integration/examples/test_driver_lifecycle_example.py -s -v
diff --git a/tests/integration/examples/test_geospatial_types_example.py b/tests/integration/examples/test_geospatial_types_example.py
index 30b7ff4e..39a2df8e 100644
--- a/tests/integration/examples/test_geospatial_types_example.py
+++ b/tests/integration/examples/test_geospatial_types_example.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -18,8 +15,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+
import pytest
+
# python -m pytest tests/integration/examples/test_geospatial_types_example.py -s -v
@@ -28,9 +27,11 @@ def _echo(tx, x):
def test_cartesian_point(driver):
+ # isort: off
# tag::geospatial-types-cartesian-import[]
from neo4j.spatial import CartesianPoint
# end::geospatial-types-cartesian-import[]
+ # isort: on
# tag::geospatial-types-cartesian[]
# Creating a 2D point in Cartesian space
@@ -87,9 +88,11 @@ def test_cartesian_point(driver):
def test_wgs84_point(driver):
+ # isort: off
# tag::geospatial-types-wgs84-import[]
from neo4j.spatial import WGS84Point
# end::geospatial-types-wgs84-import[]
+ # isort: on
# tag::geospatial-types-wgs84[]
# Creating a 2D point in WSG84 space
diff --git a/tests/integration/examples/test_hello_world_example.py b/tests/integration/examples/test_hello_world_example.py
index a2723453..b3c83e8a 100644
--- a/tests/integration/examples/test_hello_world_example.py
+++ b/tests/integration/examples/test_hello_world_example.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,17 +16,20 @@
# limitations under the License.
-import pytest
-
from contextlib import redirect_stdout
from io import StringIO
+import pytest
+
+from neo4j._exceptions import BoltHandshakeError
+from neo4j.exceptions import ServiceUnavailable
+
+
+# isort: off
# tag::hello-world-import[]
from neo4j import GraphDatabase
# end::hello-world-import[]
-
-from neo4j.exceptions import ServiceUnavailable
-from neo4j._exceptions import BoltHandshakeError
+# isort: on
# python -m pytest tests/integration/examples/test_hello_world_example.py -s -v
@@ -79,4 +79,3 @@ def test_hello_world_example(uri, auth):
except ServiceUnavailable as error:
if isinstance(error.__cause__, BoltHandshakeError):
pytest.skip(error.args[0])
-
diff --git a/tests/integration/examples/test_kerberos_auth_example.py b/tests/integration/examples/test_kerberos_auth_example.py
index 4a8c42ec..ed8b8c17 100644
--- a/tests/integration/examples/test_kerberos_auth_example.py
+++ b/tests/integration/examples/test_kerberos_auth_example.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -18,16 +15,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import pytest
+
+from tests.integration.examples import DriverSetupExample
+
-import neo4j
+# isort: off
# tag::kerberos-auth-import[]
from neo4j import (
GraphDatabase,
kerberos_auth,
)
# end::kerberos-auth-import[]
-
-from tests.integration.examples import DriverSetupExample
+# isort: on
# python -m pytest tests/integration/examples/test_kerberos_auth_example.py -s -v
@@ -39,21 +39,5 @@ def __init__(self, uri, ticket):
# end::kerberos-auth[]
-def test_example(uri, mocker):
- # Currently, there is no way of running the test against a server with SSO
- # setup.
- mocker.patch("neo4j.GraphDatabase.bolt_driver")
- mocker.patch("neo4j.GraphDatabase.neo4j_driver")
-
- ticket = "myTicket"
- KerberosAuthExample(uri, ticket)
- calls = (neo4j.GraphDatabase.bolt_driver.call_args_list
- + neo4j.GraphDatabase.neo4j_driver.call_args_list)
- assert len(calls) == 1
- args_, kwargs = calls[0]
- auth = kwargs.get("auth")
- assert isinstance(auth, neo4j.Auth)
- assert auth.scheme == "kerberos"
- assert auth.principal == ""
- assert auth.credentials == ticket
- assert not hasattr(auth, "parameters")
+def test_example():
+ pytest.skip("Currently no way to test Kerberos auth")
diff --git a/tests/integration/examples/test_pass_bookmarks_example.py b/tests/integration/examples/test_pass_bookmarks_example.py
index 6f51584a..51cc33d9 100644
--- a/tests/integration/examples/test_pass_bookmarks_example.py
+++ b/tests/integration/examples/test_pass_bookmarks_example.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -21,12 +18,15 @@
import pytest
+from neo4j._exceptions import BoltHandshakeError
+from neo4j.exceptions import ServiceUnavailable
+
+
+# isort: off
# tag::pass-bookmarks-import[]
from neo4j import GraphDatabase
# end::pass-bookmarks-import[]
-
-from neo4j.exceptions import ServiceUnavailable
-from neo4j._exceptions import BoltHandshakeError
+# isort: on
# python -m pytest tests/integration/examples/test_pass_bookmarks_example.py -s -v
diff --git a/tests/integration/examples/test_read_write_transaction_example.py b/tests/integration/examples/test_read_write_transaction_example.py
index 04a1a4ef..787a4849 100644
--- a/tests/integration/examples/test_read_write_transaction_example.py
+++ b/tests/integration/examples/test_read_write_transaction_example.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
diff --git a/tests/integration/examples/test_result_consume_example.py b/tests/integration/examples/test_result_consume_example.py
index 53da413b..3770d9ed 100644
--- a/tests/integration/examples/test_result_consume_example.py
+++ b/tests/integration/examples/test_result_consume_example.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
diff --git a/tests/integration/examples/test_result_retain_example.py b/tests/integration/examples/test_result_retain_example.py
index a96bf02e..828a8236 100644
--- a/tests/integration/examples/test_result_retain_example.py
+++ b/tests/integration/examples/test_result_retain_example.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
diff --git a/tests/integration/examples/test_service_unavailable_example.py b/tests/integration/examples/test_service_unavailable_example.py
index 7b2961db..21972798 100644
--- a/tests/integration/examples/test_service_unavailable_example.py
+++ b/tests/integration/examples/test_service_unavailable_example.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -21,9 +18,12 @@
import pytest
+
+# isort: off
# tag::service-unavailable-import[]
from neo4j.exceptions import ServiceUnavailable
# end::service-unavailable-import[]
+# isort: on
def service_unavailable_example(driver):
diff --git a/tests/integration/examples/test_session_example.py b/tests/integration/examples/test_session_example.py
index 65cb3255..14302c31 100644
--- a/tests/integration/examples/test_session_example.py
+++ b/tests/integration/examples/test_session_example.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
diff --git a/tests/integration/examples/test_temporal_types_example.py b/tests/integration/examples/test_temporal_types_example.py
index 97b57e88..42b17cdc 100644
--- a/tests/integration/examples/test_temporal_types_example.py
+++ b/tests/integration/examples/test_temporal_types_example.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -26,12 +23,14 @@ def _echo(tx, x):
def test_datetime(driver):
+ # isort: off
# tag::temporal-types-datetime-import[]
from datetime import datetime
from neo4j.time import DateTime
import pytz
# end::temporal-types-datetime-import[]
+ # isort: on
# tag::temporal-types-datetime[]
# Create datetimes to be used as query parameters
@@ -78,11 +77,13 @@ def test_datetime(driver):
def test_date(driver):
+ # isort: off
# tag::temporal-types-date-import[]
from datetime import date
from neo4j.time import Date
# end::temporal-types-date-import[]
+ # isort: on
# tag::temporal-types-date[]
# Create dates to be used as query parameters
@@ -127,12 +128,14 @@ def test_date(driver):
def test_time(driver):
+ # isort: off
# tag::temporal-types-time-import[]
from datetime import time
from neo4j.time import Time
import pytz
# end::temporal-types-time-import[]
+ # isort: on
# tag::temporal-types-time[]
# Create datetimes to be used as query parameters
@@ -176,12 +179,13 @@ def test_time(driver):
def test_local_datetime(driver):
+ # isort: off
# tag::temporal-types-local-datetime-import[]
from datetime import datetime
from neo4j.time import DateTime
- import pytz
# end::temporal-types-local-datetime-import[]
+ # isort: on
# tag::temporal-types-local-datetime[]
# Create datetimes to be used as query parameters
@@ -226,12 +230,13 @@ def test_local_datetime(driver):
def test_local_time(driver):
+ # isort: off
# tag::temporal-types-local-time-import[]
from datetime import time
from neo4j.time import Time
- import pytz
# end::temporal-types-local-time-import[]
+ # isort: on
# tag::temporal-types-local-time[]
# Create datetimes to be used as query parameters
@@ -275,11 +280,13 @@ def test_local_time(driver):
def test_duration_example(driver):
+ # isort: off
# tag::temporal-types-duration-import[]
from datetime import timedelta
from neo4j.time import Duration
# end::temporal-types-duration-import[]
+ # isort: on
# tag::temporal-types-duration[]
# Creating durations to be used as query parameters
diff --git a/tests/integration/examples/test_transaction_function_example.py b/tests/integration/examples/test_transaction_function_example.py
index 5e525dff..e0b6588c 100644
--- a/tests/integration/examples/test_transaction_function_example.py
+++ b/tests/integration/examples/test_transaction_function_example.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,9 +16,11 @@
# limitations under the License.
+# isort: off
# tag::transaction-function-import[]
from neo4j import unit_of_work
# end::transaction-function-import[]
+# isort: on
# python -m pytest tests/integration/examples/test_transaction_function_example.py -s -v
diff --git a/tests/integration/examples/test_transaction_metadata_config_example.py b/tests/integration/examples/test_transaction_metadata_config_example.py
index ef7ef39d..8862d22a 100644
--- a/tests/integration/examples/test_transaction_metadata_config_example.py
+++ b/tests/integration/examples/test_transaction_metadata_config_example.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -18,7 +15,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from neo4j import unit_of_work, Query
+
+from neo4j import unit_of_work
# python -m pytest tests/integration/examples/test_transaction_metadata_config_example.py -s -v
diff --git a/tests/integration/examples/test_transaction_timeout_config_example.py b/tests/integration/examples/test_transaction_timeout_config_example.py
index cf3ed5a6..e3503691 100644
--- a/tests/integration/examples/test_transaction_timeout_config_example.py
+++ b/tests/integration/examples/test_transaction_timeout_config_example.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -18,7 +15,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from neo4j import unit_of_work, Query
+
+from neo4j import unit_of_work
# python -m pytest tests/integration/examples/test_transaction_timeout_config_example.py -s -v
diff --git a/tests/integration/test_autocommit.py b/tests/integration/test_autocommit.py
index 32619991..960cbe3f 100644
--- a/tests/integration/test_autocommit.py
+++ b/tests/integration/test_autocommit.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,7 +16,7 @@
# limitations under the License.
-from neo4j.work.simple import Query
+from neo4j import Query
# TODO: this test will stay until a uniform behavior for `.single()` across the
diff --git a/tests/integration/test_bolt_driver.py b/tests/integration/test_bolt_driver.py
index 5b2e76a0..346b82cd 100644
--- a/tests/integration/test_bolt_driver.py
+++ b/tests/integration/test_bolt_driver.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,30 +16,8 @@
# limitations under the License.
-import pytest
from pytest import fixture
-from neo4j import (
- GraphDatabase,
- BoltDriver,
- Version,
- READ_ACCESS,
- WRITE_ACCESS,
- ResultSummary,
- unit_of_work,
- Transaction,
- Result,
- ServerInfo,
-)
-from neo4j.exceptions import (
- ServiceUnavailable,
- AuthError,
- ConfigurationError,
- ClientError,
-)
-from neo4j._exceptions import BoltHandshakeError
-from neo4j.io._bolt3 import Bolt3
-
# TODO: this test will stay until a uniform behavior for `.single()` across the
# drivers has been specified and tests are created in testkit
diff --git a/tests/integration/test_pipelines.py b/tests/integration/test_pipelines.py
deleted file mode 100644
index 299ec76f..00000000
--- a/tests/integration/test_pipelines.py
+++ /dev/null
@@ -1,308 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-# Copyright (c) "Neo4j"
-# Neo4j Sweden AB [http://neo4j.com]
-#
-# This file is part of Neo4j.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import pytest
-from uuid import uuid4
-
-
-from neo4j.packstream import Structure
-from neo4j.exceptions import (
- Neo4jError,
- CypherSyntaxError,
-)
-from neo4j.graph import (
- Node,
- Relationship,
- Path,
-)
-from neo4j.work.pipelining import PullOrderException
-
-
-def test_can_run_simple_statement(bolt_driver):
- pipeline = bolt_driver.pipeline(flush_every=0)
- pipeline.push("RETURN 1 AS n")
- for record in pipeline.pull():
- assert len(record) == 1
- assert record[0] == 1
- # TODO: why does pipeline result not look like a regular result?
- # assert record["n"] == 1
- # with pytest.raises(KeyError):
- # _ = record["x"]
- # assert record["n"] == 1
- # with pytest.raises(KeyError):
- # _ = record["x"]
- with pytest.raises(TypeError):
- _ = record[object()]
- assert repr(record)
- assert len(record) == 1
- pipeline.close()
-
-
-def test_can_run_simple_statement_with_params(bolt_driver):
- pipeline = bolt_driver.pipeline(flush_every=0)
- count = 0
- pipeline.push("RETURN $x AS n", {"x": {"abc": ["d", "e", "f"]}})
- for record in pipeline.pull():
- assert record[0] == {"abc": ["d", "e", "f"]}
- # TODO: why does pipeline result not look like a regular result?
- # assert record["n"] == {"abc": ["d", "e", "f"]}
- assert repr(record)
- assert len(record) == 1
- count += 1
- pipeline.close()
- assert count == 1
-
-
-def test_can_run_write_statement_with_no_return(driver):
- pipeline = driver.pipeline(flush_every=0)
- count = 0
- test_uid = str(uuid4())
- pipeline.push("CREATE (a:Person {uid:$test_uid})", dict(test_uid=test_uid))
-
- for _ in pipeline.pull():
- raise Exception("Should not return any results from create with no return")
- # Note you still have to consume the generator if you want to be allowed to pull from the pipeline again even
- # though it doesn't apparently return any items.
-
- pipeline.push("MATCH (a:Person {uid:$test_uid}) RETURN a LIMIT 1", dict(test_uid=test_uid))
- for _ in pipeline.pull():
- count += 1
- pipeline.close()
- assert count == 1
-
-
-def test_fails_on_bad_syntax(bolt_driver):
- pipeline = bolt_driver.pipeline(flush_every=0)
- with pytest.raises(Neo4jError):
- pipeline.push("X")
- next(pipeline.pull())
-
-
-def test_doesnt_fail_on_bad_syntax_somewhere(bolt_driver):
- pipeline = bolt_driver.pipeline(flush_every=0)
- pipeline.push("RETURN 1 AS n")
- pipeline.push("X")
- assert next(pipeline.pull())[0] == 1
- with pytest.raises(Neo4jError):
- next(pipeline.pull())
-
-
-def test_fails_on_missing_parameter(bolt_driver):
- pipeline = bolt_driver.pipeline(flush_every=0)
- with pytest.raises(Neo4jError):
- pipeline.push("RETURN $x")
- next(pipeline.pull())
-
-
-def test_can_run_simple_statement_from_bytes_string(bolt_driver):
- pipeline = bolt_driver.pipeline(flush_every=0)
- count = 0
- pytest.skip("FIXME: why can't pipeline handle bytes string?")
- pipeline.push(b"RETURN 1 AS n")
- for record in pipeline.pull():
- assert record[0] == 1
- assert record["n"] == 1
- assert repr(record)
- assert len(record) == 1
- count += 1
- pipeline.close()
- assert count == 1
-
-
-def test_can_run_statement_that_returns_multiple_records(bolt_driver):
- pipeline = bolt_driver.pipeline(flush_every=0)
- count = 0
- pipeline.push("unwind(range(1, 10)) AS z RETURN z")
- for record in pipeline.pull():
- assert 1 <= record[0] <= 10
- count += 1
- pipeline.close()
- assert count == 10
-
-
-def test_can_return_node(neo4j_driver):
- with neo4j_driver.pipeline(flush_every=0) as pipeline:
- pipeline.push("CREATE (a:Person {name:'Alice'}) RETURN a")
- record_list = list(pipeline.pull())
- assert len(record_list) == 1
- for record in record_list:
- alice = record[0]
- print(alice)
- pytest.skip("FIXME: why does pipeline result not look like a regular result?")
- assert isinstance(alice, Node)
- assert alice.labels == {"Person"}
- assert dict(alice) == {"name": "Alice"}
-
-
-def test_can_return_relationship(neo4j_driver):
- with neo4j_driver.pipeline(flush_every=0) as pipeline:
- pipeline.push("CREATE ()-[r:KNOWS {since:1999}]->() RETURN r")
- record_list = list(pipeline.pull())
- assert len(record_list) == 1
- for record in record_list:
- rel = record[0]
- print(rel)
- pytest.skip("FIXME: why does pipeline result not look like a regular result?")
- assert isinstance(rel, Relationship)
- assert rel.type == "KNOWS"
- assert dict(rel) == {"since": 1999}
-
-
-def test_can_return_path(neo4j_driver):
- with neo4j_driver.pipeline(flush_every=0) as pipeline:
- test_uid = str(uuid4())
- pipeline.push(
- "MERGE p=(alice:Person {name:'Alice', test_uid: $test_uid})"
- "-[:KNOWS {test_uid: $test_uid}]->"
- "(:Person {name:'Bob', test_uid: $test_uid})"
- " RETURN p",
- dict(test_uid=test_uid)
- )
- record_list = list(pipeline.pull())
- assert len(record_list) == 1
- for record in record_list:
- path = record[0]
- print(path)
- pytest.skip("FIXME: why does pipeline result not look like a regular result?")
- assert isinstance(path, Path)
- assert path.start_node["name"] == "Alice"
- assert path.end_node["name"] == "Bob"
- assert path.relationships[0].type == "KNOWS"
- assert len(path.nodes) == 2
- assert len(path.relationships) == 1
-
-
-def test_can_handle_cypher_error(bolt_driver):
- with bolt_driver.pipeline(flush_every=0) as pipeline:
- pipeline.push("X")
- with pytest.raises(Neo4jError):
- next(pipeline.pull())
-
-
-def test_should_not_allow_empty_statements(bolt_driver, requires_bolt_4x):
- with bolt_driver.pipeline(flush_every=0) as pipeline:
- pipeline.push("")
- with pytest.raises(CypherSyntaxError):
- next(pipeline.pull())
-
-
-def test_can_queue_multiple_statements(bolt_driver):
- count = 0
- with bolt_driver.pipeline(flush_every=0) as pipeline:
- pipeline.push("unwind(range(1, 10)) AS z RETURN z")
- pipeline.push("unwind(range(11, 20)) AS z RETURN z")
- pipeline.push("unwind(range(21, 30)) AS z RETURN z")
- for i in range(3):
- for record in pipeline.pull():
- assert (i * 10 + 1) <= record[0] <= ((i + 1) * 10)
- count += 1
- assert count == 30
-
-
-def test_pull_order_exception(bolt_driver):
- """If you try and pull when you haven't finished iterating the previous result you get an error"""
- pipeline = bolt_driver.pipeline(flush_every=0)
- with pytest.raises(PullOrderException):
- pipeline.push("unwind(range(1, 10)) AS z RETURN z")
- pipeline.push("unwind(range(11, 20)) AS z RETURN z")
- generator_one = pipeline.pull()
- generator_two = pipeline.pull()
-
-
-def test_pipeline_can_read_own_writes(neo4j_driver):
- """I am not sure that we _should_ guarantee this"""
- count = 0
- with neo4j_driver.pipeline(flush_every=0) as pipeline:
- test_uid = str(uuid4())
- pipeline.push(
- "CREATE (a:Person {name:'Alice', test_uid: $test_uid})",
- dict(test_uid=test_uid)
- )
- pipeline.push(
- "MATCH (alice:Person {name:'Alice', test_uid: $test_uid}) "
- "MERGE (alice)"
- "-[:KNOWS {test_uid: $test_uid}]->"
- "(:Person {name:'Bob', test_uid: $test_uid})",
- dict(test_uid=test_uid)
- )
- pipeline.push("MATCH (n:Person {test_uid: $test_uid}) RETURN n", dict(test_uid=test_uid))
- pipeline.push(
- "MATCH"
- " p=(:Person {test_uid: $test_uid})-[:KNOWS {test_uid: $test_uid}]->(:Person {test_uid: $test_uid})"
- " RETURN p",
- dict(test_uid=test_uid)
- )
-
- # create Alice
- # n.b. we have to consume the result
- assert next(pipeline.pull(), True) == True
-
- # merge "knows Bob"
- # n.b. we have to consume the result
- assert next(pipeline.pull(), True) == True
-
- # get people
- for result in pipeline.pull():
- count += 1
-
- assert len(result) == 1
- person = result[0]
- print(person)
- assert isinstance(person, Structure)
- assert person.tag == b'N'
- print(person.fields)
- assert set(person.fields[1]) == {"Person"}
-
- print(count)
- assert count == 2
-
- # get path
- for result in pipeline.pull():
- count += 1
-
- assert len(result) == 1
- path = result[0]
- print(path)
- assert isinstance(path, Structure)
- assert path.tag == b'P'
-
- # TODO: return Path / Node / Rel instances rather than Structures
- # assert isinstance(path, Path)
- # assert path.start_node["name"] == "Alice"
- # assert path.end_node["name"] == "Bob"
- # assert path.relationships[0].type == "KNOWS"
- # assert len(path.nodes) == 2
- # assert len(path.relationships) == 1
-
- assert count == 3
-
-
-def test_automatic_reset_after_failure(bolt_driver):
- with bolt_driver.pipeline(flush_every=0) as pipeline:
- try:
- pipeline.push("X")
- next(pipeline.pull())
- except Neo4jError:
- pipeline.push("RETURN 1")
- record = next(pipeline.pull())
- assert record[0] == 1
- else:
- assert False, "A Cypher error should have occurred"
diff --git a/tests/integration/test_readme.py b/tests/integration/test_readme.py
index 4788411d..fde46639 100644
--- a/tests/integration/test_readme.py
+++ b/tests/integration/test_readme.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -18,10 +15,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+
import pytest
-from neo4j.exceptions import ServiceUnavailable
from neo4j._exceptions import BoltHandshakeError
+from neo4j.exceptions import ServiceUnavailable
+
# python -m pytest tests/integration/test_readme.py -s -v
diff --git a/tests/integration/test_result.py b/tests/integration/test_result.py
index 80e14fb2..b4af833a 100644
--- a/tests/integration/test_result.py
+++ b/tests/integration/test_result.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
diff --git a/tests/integration/test_result_graph.py b/tests/integration/test_result_graph.py
index 8c2cf81f..15ea7a37 100644
--- a/tests/integration/test_result_graph.py
+++ b/tests/integration/test_result_graph.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -21,13 +18,7 @@
import pytest
-from neo4j.graph import (
- Node,
- Relationship,
- Graph,
- Path,
-)
-from neo4j.exceptions import Neo4jError
+from neo4j.graph import Graph
def test_result_graph_instance(session):
diff --git a/tests/integration/test_spatial_types.py b/tests/integration/test_spatial_types.py
index 2227715f..71ed4cd5 100644
--- a/tests/integration/test_spatial_types.py
+++ b/tests/integration/test_spatial_types.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -21,7 +18,6 @@
import pytest
-
from neo4j.spatial import (
CartesianPoint,
WGS84Point,
diff --git a/tests/integration/test_temporal_types.py b/tests/integration/test_temporal_types.py
index 0e996db2..e9e64fb6 100644
--- a/tests/integration/test_temporal_types.py
+++ b/tests/integration/test_temporal_types.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,9 +16,9 @@
# limitations under the License.
-import pytest
import datetime
+import pytest
from pytz import (
FixedOffset,
timezone,
@@ -31,9 +28,9 @@
from neo4j.exceptions import CypherTypeError
from neo4j.time import (
Date,
- Time,
DateTime,
Duration,
+ Time,
)
@@ -402,4 +399,4 @@ def test_time_parameter_case3(session):
t2 = session.run("RETURN $time", time=t1).single().value()
assert isinstance(t2, Time)
- assert t1 == t2
\ No newline at end of file
+ assert t1 == t2
diff --git a/tests/integration/test_tx_functions.py b/tests/integration/test_tx_functions.py
index 184efbcd..bfaa2d58 100644
--- a/tests/integration/test_tx_functions.py
+++ b/tests/integration/test_tx_functions.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,15 +16,17 @@
# limitations under the License.
-import pytest
from uuid import uuid4
-from neo4j.work.simple import unit_of_work
+import pytest
+
+from neo4j import unit_of_work
from neo4j.exceptions import (
- Neo4jError,
ClientError,
+ Neo4jError,
)
+
# python -m pytest tests/integration/test_tx_functions.py -s -v
diff --git a/tests/performance/test_async_results.py b/tests/performance/test_async_results.py
new file mode 100644
index 00000000..085001e0
--- /dev/null
+++ b/tests/performance/test_async_results.py
@@ -0,0 +1,86 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import asyncio
+from itertools import product
+
+from pytest import mark
+
+from neo4j import AsyncGraphDatabase
+
+from .tools import RemoteGraphDatabaseServer
+
+
+class AsyncReadWorkload(object):
+
+ server = None
+ driver = None
+ loop = None
+
+ @classmethod
+ def setup_class(cls):
+ cls.server = server = RemoteGraphDatabaseServer()
+ server.start()
+ cls.loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(cls.loop)
+ cls.driver = AsyncGraphDatabase.driver(server.server_uri,
+ auth=server.auth_token,
+ encrypted=server.encrypted)
+
+ @classmethod
+ def teardown_class(cls):
+ try:
+ cls.loop.run_until_complete(cls.driver.close())
+ cls.server.stop()
+ finally:
+ cls.loop.stop()
+ asyncio.set_event_loop(None)
+
+ def work(self, *units_of_work):
+ async def runner():
+ async with self.driver.session() as session:
+ for unit_of_work in units_of_work:
+ await session.read_transaction(unit_of_work)
+
+ def sync_runner():
+ self.loop.run_until_complete(runner())
+
+ return sync_runner
+
+
+class TestAsyncReadWorkload(AsyncReadWorkload):
+
+ @staticmethod
+ def uow(record_count, record_width, value):
+
+ async def _(tx):
+ s = "UNWIND range(1, $record_count) AS _ RETURN {}".format(
+ ", ".join("$x AS x{}".format(i) for i in range(record_width)))
+ p = {"record_count": record_count, "x": value}
+ async for record in await tx.run(s, p):
+ assert all(x == value for x in record.values())
+
+ return _
+
+ @mark.parametrize("record_count,record_width,value", product(
+ [1, 1000], # record count
+ [1, 10], # record width
+ [1, u'hello, world'], # value
+ ))
+ def test_1x1(self, benchmark, record_count, record_width, value):
+ benchmark(self.work(self.uow(record_count, record_width, value)))
diff --git a/tests/performance/test_results.py b/tests/performance/test_results.py
index 15971faa..9629c7d8 100644
--- a/tests/performance/test_results.py
+++ b/tests/performance/test_results.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -24,6 +21,7 @@
from pytest import mark
from neo4j import GraphDatabase
+
from .tools import RemoteGraphDatabaseServer
diff --git a/tests/performance/tools.py b/tests/performance/tools.py
index 3c51a4fe..511644e5 100644
--- a/tests/performance/tools.py
+++ b/tests/performance/tools.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,17 +16,15 @@
# limitations under the License.
-from unittest import TestCase, SkipTest
-
-try:
- from urllib.request import urlretrieve
-except ImportError:
- from urllib import urlretrieve
+from unittest import SkipTest
from neo4j import GraphDatabase
from neo4j.exceptions import AuthError
-
-from tests.env import NEO4J_USER, NEO4J_PASSWORD, NEO4J_SERVER_URI
+from tests.env import (
+ NEO4J_PASSWORD,
+ NEO4J_SERVER_URI,
+ NEO4J_USER,
+)
def is_listening(address):
diff --git a/tests/requirements.txt b/tests/requirements.txt
index b9405952..d3b59806 100644
--- a/tests/requirements.txt
+++ b/tests/requirements.txt
@@ -1,7 +1,10 @@
git+https://github.com/neo4j-drivers/boltkit@4.2#egg=boltkit
-coverage
-pytest
-pytest-benchmark
-pytest-cov
-pytest-mock
-teamcity-messages
+coverage>=5.5
+pytest>=6.2.5
+pytest-asyncio>=0.16.0
+pytest-benchmark>=3.4.1
+pytest-cov>=3.0.0
+pytest-mock>=3.6.1
+# brings mock.AsyncMock to Python 3.7 (3.8+ ships with built-in support)
+mock>=4.0.3; python_version < '3.8'
+teamcity-messages>=1.29
diff --git a/tests/stub/conftest.py b/tests/stub/conftest.py
index c8d91e69..7738a949 100644
--- a/tests/stub/conftest.py
+++ b/tests/stub/conftest.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,18 +16,17 @@
# limitations under the License.
-import subprocess
+import logging
import os
-import time
-
from platform import system
+import subprocess
from threading import Thread
from time import sleep
from boltkit.server.stub import BoltStubService
from pytest import fixture
-import logging
+
log = logging.getLogger("neo4j")
# from neo4j.debug import watch
diff --git a/tests/stub/scripts/v1/empty_explicit_hello_goodbye.script b/tests/stub/scripts/v1/empty_explicit_hello_goodbye.script
index cd48e67a..f55e38cc 100644
--- a/tests/stub/scripts/v1/empty_explicit_hello_goodbye.script
+++ b/tests/stub/scripts/v1/empty_explicit_hello_goodbye.script
@@ -3,4 +3,4 @@
C: INIT {"user_agent": "test", "scheme": "basic", "principal": "test", "credentials": "test"}
S: SUCCESS {"server": "Neo4j/3.3.0", "connection_id": "123e4567-e89b-12d3-a456-426655440000"}
C: RESET
-S:
\ No newline at end of file
+S:
diff --git a/tests/stub/scripts/v2/empty_explicit_hello_goodbye.script b/tests/stub/scripts/v2/empty_explicit_hello_goodbye.script
index 7e8ae848..8e661fde 100644
--- a/tests/stub/scripts/v2/empty_explicit_hello_goodbye.script
+++ b/tests/stub/scripts/v2/empty_explicit_hello_goodbye.script
@@ -3,4 +3,4 @@
C: INIT {"user_agent": "test", "scheme": "basic", "principal": "test", "credentials": "test"}
S: SUCCESS {"server": "Neo4j/3.4.0", "connection_id": "123e4567-e89b-12d3-a456-426655440000"}
C: RESET
-S:
\ No newline at end of file
+S:
diff --git a/tests/stub/scripts/v3/dbms_cluster_routing_get_routing_table_system.script b/tests/stub/scripts/v3/dbms_cluster_routing_get_routing_table_system.script
index 7aea9af8..e47a2e2e 100644
--- a/tests/stub/scripts/v3/dbms_cluster_routing_get_routing_table_system.script
+++ b/tests/stub/scripts/v3/dbms_cluster_routing_get_routing_table_system.script
@@ -11,4 +11,4 @@ S: SUCCESS {"fields": ["ttl", "servers"]}
SUCCESS {}
C: GOODBYE
-S:
\ No newline at end of file
+S:
diff --git a/tests/stub/scripts/v3/empty_explicit_hello_goodbye.script b/tests/stub/scripts/v3/empty_explicit_hello_goodbye.script
index 4ed2b804..be98dd7a 100644
--- a/tests/stub/scripts/v3/empty_explicit_hello_goodbye.script
+++ b/tests/stub/scripts/v3/empty_explicit_hello_goodbye.script
@@ -4,4 +4,4 @@
C: HELLO {"user_agent": "test", "scheme": "basic", "principal": "test", "credentials": "test"}
S: SUCCESS {"server": "Neo4j/3.5.0", "connection_id": "123e4567-e89b-12d3-a456-426655440000"}
C: GOODBYE
-S:
\ No newline at end of file
+S:
diff --git a/tests/stub/scripts/v3/get_routing_table.script b/tests/stub/scripts/v3/get_routing_table.script
index b3ec1fa0..07e1ef20 100644
--- a/tests/stub/scripts/v3/get_routing_table.script
+++ b/tests/stub/scripts/v3/get_routing_table.script
@@ -8,4 +8,4 @@ C: RUN "CALL dbms.cluster.routing.getRoutingTable($context)" {"context": {"addre
PULL_ALL
S: SUCCESS {"fields": ["ttl", "servers"]}
RECORD [9223372036854775807, [{"addresses": ["127.0.0.1:9001"],"role": "WRITE"}, {"addresses": ["127.0.0.1:9002"], "role": "READ"},{"addresses": ["127.0.0.1:9001", "127.0.0.1:9002"], "role": "ROUTE"}]]
- SUCCESS {}
\ No newline at end of file
+ SUCCESS {}
diff --git a/tests/stub/scripts/v3/get_routing_table_with_context.script b/tests/stub/scripts/v3/get_routing_table_with_context.script
index 65e67bc6..03a23967 100644
--- a/tests/stub/scripts/v3/get_routing_table_with_context.script
+++ b/tests/stub/scripts/v3/get_routing_table_with_context.script
@@ -8,4 +8,4 @@ C: RUN "CALL dbms.cluster.routing.getRoutingTable($context)" {"context": {"name"
PULL_ALL
S: SUCCESS {"fields": ["ttl", "servers"]}
RECORD [9223372036854775807, [{"addresses": ["127.0.0.1:9001"],"role": "WRITE"}, {"addresses": ["127.0.0.1:9002"], "role": "READ"},{"addresses": ["127.0.0.1:9001", "127.0.0.1:9002"], "role": "ROUTE"}]]
- SUCCESS {}
\ No newline at end of file
+ SUCCESS {}
diff --git a/tests/stub/scripts/v3/pull_all_port_9001_transaction_function.script b/tests/stub/scripts/v3/pull_all_port_9001_transaction_function.script
index 8dd7b406..1003b6ef 100644
--- a/tests/stub/scripts/v3/pull_all_port_9001_transaction_function.script
+++ b/tests/stub/scripts/v3/pull_all_port_9001_transaction_function.script
@@ -16,4 +16,4 @@ S: RECORD [1]
RECORD [4]
SUCCESS {"type": "r", "t_last": 500}
C: COMMIT
-S: SUCCESS {"bookmark": "neo4j:bookmark-test-1"}
\ No newline at end of file
+S: SUCCESS {"bookmark": "neo4j:bookmark-test-1"}
diff --git a/tests/stub/scripts/v3/return_1_port_9001.script b/tests/stub/scripts/v3/return_1_port_9001.script
index 06cc51a5..5135a788 100644
--- a/tests/stub/scripts/v3/return_1_port_9001.script
+++ b/tests/stub/scripts/v3/return_1_port_9001.script
@@ -9,4 +9,4 @@ C: RUN "RETURN 1 AS x" {} {"mode": "r"}
PULL_ALL
S: SUCCESS {"fields": ["x"]}
RECORD [1]
- SUCCESS {"bookmark": "neo4j:bookmark-test-1", "type": "r", "t_last": 5}
\ No newline at end of file
+ SUCCESS {"bookmark": "neo4j:bookmark-test-1", "type": "r", "t_last": 5}
diff --git a/tests/stub/scripts/v3/router_with_multiple_servers.script b/tests/stub/scripts/v3/router_with_multiple_servers.script
index 34ac0f0f..fec9d25d 100644
--- a/tests/stub/scripts/v3/router_with_multiple_servers.script
+++ b/tests/stub/scripts/v3/router_with_multiple_servers.script
@@ -8,4 +8,4 @@ C: RUN "CALL dbms.cluster.routing.getRoutingTable($context)" {"context": {"addre
PULL_ALL
S: SUCCESS {"fields": ["ttl", "servers"]}
RECORD [300, [{"role":"ROUTE","addresses":["127.0.0.1:9001","127.0.0.1:9002"]},{"role":"READ","addresses":["127.0.0.1:9001","127.0.0.1:9003"]},{"role":"WRITE","addresses":["127.0.0.1:9004"]}]]
- SUCCESS {}
\ No newline at end of file
+ SUCCESS {}
diff --git a/tests/stub/scripts/v3/rude_reader.script b/tests/stub/scripts/v3/rude_reader.script
index 055c06f1..ea323e98 100644
--- a/tests/stub/scripts/v3/rude_reader.script
+++ b/tests/stub/scripts/v3/rude_reader.script
@@ -7,4 +7,3 @@
C: RUN "RETURN 1" {} {"mode": "r"}
PULL_ALL
S:
-
diff --git a/tests/stub/scripts/v3/rude_router.script b/tests/stub/scripts/v3/rude_router.script
index 9fa163b8..a29bbbd4 100644
--- a/tests/stub/scripts/v3/rude_router.script
+++ b/tests/stub/scripts/v3/rude_router.script
@@ -7,4 +7,3 @@
C: RUN "CALL dbms.cluster.routing.getRoutingTable($context)" {"context": {"address": "localhost:9001"}} {"mode": "r"}
PULL_ALL
S:
-
diff --git a/tests/stub/scripts/v4x0/dbms_routing_get_routing_table_system_default.script b/tests/stub/scripts/v4x0/dbms_routing_get_routing_table_system_default.script
index 05f9a9ab..ec865816 100644
--- a/tests/stub/scripts/v4x0/dbms_routing_get_routing_table_system_default.script
+++ b/tests/stub/scripts/v4x0/dbms_routing_get_routing_table_system_default.script
@@ -11,4 +11,4 @@ S: SUCCESS {"fields": ["ttl", "servers"]}
SUCCESS {"bookmark": "neo4j:bookmark-test-1", "type": "s", "t_last": 15, "db": "system"}
C: GOODBYE
-S:
\ No newline at end of file
+S:
diff --git a/tests/stub/scripts/v4x0/dbms_routing_get_routing_table_system_neo4j.script b/tests/stub/scripts/v4x0/dbms_routing_get_routing_table_system_neo4j.script
index aed4b3c4..ba93794d 100644
--- a/tests/stub/scripts/v4x0/dbms_routing_get_routing_table_system_neo4j.script
+++ b/tests/stub/scripts/v4x0/dbms_routing_get_routing_table_system_neo4j.script
@@ -11,4 +11,4 @@ S: SUCCESS {"fields": ["ttl", "servers"]}
SUCCESS {"bookmark": "neo4j:bookmark-test-1", "type": "r", "t_last": 15, "db": "neo4j"}
C: GOODBYE
-S:
\ No newline at end of file
+S:
diff --git a/tests/stub/scripts/v4x0/empty.script b/tests/stub/scripts/v4x0/empty.script
index 6e97eca9..9b11fd90 100644
--- a/tests/stub/scripts/v4x0/empty.script
+++ b/tests/stub/scripts/v4x0/empty.script
@@ -2,4 +2,4 @@
!: AUTO HELLO
!: AUTO GOODBYE
!: AUTO RESET
-!: PORT 9001
\ No newline at end of file
+!: PORT 9001
diff --git a/tests/stub/scripts/v4x0/empty_explicit_hello_goodbye.script b/tests/stub/scripts/v4x0/empty_explicit_hello_goodbye.script
index 29957f14..5a2bdc88 100644
--- a/tests/stub/scripts/v4x0/empty_explicit_hello_goodbye.script
+++ b/tests/stub/scripts/v4x0/empty_explicit_hello_goodbye.script
@@ -4,4 +4,4 @@
C: HELLO {"user_agent": "test", "scheme": "basic", "principal": "test", "credentials": "test"}
S: SUCCESS {"server": "Neo4j/4.0.0", "connection_id": "123e4567-e89b-12d3-a456-426655440000"}
C: GOODBYE
-S:
\ No newline at end of file
+S:
diff --git a/tests/stub/scripts/v4x0/return_1_four_times_port_9004.script b/tests/stub/scripts/v4x0/return_1_four_times_port_9004.script
index 47698633..a1f56d8b 100644
--- a/tests/stub/scripts/v4x0/return_1_four_times_port_9004.script
+++ b/tests/stub/scripts/v4x0/return_1_four_times_port_9004.script
@@ -26,4 +26,4 @@ C: RUN "RETURN 1" {} {"mode": "r"}
PULL {"n": -1}
S: SUCCESS {"fields": ["1"]}
RECORD [1]
- SUCCESS {}
\ No newline at end of file
+ SUCCESS {}
diff --git a/tests/stub/scripts/v4x0/return_1_port_9001.script b/tests/stub/scripts/v4x0/return_1_port_9001.script
index c0c9c04c..188c1839 100644
--- a/tests/stub/scripts/v4x0/return_1_port_9001.script
+++ b/tests/stub/scripts/v4x0/return_1_port_9001.script
@@ -9,4 +9,4 @@ C: RUN "RETURN 1 AS x" {} {"mode": "r"}
PULL {"n": -1}
S: SUCCESS {"fields": ["x"]}
RECORD [1]
- SUCCESS {"bookmark": "neo4j:bookmark-test-1", "type": "r", "t_last": 5, "db": "system"}
\ No newline at end of file
+ SUCCESS {"bookmark": "neo4j:bookmark-test-1", "type": "r", "t_last": 5, "db": "system"}
diff --git a/tests/stub/scripts/v4x0/router_port_9001_one_read_port_9004_one_write_port_9006.script b/tests/stub/scripts/v4x0/router_port_9001_one_read_port_9004_one_write_port_9006.script
index cafc5742..b373111f 100644
--- a/tests/stub/scripts/v4x0/router_port_9001_one_read_port_9004_one_write_port_9006.script
+++ b/tests/stub/scripts/v4x0/router_port_9001_one_read_port_9004_one_write_port_9006.script
@@ -10,4 +10,4 @@ C: RUN "CALL dbms.routing.getRoutingTable($context)" {"context": {"address": "lo
PULL {"n": -1}
S: SUCCESS {"fields": ["ttl", "servers"]}
RECORD [300, [{"role":"ROUTE", "addresses":["127.0.0.1:9001", "127.0.0.1:9002"]}, {"role":"READ", "addresses":["127.0.0.1:9004"]}, {"role":"WRITE", "addresses":["127.0.0.1:9006"]}]]
- SUCCESS {"bookmark": "neo4j:bookmark-test-1", "type": "s", "t_last": 5, "db": "system"}
\ No newline at end of file
+ SUCCESS {"bookmark": "neo4j:bookmark-test-1", "type": "s", "t_last": 5, "db": "system"}
diff --git a/tests/stub/test_directdriver.py b/tests/stub/test_directdriver.py
index 2e7ae7f6..492e323f 100644
--- a/tests/stub/test_directdriver.py
+++ b/tests/stub/test_directdriver.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -21,33 +18,14 @@
import pytest
-from neo4j.exceptions import (
- ServiceUnavailable,
- ConfigurationError,
- UnsupportedServerProduct,
-)
-from neo4j._exceptions import (
- BoltHandshakeError,
- BoltSecurityError,
-)
-
from neo4j import (
- GraphDatabase,
BoltDriver,
- Query,
- WRITE_ACCESS,
+ GraphDatabase,
READ_ACCESS,
- TRUST_ALL_CERTIFICATES,
- TRUST_SYSTEM_CA_SIGNED_CERTIFICATES,
- DEFAULT_DATABASE,
- Result,
- unit_of_work,
- Transaction,
)
+from neo4j.exceptions import ServiceUnavailable
+from tests.stub.conftest import StubCluster
-from tests.stub.conftest import (
- StubCluster,
-)
# python -m pytest tests/stub/test_directdriver.py -s -v
diff --git a/tests/stub/test_routingdriver.py b/tests/stub/test_routingdriver.py
index 139c225d..c94b5d95 100644
--- a/tests/stub/test_routingdriver.py
+++ b/tests/stub/test_routingdriver.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -24,25 +21,11 @@
from neo4j import (
GraphDatabase,
Neo4jDriver,
- TRUST_ALL_CERTIFICATES,
- TRUST_SYSTEM_CA_SIGNED_CERTIFICATES,
- DEFAULT_DATABASE,
-)
-from neo4j.api import (
- READ_ACCESS,
- WRITE_ACCESS,
-)
-from neo4j.exceptions import (
- ServiceUnavailable,
- TransientError,
- SessionExpired,
- ConfigurationError,
-)
-from neo4j._exceptions import (
- BoltSecurityError,
)
+from neo4j.exceptions import ServiceUnavailable
from tests.stub.conftest import StubCluster
+
# python -m pytest tests/stub/test_routingdriver.py -s -v
# TODO: those tests will stay until a uniform behavior across the drivers has
# been specified and tests are created in testkit
diff --git a/tests/unit/README.md b/tests/unit/README.md
new file mode 100644
index 00000000..415e2d23
--- /dev/null
+++ b/tests/unit/README.md
@@ -0,0 +1,10 @@
+Structure:
+
+```
+.
+├── _async_compat # utility code to allow auto-converting async tests to sync tests
+├── async # tests that are specifit to async classes, methods, or functions
+├── common # tests of classes, methods, and functions that only exist in sync
+├── mixed # tests that cannot be auto converted from async to sync
+└── sync # auto-genereated sync versions of tetst in async
+```
diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py
index e69de29b..b81a309d 100644
--- a/tests/unit/__init__.py
+++ b/tests/unit/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/unit/_async_compat/__init__.py b/tests/unit/_async_compat/__init__.py
new file mode 100644
index 00000000..285e3af5
--- /dev/null
+++ b/tests/unit/_async_compat/__init__.py
@@ -0,0 +1,53 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import sys
+
+
+if sys.version_info >= (3, 8):
+ from unittest import mock
+ from unittest.mock import AsyncMockMixin
+else:
+ import mock
+ from mock.mock import AsyncMockMixin
+
+from .mark_decorator import (
+ mark_async_test,
+ mark_sync_test,
+)
+
+
+AsyncMagicMock = mock.AsyncMock
+MagicMock = mock.MagicMock
+Mock = mock.Mock
+mock.NonCallableMagicMock
+
+
+class AsyncMock(AsyncMockMixin, Mock):
+ pass
+
+
+__all__ = [
+ "mark_async_test",
+ "mark_sync_test",
+ "AsyncMagicMock",
+ "AsyncMock",
+ "MagicMock",
+ "Mock",
+ "mock",
+]
diff --git a/tests/unit/_async_compat/mark_decorator.py b/tests/unit/_async_compat/mark_decorator.py
new file mode 100644
index 00000000..6cd4832d
--- /dev/null
+++ b/tests/unit/_async_compat/mark_decorator.py
@@ -0,0 +1,26 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import pytest
+
+
+mark_async_test = pytest.mark.asyncio
+
+
+def mark_sync_test(f):
+ return f
diff --git a/tests/unit/async_/__init__.py b/tests/unit/async_/__init__.py
new file mode 100644
index 00000000..b81a309d
--- /dev/null
+++ b/tests/unit/async_/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/unit/async_/io/__init__.py b/tests/unit/async_/io/__init__.py
new file mode 100644
index 00000000..b81a309d
--- /dev/null
+++ b/tests/unit/async_/io/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/unit/async_/io/conftest.py b/tests/unit/async_/io/conftest.py
new file mode 100644
index 00000000..1da3e44e
--- /dev/null
+++ b/tests/unit/async_/io/conftest.py
@@ -0,0 +1,156 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from io import BytesIO
+from struct import (
+ pack as struct_pack,
+ unpack as struct_unpack,
+)
+
+import pytest
+
+from neo4j._async.io._common import AsyncMessageInbox
+from neo4j.packstream import (
+ Packer,
+ UnpackableBuffer,
+ Unpacker,
+)
+
+
+class AsyncFakeSocket:
+
+ def __init__(self, address):
+ self.address = address
+ self.captured = b""
+ self.messages = AsyncMessageInbox(self, on_error=print)
+
+ def getsockname(self):
+ return "127.0.0.1", 0xFFFF
+
+ def getpeername(self):
+ return self.address
+
+ async def recv_into(self, buffer, nbytes):
+ data = self.captured[:nbytes]
+ actual = len(data)
+ buffer[:actual] = data
+ self.captured = self.captured[actual:]
+ return actual
+
+ async def sendall(self, data):
+ self.captured += data
+
+ def close(self):
+ return
+
+ async def pop_message(self):
+ return await self.messages.pop()
+
+
+class AsyncFakeSocket2:
+
+ def __init__(self, address=None, on_send=None):
+ self.address = address
+ self.recv_buffer = bytearray()
+ self._messages = AsyncMessageInbox(self, on_error=print)
+ self.on_send = on_send
+
+ def getsockname(self):
+ return "127.0.0.1", 0xFFFF
+
+ def getpeername(self):
+ return self.address
+
+ async def recv_into(self, buffer, nbytes):
+ data = self.recv_buffer[:nbytes]
+ actual = len(data)
+ buffer[:actual] = data
+ self.recv_buffer = self.recv_buffer[actual:]
+ return actual
+
+ async def sendall(self, data):
+ if callable(self.on_send):
+ self.on_send(data)
+
+ def close(self):
+ return
+
+ def inject(self, data):
+ self.recv_buffer += data
+
+ def _pop_chunk(self):
+ chunk_size, = struct_unpack(">H", self.recv_buffer[:2])
+ print("CHUNK SIZE %r" % chunk_size)
+ end = 2 + chunk_size
+ chunk_data, self.recv_buffer = self.recv_buffer[2:end], self.recv_buffer[end:]
+ return chunk_data
+
+ async def pop_message(self):
+ data = bytearray()
+ while True:
+ chunk = self._pop_chunk()
+ print("CHUNK %r" % chunk)
+ if chunk:
+ data.extend(chunk)
+ elif data:
+ break # end of message
+ else:
+ continue # NOOP
+ header = data[0]
+ n_fields = header % 0x10
+ tag = data[1]
+ buffer = UnpackableBuffer(data[2:])
+ unpacker = Unpacker(buffer)
+ fields = [unpacker.unpack() for _ in range(n_fields)]
+ return tag, fields
+
+ async def send_message(self, tag, *fields):
+ data = self.encode_message(tag, *fields)
+ await self.sendall(struct_pack(">H", len(data)) + data + b"\x00\x00")
+
+ @classmethod
+ def encode_message(cls, tag, *fields):
+ b = BytesIO()
+ packer = Packer(b)
+ for field in fields:
+ packer.pack(field)
+ return bytearray([0xB0 + len(fields), tag]) + b.getvalue()
+
+
+class AsyncFakeSocketPair:
+
+ def __init__(self, address):
+ self.client = AsyncFakeSocket2(address)
+ self.server = AsyncFakeSocket2()
+ self.client.on_send = self.server.inject
+ self.server.on_send = self.client.inject
+
+
+@pytest.fixture
+def fake_socket():
+ return AsyncFakeSocket
+
+
+@pytest.fixture
+def fake_socket_2():
+ return AsyncFakeSocket2
+
+
+@pytest.fixture
+def fake_socket_pair():
+ return AsyncFakeSocketPair
diff --git a/tests/unit/async_/io/test__common.py b/tests/unit/async_/io/test__common.py
new file mode 100644
index 00000000..7fefd5b9
--- /dev/null
+++ b/tests/unit/async_/io/test__common.py
@@ -0,0 +1,50 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import pytest
+
+from neo4j._async.io._common import Outbox
+
+
+@pytest.mark.parametrize(("chunk_size", "data", "result"), (
+ (
+ 2,
+ (bytes(range(10, 15)),),
+ bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 1, 14))
+ ),
+ (
+ 2,
+ (bytes(range(10, 14)),),
+ bytes((0, 2, 10, 11, 0, 2, 12, 13))
+ ),
+ (
+ 2,
+ (bytes((5, 6, 7)), bytes((8, 9))),
+ bytes((0, 2, 5, 6, 0, 2, 7, 8, 0, 1, 9))
+ ),
+))
+def test_async_outbox_chunking(chunk_size, data, result):
+ outbox = Outbox(max_chunk_size=chunk_size)
+ assert bytes(outbox.view()) == b""
+ for d in data:
+ outbox.write(d)
+ assert bytes(outbox.view()) == result
+ # make sure this works multiple times
+ assert bytes(outbox.view()) == result
+ outbox.clear()
+ assert bytes(outbox.view()) == b""
diff --git a/tests/unit/async_/io/test_class_bolt.py b/tests/unit/async_/io/test_class_bolt.py
new file mode 100644
index 00000000..50706d75
--- /dev/null
+++ b/tests/unit/async_/io/test_class_bolt.py
@@ -0,0 +1,62 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import pytest
+
+from neo4j._async.io import AsyncBolt
+
+
+# python -m pytest tests/unit/io/test_class_bolt.py -s -v
+
+
+def test_class_method_protocol_handlers():
+ # python -m pytest tests/unit/io/test_class_bolt.py -s -v -k test_class_method_protocol_handlers
+ protocol_handlers = AsyncBolt.protocol_handlers()
+ assert len(protocol_handlers) == 6
+
+
+@pytest.mark.parametrize(
+ "test_input, expected",
+ [
+ ((0, 0), 0),
+ ((4, 0), 1),
+ ]
+)
+def test_class_method_protocol_handlers_with_protocol_version(test_input, expected):
+ # python -m pytest tests/unit/io/test_class_bolt.py -s -v -k test_class_method_protocol_handlers_with_protocol_version
+ protocol_handlers = AsyncBolt.protocol_handlers(protocol_version=test_input)
+ assert len(protocol_handlers) == expected
+
+
+def test_class_method_protocol_handlers_with_invalid_protocol_version():
+ # python -m pytest tests/unit/io/test_class_bolt.py -s -v -k test_class_method_protocol_handlers_with_invalid_protocol_version
+ with pytest.raises(TypeError):
+ AsyncBolt.protocol_handlers(protocol_version=2)
+
+
+def test_class_method_get_handshake():
+ # python -m pytest tests/unit/io/test_class_bolt.py -s -v -k test_class_method_get_handshake
+ handshake = AsyncBolt.get_handshake()
+ assert handshake == b"\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x04\x00\x00\x00\x03"
+
+
+def test_magic_preamble():
+ # python -m pytest tests/unit/io/test_class_bolt.py -s -v -k test_magic_preamble
+ preamble = 0x6060B017
+ preamble_bytes = preamble.to_bytes(4, byteorder="big")
+ assert AsyncBolt.MAGIC_PREAMBLE == preamble_bytes
diff --git a/tests/unit/async_/io/test_class_bolt3.py b/tests/unit/async_/io/test_class_bolt3.py
new file mode 100644
index 00000000..b1220da6
--- /dev/null
+++ b/tests/unit/async_/io/test_class_bolt3.py
@@ -0,0 +1,115 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import pytest
+
+from neo4j._async.io._bolt3 import AsyncBolt3
+from neo4j.conf import PoolConfig
+from neo4j.exceptions import ConfigurationError
+
+from ..._async_compat import (
+ AsyncMagicMock,
+ mark_async_test,
+)
+
+
+@pytest.mark.parametrize("set_stale", (True, False))
+def test_conn_is_stale(fake_socket, set_stale):
+ address = ("127.0.0.1", 7687)
+ max_connection_lifetime = 0
+ connection = AsyncBolt3(address, fake_socket(address), max_connection_lifetime)
+ if set_stale:
+ connection.set_stale()
+ assert connection.stale() is True
+
+
+@pytest.mark.parametrize("set_stale", (True, False))
+def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale):
+ address = ("127.0.0.1", 7687)
+ max_connection_lifetime = -1
+ connection = AsyncBolt3(address, fake_socket(address), max_connection_lifetime)
+ if set_stale:
+ connection.set_stale()
+ assert connection.stale() is set_stale
+
+
+@pytest.mark.parametrize("set_stale", (True, False))
+def test_conn_is_not_stale(fake_socket, set_stale):
+ address = ("127.0.0.1", 7687)
+ max_connection_lifetime = 999999999
+ connection = AsyncBolt3(address, fake_socket(address), max_connection_lifetime)
+ if set_stale:
+ connection.set_stale()
+ assert connection.stale() is set_stale
+
+
+def test_db_extra_not_supported_in_begin(fake_socket):
+ address = ("127.0.0.1", 7687)
+ connection = AsyncBolt3(address, fake_socket(address), PoolConfig.max_connection_lifetime)
+ with pytest.raises(ConfigurationError):
+ connection.begin(db="something")
+
+
+def test_db_extra_not_supported_in_run(fake_socket):
+ address = ("127.0.0.1", 7687)
+ connection = AsyncBolt3(address, fake_socket(address), PoolConfig.max_connection_lifetime)
+ with pytest.raises(ConfigurationError):
+ connection.run("", db="something")
+
+
+@mark_async_test
+async def test_simple_discard(fake_socket):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt3(address, socket, PoolConfig.max_connection_lifetime)
+ connection.discard()
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x2F"
+ assert len(fields) == 0
+
+
+@mark_async_test
+async def test_simple_pull(fake_socket):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt3(address, socket, PoolConfig.max_connection_lifetime)
+ connection.pull()
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x3F"
+ assert len(fields) == 0
+
+
+@pytest.mark.parametrize("recv_timeout", (1, -1))
+@mark_async_test
+async def test_hint_recv_timeout_seconds_gets_ignored(
+ fake_socket_pair, recv_timeout
+):
+ address = ("127.0.0.1", 7687)
+ sockets = fake_socket_pair(address)
+ sockets.client.settimeout = AsyncMagicMock()
+ await sockets.server.send_message(0x70, {
+ "server": "Neo4j/3.5.0",
+ "hints": {"connection.recv_timeout_seconds": recv_timeout},
+ })
+ connection = AsyncBolt3(
+ address, sockets.client, PoolConfig.max_connection_lifetime
+ )
+ await connection.hello()
+ sockets.client.settimeout.assert_not_called()
diff --git a/tests/unit/async_/io/test_class_bolt4x0.py b/tests/unit/async_/io/test_class_bolt4x0.py
new file mode 100644
index 00000000..c2623cf1
--- /dev/null
+++ b/tests/unit/async_/io/test_class_bolt4x0.py
@@ -0,0 +1,209 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from unittest.mock import MagicMock
+
+import pytest
+
+from neo4j._async.io._bolt4 import AsyncBolt4x0
+from neo4j.conf import PoolConfig
+
+from ..._async_compat import mark_async_test
+
+
+@pytest.mark.parametrize("set_stale", (True, False))
+def test_conn_is_stale(fake_socket, set_stale):
+ address = ("127.0.0.1", 7687)
+ max_connection_lifetime = 0
+ connection = AsyncBolt4x0(address, fake_socket(address), max_connection_lifetime)
+ if set_stale:
+ connection.set_stale()
+ assert connection.stale() is True
+
+
+@pytest.mark.parametrize("set_stale", (True, False))
+def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale):
+ address = ("127.0.0.1", 7687)
+ max_connection_lifetime = -1
+ connection = AsyncBolt4x0(address, fake_socket(address), max_connection_lifetime)
+ if set_stale:
+ connection.set_stale()
+ assert connection.stale() is set_stale
+
+
+@pytest.mark.parametrize("set_stale", (True, False))
+def test_conn_is_not_stale(fake_socket, set_stale):
+ address = ("127.0.0.1", 7687)
+ max_connection_lifetime = 999999999
+ connection = AsyncBolt4x0(address, fake_socket(address), max_connection_lifetime)
+ if set_stale:
+ connection.set_stale()
+ assert connection.stale() is set_stale
+
+
+@mark_async_test
+async def test_db_extra_in_begin(fake_socket):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime)
+ connection.begin(db="something")
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x11"
+ assert len(fields) == 1
+ assert fields[0] == {"db": "something"}
+
+
+@mark_async_test
+async def test_db_extra_in_run(fake_socket):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime)
+ connection.run("", {}, db="something")
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x10"
+ assert len(fields) == 3
+ assert fields[0] == ""
+ assert fields[1] == {}
+ assert fields[2] == {"db": "something"}
+
+
+@mark_async_test
+async def test_n_extra_in_discard(fake_socket):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime)
+ connection.discard(n=666)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x2F"
+ assert len(fields) == 1
+ assert fields[0] == {"n": 666}
+
+
+@pytest.mark.parametrize(
+ "test_input, expected",
+ [
+ (666, {"n": -1, "qid": 666}),
+ (-1, {"n": -1}),
+ ]
+)
+@mark_async_test
+async def test_qid_extra_in_discard(fake_socket, test_input, expected):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime)
+ connection.discard(qid=test_input)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x2F"
+ assert len(fields) == 1
+ assert fields[0] == expected
+
+
+@pytest.mark.parametrize(
+ "test_input, expected",
+ [
+ (777, {"n": 666, "qid": 777}),
+ (-1, {"n": 666}),
+ ]
+)
+@mark_async_test
+async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime)
+ connection.discard(n=666, qid=test_input)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x2F"
+ assert len(fields) == 1
+ assert fields[0] == expected
+
+
+@pytest.mark.parametrize(
+ "test_input, expected",
+ [
+ (666, {"n": 666}),
+ (-1, {"n": -1}),
+ ]
+)
+@mark_async_test
+async def test_n_extra_in_pull(fake_socket, test_input, expected):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime)
+ connection.pull(n=test_input)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x3F"
+ assert len(fields) == 1
+ assert fields[0] == expected
+
+
+@pytest.mark.parametrize(
+ "test_input, expected",
+ [
+ (777, {"n": -1, "qid": 777}),
+ (-1, {"n": -1}),
+ ]
+)
+@mark_async_test
+async def test_qid_extra_in_pull(fake_socket, test_input, expected):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime)
+ connection.pull(qid=test_input)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x3F"
+ assert len(fields) == 1
+ assert fields[0] == expected
+
+
+@mark_async_test
+async def test_n_and_qid_extras_in_pull(fake_socket):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime)
+ connection.pull(n=666, qid=777)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x3F"
+ assert len(fields) == 1
+ assert fields[0] == {"n": 666, "qid": 777}
+
+
+@pytest.mark.parametrize("recv_timeout", (1, -1))
+@mark_async_test
+async def test_hint_recv_timeout_seconds_gets_ignored(
+ fake_socket_pair, recv_timeout
+):
+ address = ("127.0.0.1", 7687)
+ sockets = fake_socket_pair(address)
+ sockets.client.settimeout = MagicMock()
+ await sockets.server.send_message(0x70, {
+ "server": "Neo4j/4.0.0",
+ "hints": {"connection.recv_timeout_seconds": recv_timeout},
+ })
+ connection = AsyncBolt4x0(
+ address, sockets.client, PoolConfig.max_connection_lifetime
+ )
+ await connection.hello()
+ sockets.client.settimeout.assert_not_called()
diff --git a/tests/unit/async_/io/test_class_bolt4x1.py b/tests/unit/async_/io/test_class_bolt4x1.py
new file mode 100644
index 00000000..9123e3b0
--- /dev/null
+++ b/tests/unit/async_/io/test_class_bolt4x1.py
@@ -0,0 +1,227 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import pytest
+
+from neo4j._async.io._bolt4 import AsyncBolt4x1
+from neo4j.conf import PoolConfig
+
+from ..._async_compat import (
+ AsyncMagicMock,
+ mark_async_test,
+)
+
+
+@pytest.mark.parametrize("set_stale", (True, False))
+def test_conn_is_stale(fake_socket, set_stale):
+ address = ("127.0.0.1", 7687)
+ max_connection_lifetime = 0
+ connection = AsyncBolt4x1(address, fake_socket(address), max_connection_lifetime)
+ if set_stale:
+ connection.set_stale()
+ assert connection.stale() is True
+
+
+@pytest.mark.parametrize("set_stale", (True, False))
+def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale):
+ address = ("127.0.0.1", 7687)
+ max_connection_lifetime = -1
+ connection = AsyncBolt4x1(address, fake_socket(address), max_connection_lifetime)
+ if set_stale:
+ connection.set_stale()
+ assert connection.stale() is set_stale
+
+
+@pytest.mark.parametrize("set_stale", (True, False))
+def test_conn_is_not_stale(fake_socket, set_stale):
+ address = ("127.0.0.1", 7687)
+ max_connection_lifetime = 999999999
+ connection = AsyncBolt4x1(address, fake_socket(address), max_connection_lifetime)
+ if set_stale:
+ connection.set_stale()
+ assert connection.stale() is set_stale
+
+
+@mark_async_test
+async def test_db_extra_in_begin(fake_socket):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime)
+ connection.begin(db="something")
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x11"
+ assert len(fields) == 1
+ assert fields[0] == {"db": "something"}
+
+
+@mark_async_test
+async def test_db_extra_in_run(fake_socket):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime)
+ connection.run("", {}, db="something")
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x10"
+ assert len(fields) == 3
+ assert fields[0] == ""
+ assert fields[1] == {}
+ assert fields[2] == {"db": "something"}
+
+
+@mark_async_test
+async def test_n_extra_in_discard(fake_socket):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime)
+ connection.discard(n=666)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x2F"
+ assert len(fields) == 1
+ assert fields[0] == {"n": 666}
+
+
+@pytest.mark.parametrize(
+ "test_input, expected",
+ [
+ (666, {"n": -1, "qid": 666}),
+ (-1, {"n": -1}),
+ ]
+)
+@mark_async_test
+async def test_qid_extra_in_discard(fake_socket, test_input, expected):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime)
+ connection.discard(qid=test_input)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x2F"
+ assert len(fields) == 1
+ assert fields[0] == expected
+
+
+@pytest.mark.parametrize(
+ "test_input, expected",
+ [
+ (777, {"n": 666, "qid": 777}),
+ (-1, {"n": 666}),
+ ]
+)
+@mark_async_test
+async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected):
+ # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime)
+ connection.discard(n=666, qid=test_input)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x2F"
+ assert len(fields) == 1
+ assert fields[0] == expected
+
+
+@pytest.mark.parametrize(
+ "test_input, expected",
+ [
+ (666, {"n": 666}),
+ (-1, {"n": -1}),
+ ]
+)
+@mark_async_test
+async def test_n_extra_in_pull(fake_socket, test_input, expected):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime)
+ connection.pull(n=test_input)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x3F"
+ assert len(fields) == 1
+ assert fields[0] == expected
+
+
+@pytest.mark.parametrize(
+ "test_input, expected",
+ [
+ (777, {"n": -1, "qid": 777}),
+ (-1, {"n": -1}),
+ ]
+)
+@mark_async_test
+async def test_qid_extra_in_pull(fake_socket, test_input, expected):
+ # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime)
+ connection.pull(qid=test_input)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x3F"
+ assert len(fields) == 1
+ assert fields[0] == expected
+
+
+@mark_async_test
+async def test_n_and_qid_extras_in_pull(fake_socket):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime)
+ connection.pull(n=666, qid=777)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x3F"
+ assert len(fields) == 1
+ assert fields[0] == {"n": 666, "qid": 777}
+
+
+@mark_async_test
+async def test_hello_passes_routing_metadata(fake_socket_pair):
+ address = ("127.0.0.1", 7687)
+ sockets = fake_socket_pair(address)
+ await sockets.server.send_message(0x70, {"server": "Neo4j/4.1.0"})
+ connection = AsyncBolt4x1(
+ address, sockets.client, PoolConfig.max_connection_lifetime,
+ routing_context={"foo": "bar"}
+ )
+ await connection.hello()
+ tag, fields = await sockets.server.pop_message()
+ assert tag == 0x01
+ assert len(fields) == 1
+ assert fields[0]["routing"] == {"foo": "bar"}
+
+
+@pytest.mark.parametrize("recv_timeout", (1, -1))
+@mark_async_test
+async def test_hint_recv_timeout_seconds_gets_ignored(
+ fake_socket_pair, recv_timeout
+):
+ address = ("127.0.0.1", 7687)
+ sockets = fake_socket_pair(address)
+ sockets.client.settimeout = AsyncMagicMock()
+ await sockets.server.send_message(0x70, {
+ "server": "Neo4j/4.1.0",
+ "hints": {"connection.recv_timeout_seconds": recv_timeout},
+ })
+ connection = AsyncBolt4x1(address, sockets.client,
+ PoolConfig.max_connection_lifetime)
+ await connection.hello()
+ sockets.client.settimeout.assert_not_called()
diff --git a/tests/unit/async_/io/test_class_bolt4x2.py b/tests/unit/async_/io/test_class_bolt4x2.py
new file mode 100644
index 00000000..1a575b2b
--- /dev/null
+++ b/tests/unit/async_/io/test_class_bolt4x2.py
@@ -0,0 +1,228 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import pytest
+
+from neo4j._async.io._bolt4 import AsyncBolt4x2
+from neo4j.conf import PoolConfig
+
+from ..._async_compat import (
+ AsyncMagicMock,
+ mark_async_test,
+)
+
+
+@pytest.mark.parametrize("set_stale", (True, False))
+def test_conn_is_stale(fake_socket, set_stale):
+ address = ("127.0.0.1", 7687)
+ max_connection_lifetime = 0
+ connection = AsyncBolt4x2(address, fake_socket(address), max_connection_lifetime)
+ if set_stale:
+ connection.set_stale()
+ assert connection.stale() is True
+
+
+@pytest.mark.parametrize("set_stale", (True, False))
+def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale):
+ address = ("127.0.0.1", 7687)
+ max_connection_lifetime = -1
+ connection = AsyncBolt4x2(address, fake_socket(address), max_connection_lifetime)
+ if set_stale:
+ connection.set_stale()
+ assert connection.stale() is set_stale
+
+
+@pytest.mark.parametrize("set_stale", (True, False))
+def test_conn_is_not_stale(fake_socket, set_stale):
+ address = ("127.0.0.1", 7687)
+ max_connection_lifetime = 999999999
+ connection = AsyncBolt4x2(address, fake_socket(address), max_connection_lifetime)
+ if set_stale:
+ connection.set_stale()
+ assert connection.stale() is set_stale
+
+
+@mark_async_test
+async def test_db_extra_in_begin(fake_socket):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime)
+ connection.begin(db="something")
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x11"
+ assert len(fields) == 1
+ assert fields[0] == {"db": "something"}
+
+
+@mark_async_test
+async def test_db_extra_in_run(fake_socket):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime)
+ connection.run("", {}, db="something")
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x10"
+ assert len(fields) == 3
+ assert fields[0] == ""
+ assert fields[1] == {}
+ assert fields[2] == {"db": "something"}
+
+
+@mark_async_test
+async def test_n_extra_in_discard(fake_socket):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime)
+ connection.discard(n=666)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x2F"
+ assert len(fields) == 1
+ assert fields[0] == {"n": 666}
+
+
+@pytest.mark.parametrize(
+ "test_input, expected",
+ [
+ (666, {"n": -1, "qid": 666}),
+ (-1, {"n": -1}),
+ ]
+)
+@mark_async_test
+async def test_qid_extra_in_discard(fake_socket, test_input, expected):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime)
+ connection.discard(qid=test_input)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x2F"
+ assert len(fields) == 1
+ assert fields[0] == expected
+
+
+@pytest.mark.parametrize(
+ "test_input, expected",
+ [
+ (777, {"n": 666, "qid": 777}),
+ (-1, {"n": 666}),
+ ]
+)
+@mark_async_test
+async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected):
+ # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime)
+ connection.discard(n=666, qid=test_input)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x2F"
+ assert len(fields) == 1
+ assert fields[0] == expected
+
+
+@pytest.mark.parametrize(
+ "test_input, expected",
+ [
+ (666, {"n": 666}),
+ (-1, {"n": -1}),
+ ]
+)
+@mark_async_test
+async def test_n_extra_in_pull(fake_socket, test_input, expected):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime)
+ connection.pull(n=test_input)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x3F"
+ assert len(fields) == 1
+ assert fields[0] == expected
+
+
+@pytest.mark.parametrize(
+ "test_input, expected",
+ [
+ (777, {"n": -1, "qid": 777}),
+ (-1, {"n": -1}),
+ ]
+)
+@mark_async_test
+async def test_qid_extra_in_pull(fake_socket, test_input, expected):
+ # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime)
+ connection.pull(qid=test_input)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x3F"
+ assert len(fields) == 1
+ assert fields[0] == expected
+
+
+@mark_async_test
+async def test_n_and_qid_extras_in_pull(fake_socket):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime)
+ connection.pull(n=666, qid=777)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x3F"
+ assert len(fields) == 1
+ assert fields[0] == {"n": 666, "qid": 777}
+
+
+@mark_async_test
+async def test_hello_passes_routing_metadata(fake_socket_pair):
+ address = ("127.0.0.1", 7687)
+ sockets = fake_socket_pair(address)
+ await sockets.server.send_message(0x70, {"server": "Neo4j/4.2.0"})
+ connection = AsyncBolt4x2(
+ address, sockets.client, PoolConfig.max_connection_lifetime,
+ routing_context={"foo": "bar"}
+ )
+ await connection.hello()
+ tag, fields = await sockets.server.pop_message()
+ assert tag == 0x01
+ assert len(fields) == 1
+ assert fields[0]["routing"] == {"foo": "bar"}
+
+
+@pytest.mark.parametrize("recv_timeout", (1, -1))
+@mark_async_test
+async def test_hint_recv_timeout_seconds_gets_ignored(
+ fake_socket_pair, recv_timeout
+):
+ address = ("127.0.0.1", 7687)
+ sockets = fake_socket_pair(address)
+ sockets.client.settimeout = AsyncMagicMock()
+ await sockets.server.send_message(0x70, {
+ "server": "Neo4j/4.2.0",
+ "hints": {"connection.recv_timeout_seconds": recv_timeout},
+ })
+ connection = AsyncBolt4x2(
+ address, sockets.client, PoolConfig.max_connection_lifetime
+ )
+ await connection.hello()
+ sockets.client.settimeout.assert_not_called()
diff --git a/tests/unit/async_/io/test_class_bolt4x3.py b/tests/unit/async_/io/test_class_bolt4x3.py
new file mode 100644
index 00000000..e3b22af2
--- /dev/null
+++ b/tests/unit/async_/io/test_class_bolt4x3.py
@@ -0,0 +1,255 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+
+import pytest
+
+from neo4j._async.io._bolt4 import AsyncBolt4x3
+from neo4j.conf import PoolConfig
+
+from ..._async_compat import (
+ AsyncMagicMock,
+ mark_async_test,
+)
+
+
+@pytest.mark.parametrize("set_stale", (True, False))
+def test_conn_is_stale(fake_socket, set_stale):
+ address = ("127.0.0.1", 7687)
+ max_connection_lifetime = 0
+ connection = AsyncBolt4x3(address, fake_socket(address), max_connection_lifetime)
+ if set_stale:
+ connection.set_stale()
+ assert connection.stale() is True
+
+
+@pytest.mark.parametrize("set_stale", (True, False))
+def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale):
+ address = ("127.0.0.1", 7687)
+ max_connection_lifetime = -1
+ connection = AsyncBolt4x3(address, fake_socket(address), max_connection_lifetime)
+ if set_stale:
+ connection.set_stale()
+ assert connection.stale() is set_stale
+
+
+@pytest.mark.parametrize("set_stale", (True, False))
+def test_conn_is_not_stale(fake_socket, set_stale):
+ address = ("127.0.0.1", 7687)
+ max_connection_lifetime = 999999999
+ connection = AsyncBolt4x3(address, fake_socket(address), max_connection_lifetime)
+ if set_stale:
+ connection.set_stale()
+ assert connection.stale() is set_stale
+
+
+@mark_async_test
+async def test_db_extra_in_begin(fake_socket):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime)
+ connection.begin(db="something")
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x11"
+ assert len(fields) == 1
+ assert fields[0] == {"db": "something"}
+
+
+@mark_async_test
+async def test_db_extra_in_run(fake_socket):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime)
+ connection.run("", {}, db="something")
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x10"
+ assert len(fields) == 3
+ assert fields[0] == ""
+ assert fields[1] == {}
+ assert fields[2] == {"db": "something"}
+
+
+@mark_async_test
+async def test_n_extra_in_discard(fake_socket):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime)
+ connection.discard(n=666)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x2F"
+ assert len(fields) == 1
+ assert fields[0] == {"n": 666}
+
+
+@pytest.mark.parametrize(
+ "test_input, expected",
+ [
+ (666, {"n": -1, "qid": 666}),
+ (-1, {"n": -1}),
+ ]
+)
+@mark_async_test
+async def test_qid_extra_in_discard(fake_socket, test_input, expected):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime)
+ connection.discard(qid=test_input)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x2F"
+ assert len(fields) == 1
+ assert fields[0] == expected
+
+
+@pytest.mark.parametrize(
+ "test_input, expected",
+ [
+ (777, {"n": 666, "qid": 777}),
+ (-1, {"n": 666}),
+ ]
+)
+@mark_async_test
+async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected):
+ # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime)
+ connection.discard(n=666, qid=test_input)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x2F"
+ assert len(fields) == 1
+ assert fields[0] == expected
+
+
+@pytest.mark.parametrize(
+ "test_input, expected",
+ [
+ (666, {"n": 666}),
+ (-1, {"n": -1}),
+ ]
+)
+@mark_async_test
+async def test_n_extra_in_pull(fake_socket, test_input, expected):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime)
+ connection.pull(n=test_input)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x3F"
+ assert len(fields) == 1
+ assert fields[0] == expected
+
+
+@pytest.mark.parametrize(
+ "test_input, expected",
+ [
+ (777, {"n": -1, "qid": 777}),
+ (-1, {"n": -1}),
+ ]
+)
+@mark_async_test
+async def test_qid_extra_in_pull(fake_socket, test_input, expected):
+ # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime)
+ connection.pull(qid=test_input)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x3F"
+ assert len(fields) == 1
+ assert fields[0] == expected
+
+
+@mark_async_test
+async def test_n_and_qid_extras_in_pull(fake_socket):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime)
+ connection.pull(n=666, qid=777)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x3F"
+ assert len(fields) == 1
+ assert fields[0] == {"n": 666, "qid": 777}
+
+
+@mark_async_test
+async def test_hello_passes_routing_metadata(fake_socket_pair):
+ address = ("127.0.0.1", 7687)
+ sockets = fake_socket_pair(address)
+ await sockets.server.send_message(0x70, {"server": "Neo4j/4.3.0"})
+ connection = AsyncBolt4x3(
+ address, sockets.client, PoolConfig.max_connection_lifetime,
+ routing_context={"foo": "bar"}
+ )
+ await connection.hello()
+ tag, fields = await sockets.server.pop_message()
+ assert tag == 0x01
+ assert len(fields) == 1
+ assert fields[0]["routing"] == {"foo": "bar"}
+
+
+@pytest.mark.parametrize(("hints", "valid"), (
+ ({"connection.recv_timeout_seconds": 1}, True),
+ ({"connection.recv_timeout_seconds": 42}, True),
+ ({}, True),
+ ({"whatever_this_is": "ignore me!"}, True),
+ ({"connection.recv_timeout_seconds": -1}, False),
+ ({"connection.recv_timeout_seconds": 0}, False),
+ ({"connection.recv_timeout_seconds": 2.5}, False),
+ ({"connection.recv_timeout_seconds": None}, False),
+ ({"connection.recv_timeout_seconds": False}, False),
+ ({"connection.recv_timeout_seconds": "1"}, False),
+))
+@mark_async_test
+async def test_hint_recv_timeout_seconds(
+ fake_socket_pair, hints, valid, caplog
+):
+ address = ("127.0.0.1", 7687)
+ sockets = fake_socket_pair(address)
+ sockets.client.settimeout = AsyncMagicMock()
+ await sockets.server.send_message(
+ 0x70, {"server": "Neo4j/4.3.0", "hints": hints}
+ )
+ connection = AsyncBolt4x3(address, sockets.client,
+ PoolConfig.max_connection_lifetime)
+ with caplog.at_level(logging.INFO):
+ await connection.hello()
+ if valid:
+ if "connection.recv_timeout_seconds" in hints:
+ sockets.client.settimeout.assert_called_once_with(
+ hints["connection.recv_timeout_seconds"]
+ )
+ else:
+ sockets.client.settimeout.assert_not_called()
+ assert not any("recv_timeout_seconds" in msg
+ and "invalid" in msg
+ for msg in caplog.messages)
+ else:
+ sockets.client.settimeout.assert_not_called()
+ assert any(repr(hints["connection.recv_timeout_seconds"]) in msg
+ and "recv_timeout_seconds" in msg
+ and "invalid" in msg
+ for msg in caplog.messages)
diff --git a/tests/unit/async_/io/test_class_bolt4x4.py b/tests/unit/async_/io/test_class_bolt4x4.py
new file mode 100644
index 00000000..70ecafe0
--- /dev/null
+++ b/tests/unit/async_/io/test_class_bolt4x4.py
@@ -0,0 +1,271 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+from unittest.mock import MagicMock
+
+import pytest
+
+from neo4j._async.io._bolt4 import AsyncBolt4x4
+from neo4j.conf import PoolConfig
+
+from ..._async_compat import (
+ AsyncMagicMock,
+ mark_async_test,
+)
+
+
+@pytest.mark.parametrize("set_stale", (True, False))
+def test_conn_is_stale(fake_socket, set_stale):
+ address = ("127.0.0.1", 7687)
+ max_connection_lifetime = 0
+ connection = AsyncBolt4x4(address, fake_socket(address), max_connection_lifetime)
+ if set_stale:
+ connection.set_stale()
+ assert connection.stale() is True
+
+
+@pytest.mark.parametrize("set_stale", (True, False))
+def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale):
+ address = ("127.0.0.1", 7687)
+ max_connection_lifetime = -1
+ connection = AsyncBolt4x4(address, fake_socket(address), max_connection_lifetime)
+ if set_stale:
+ connection.set_stale()
+ assert connection.stale() is set_stale
+
+
+@pytest.mark.parametrize("set_stale", (True, False))
+def test_conn_is_not_stale(fake_socket, set_stale):
+ address = ("127.0.0.1", 7687)
+ max_connection_lifetime = 999999999
+ connection = AsyncBolt4x4(address, fake_socket(address), max_connection_lifetime)
+ if set_stale:
+ connection.set_stale()
+ assert connection.stale() is set_stale
+
+
+@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), (
+ (("", {}), {"db": "something"}, ({"db": "something"},)),
+ (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)),
+ (
+ ("", {}),
+ {"db": "something", "imp_user": "imposter"},
+ ({"db": "something", "imp_user": "imposter"},)
+ ),
+))
+@mark_async_test
+async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime)
+ connection.begin(*args, **kwargs)
+ await connection.send_all()
+ tag, is_fields = await socket.pop_message()
+ assert tag == b"\x11"
+ assert tuple(is_fields) == expected_fields
+
+
+@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), (
+ (("", {}), {"db": "something"}, ("", {}, {"db": "something"})),
+ (("", {}), {"imp_user": "imposter"}, ("", {}, {"imp_user": "imposter"})),
+ (
+ ("", {}),
+ {"db": "something", "imp_user": "imposter"},
+ ("", {}, {"db": "something", "imp_user": "imposter"})
+ ),
+))
+@mark_async_test
+async def test_extra_in_run(fake_socket, args, kwargs, expected_fields):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime)
+ connection.run(*args, **kwargs)
+ await connection.send_all()
+ tag, is_fields = await socket.pop_message()
+ assert tag == b"\x10"
+ assert tuple(is_fields) == expected_fields
+
+
+@mark_async_test
+async def test_n_extra_in_discard(fake_socket):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime)
+ connection.discard(n=666)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x2F"
+ assert len(fields) == 1
+ assert fields[0] == {"n": 666}
+
+
+@pytest.mark.parametrize(
+ "test_input, expected",
+ [
+ (666, {"n": -1, "qid": 666}),
+ (-1, {"n": -1}),
+ ]
+)
+@mark_async_test
+async def test_qid_extra_in_discard(fake_socket, test_input, expected):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime)
+ connection.discard(qid=test_input)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x2F"
+ assert len(fields) == 1
+ assert fields[0] == expected
+
+
+@pytest.mark.parametrize(
+ "test_input, expected",
+ [
+ (777, {"n": 666, "qid": 777}),
+ (-1, {"n": 666}),
+ ]
+)
+@mark_async_test
+async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected):
+ # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime)
+ connection.discard(n=666, qid=test_input)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x2F"
+ assert len(fields) == 1
+ assert fields[0] == expected
+
+
+@pytest.mark.parametrize(
+ "test_input, expected",
+ [
+ (666, {"n": 666}),
+ (-1, {"n": -1}),
+ ]
+)
+@mark_async_test
+async def test_n_extra_in_pull(fake_socket, test_input, expected):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime)
+ connection.pull(n=test_input)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x3F"
+ assert len(fields) == 1
+ assert fields[0] == expected
+
+
+@pytest.mark.parametrize(
+ "test_input, expected",
+ [
+ (777, {"n": -1, "qid": 777}),
+ (-1, {"n": -1}),
+ ]
+)
+@mark_async_test
+async def test_qid_extra_in_pull(fake_socket, test_input, expected):
+ # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime)
+ connection.pull(qid=test_input)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x3F"
+ assert len(fields) == 1
+ assert fields[0] == expected
+
+
+@mark_async_test
+async def test_n_and_qid_extras_in_pull(fake_socket):
+ address = ("127.0.0.1", 7687)
+ socket = fake_socket(address)
+ connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime)
+ connection.pull(n=666, qid=777)
+ await connection.send_all()
+ tag, fields = await socket.pop_message()
+ assert tag == b"\x3F"
+ assert len(fields) == 1
+ assert fields[0] == {"n": 666, "qid": 777}
+
+
+@mark_async_test
+async def test_hello_passes_routing_metadata(fake_socket_pair):
+ address = ("127.0.0.1", 7687)
+ sockets = fake_socket_pair(address)
+ await sockets.server.send_message(0x70, {"server": "Neo4j/4.4.0"})
+ connection = AsyncBolt4x4(
+ address, sockets.client, PoolConfig.max_connection_lifetime,
+ routing_context={"foo": "bar"}
+ )
+ await connection.hello()
+ tag, fields = await sockets.server.pop_message()
+ assert tag == 0x01
+ assert len(fields) == 1
+ assert fields[0]["routing"] == {"foo": "bar"}
+
+
+@pytest.mark.parametrize(("hints", "valid"), (
+ ({"connection.recv_timeout_seconds": 1}, True),
+ ({"connection.recv_timeout_seconds": 42}, True),
+ ({}, True),
+ ({"whatever_this_is": "ignore me!"}, True),
+ ({"connection.recv_timeout_seconds": -1}, False),
+ ({"connection.recv_timeout_seconds": 0}, False),
+ ({"connection.recv_timeout_seconds": 2.5}, False),
+ ({"connection.recv_timeout_seconds": None}, False),
+ ({"connection.recv_timeout_seconds": False}, False),
+ ({"connection.recv_timeout_seconds": "1"}, False),
+))
+@mark_async_test
+async def test_hint_recv_timeout_seconds(
+ fake_socket_pair, hints, valid, caplog
+):
+ address = ("127.0.0.1", 7687)
+ sockets = fake_socket_pair(address)
+ sockets.client.settimeout = MagicMock()
+ await sockets.server.send_message(
+ 0x70, {"server": "Neo4j/4.3.4", "hints": hints}
+ )
+ connection = AsyncBolt4x4(
+ address, sockets.client, PoolConfig.max_connection_lifetime
+ )
+ with caplog.at_level(logging.INFO):
+ await connection.hello()
+ if valid:
+ if "connection.recv_timeout_seconds" in hints:
+ sockets.client.settimeout.assert_called_once_with(
+ hints["connection.recv_timeout_seconds"]
+ )
+ else:
+ sockets.client.settimeout.assert_not_called()
+ assert not any("recv_timeout_seconds" in msg
+ and "invalid" in msg
+ for msg in caplog.messages)
+ else:
+ sockets.client.settimeout.assert_not_called()
+ assert any(repr(hints["connection.recv_timeout_seconds"]) in msg
+ and "recv_timeout_seconds" in msg
+ and "invalid" in msg
+ for msg in caplog.messages)
diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py
new file mode 100644
index 00000000..266b9779
--- /dev/null
+++ b/tests/unit/async_/io/test_direct.py
@@ -0,0 +1,231 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import pytest
+
+from neo4j import (
+ Config,
+ PoolConfig,
+ WorkspaceConfig,
+)
+from neo4j._async.io import AsyncBolt
+from neo4j._async.io._pool import AsyncIOPool
+from neo4j.exceptions import (
+ ClientError,
+ ServiceUnavailable,
+)
+
+from ..._async_compat import (
+ AsyncMock,
+ mark_async_test,
+ mock,
+)
+
+
+class AsyncFakeSocket:
+ def __init__(self, address):
+ self.address = address
+
+ def getpeername(self):
+ return self.address
+
+ async def sendall(self, data):
+ return
+
+ def close(self):
+ return
+
+
+class AsyncQuickConnection:
+ def __init__(self, socket):
+ self.socket = socket
+ self.address = socket.getpeername()
+
+ @property
+ def is_reset(self):
+ return True
+
+ def stale(self):
+ return False
+
+ async def reset(self):
+ pass
+
+ def close(self):
+ self.socket.close()
+
+ def closed(self):
+ return False
+
+ def defunct(self):
+ return False
+
+ def timedout(self):
+ return False
+
+
+class AsyncFakeBoltPool(AsyncIOPool):
+
+ def __init__(self, address, *, auth=None, **config):
+ self.pool_config, self.workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig)
+ if config:
+ raise ValueError("Unexpected config keys: %s" % ", ".join(config.keys()))
+
+ async def opener(addr, timeout):
+ return AsyncQuickConnection(AsyncFakeSocket(addr))
+
+ super().__init__(opener, self.pool_config, self.workspace_config)
+ self.address = address
+
+ async def acquire(
+ self, access_mode=None, timeout=None, database=None, bookmarks=None
+ ):
+ return await self._acquire(self.address, timeout)
+
+
+@mark_async_test
+async def test_bolt_connection_open():
+ with pytest.raises(ServiceUnavailable):
+ await AsyncBolt.open(("localhost", 9999), auth=("test", "test"))
+
+
+@mark_async_test
+async def test_bolt_connection_open_timeout():
+ with pytest.raises(ServiceUnavailable):
+ await AsyncBolt.open(("localhost", 9999), auth=("test", "test"),
+ timeout=1)
+
+
+@mark_async_test
+async def test_bolt_connection_ping():
+ protocol_version = await AsyncBolt.ping(("localhost", 9999))
+ assert protocol_version is None
+
+
+@mark_async_test
+async def test_bolt_connection_ping_timeout():
+ protocol_version = await AsyncBolt.ping(("localhost", 9999), timeout=1)
+ assert protocol_version is None
+
+
+@pytest.fixture
+async def pool():
+ async with AsyncFakeBoltPool(("127.0.0.1", 7687)) as pool:
+ yield pool
+
+
+def assert_pool_size( address, expected_active, expected_inactive, pool):
+ try:
+ connections = pool.connections[address]
+ except KeyError:
+ assert 0 == expected_active
+ assert 0 == expected_inactive
+ else:
+ assert expected_active == len([cx for cx in connections if cx.in_use])
+ assert (expected_inactive
+ == len([cx for cx in connections if not cx.in_use]))
+
+
+@mark_async_test
+async def test_pool_can_acquire(pool):
+ address = ("127.0.0.1", 7687)
+ connection = await pool._acquire(address, timeout=3)
+ assert connection.address == address
+ assert_pool_size(address, 1, 0, pool)
+
+
+@mark_async_test
+async def test_pool_can_acquire_twice(pool):
+ address = ("127.0.0.1", 7687)
+ connection_1 = await pool._acquire(address, timeout=3)
+ connection_2 = await pool._acquire(address, timeout=3)
+ assert connection_1.address == address
+ assert connection_2.address == address
+ assert connection_1 is not connection_2
+ assert_pool_size(address, 2, 0, pool)
+
+
+@mark_async_test
+async def test_pool_can_acquire_two_addresses(pool):
+ address_1 = ("127.0.0.1", 7687)
+ address_2 = ("127.0.0.1", 7474)
+ connection_1 = await pool._acquire(address_1, timeout=3)
+ connection_2 = await pool._acquire(address_2, timeout=3)
+ assert connection_1.address == address_1
+ assert connection_2.address == address_2
+ assert_pool_size(address_1, 1, 0, pool)
+ assert_pool_size(address_2, 1, 0, pool)
+
+
+@mark_async_test
+async def test_pool_can_acquire_and_release(pool):
+ address = ("127.0.0.1", 7687)
+ connection = await pool._acquire(address, timeout=3)
+ assert_pool_size(address, 1, 0, pool)
+ await pool.release(connection)
+ assert_pool_size(address, 0, 1, pool)
+
+
+@mark_async_test
+async def test_pool_releasing_twice(pool):
+ address = ("127.0.0.1", 7687)
+ connection = await pool._acquire(address, timeout=3)
+ await pool.release(connection)
+ assert_pool_size(address, 0, 1, pool)
+ await pool.release(connection)
+ assert_pool_size(address, 0, 1, pool)
+
+
+@mark_async_test
+async def test_pool_in_use_count(pool):
+ address = ("127.0.0.1", 7687)
+ assert pool.in_use_connection_count(address) == 0
+ connection = await pool._acquire(address, timeout=3)
+ assert pool.in_use_connection_count(address) == 1
+ await pool.release(connection)
+ assert pool.in_use_connection_count(address) == 0
+
+
+@mark_async_test
+async def test_pool_max_conn_pool_size(pool):
+ async with AsyncFakeBoltPool((), max_connection_pool_size=1) as pool:
+ address = ("127.0.0.1", 7687)
+ await pool._acquire(address, timeout=0)
+ assert pool.in_use_connection_count(address) == 1
+ with pytest.raises(ClientError):
+ await pool._acquire(address, timeout=0)
+ assert pool.in_use_connection_count(address) == 1
+
+
+@pytest.mark.parametrize("is_reset", (True, False))
+@mark_async_test
+async def test_pool_reset_when_released(is_reset, pool):
+ address = ("127.0.0.1", 7687)
+ quick_connection_name = AsyncQuickConnection.__name__
+ with mock.patch(f"{__name__}.{quick_connection_name}.is_reset",
+ new_callable=mock.PropertyMock) as is_reset_mock:
+ with mock.patch(f"{__name__}.{quick_connection_name}.reset",
+ new_callable=AsyncMock) as reset_mock:
+ is_reset_mock.return_value = is_reset
+ connection = await pool._acquire(address, timeout=3)
+ assert isinstance(connection, AsyncQuickConnection)
+ assert is_reset_mock.call_count == 0
+ assert reset_mock.call_count == 0
+ await pool.release(connection)
+ assert is_reset_mock.call_count == 1
+ assert reset_mock.call_count == int(not is_reset)
diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py
new file mode 100644
index 00000000..3962a4de
--- /dev/null
+++ b/tests/unit/async_/io/test_neo4j_pool.py
@@ -0,0 +1,259 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from unittest.mock import Mock
+
+import pytest
+
+from neo4j import (
+ READ_ACCESS,
+ WRITE_ACCESS,
+)
+from neo4j._async.io import AsyncNeo4jPool
+from neo4j.addressing import ResolvedAddress
+from neo4j.conf import (
+ PoolConfig,
+ RoutingConfig,
+ WorkspaceConfig,
+)
+
+from ..._async_compat import (
+ AsyncMock,
+ mark_async_test,
+)
+from ..work import AsyncFakeConnection
+
+
+ROUTER_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host")
+READER_ADDRESS = ResolvedAddress(("1.2.3.1", 9002), host_name="host")
+WRITER_ADDRESS = ResolvedAddress(("1.2.3.1", 9003), host_name="host")
+
+
+@pytest.fixture()
+def opener():
+ async def open_(addr, timeout):
+ connection = AsyncFakeConnection()
+ connection.addr = addr
+ connection.timeout = timeout
+ route_mock = AsyncMock()
+ route_mock.return_value = [{
+ "ttl": 1000,
+ "servers": [
+ {"addresses": [str(ROUTER_ADDRESS)], "role": "ROUTE"},
+ {"addresses": [str(READER_ADDRESS)], "role": "READ"},
+ {"addresses": [str(WRITER_ADDRESS)], "role": "WRITE"},
+ ],
+ }]
+ connection.attach_mock(route_mock, "route")
+ opener_.connections.append(connection)
+ return connection
+
+ opener_ = AsyncMock()
+ opener_.connections = []
+ opener_.side_effect = open_
+ return opener_
+
+
+@mark_async_test
+async def test_acquires_new_routing_table_if_deleted(opener):
+ pool = AsyncNeo4jPool(
+ opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS
+ )
+ cx = await pool.acquire(READ_ACCESS, 30, "test_db", None)
+ await pool.release(cx)
+ assert pool.routing_tables.get("test_db")
+
+ del pool.routing_tables["test_db"]
+
+ cx = await pool.acquire(READ_ACCESS, 30, "test_db", None)
+ await pool.release(cx)
+ assert pool.routing_tables.get("test_db")
+
+
+@mark_async_test
+async def test_acquires_new_routing_table_if_stale(opener):
+ pool = AsyncNeo4jPool(
+ opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS
+ )
+ cx = await pool.acquire(READ_ACCESS, 30, "test_db", None)
+ await pool.release(cx)
+ assert pool.routing_tables.get("test_db")
+
+ old_value = pool.routing_tables["test_db"].last_updated_time
+ pool.routing_tables["test_db"].ttl = 0
+
+ cx = await pool.acquire(READ_ACCESS, 30, "test_db", None)
+ await pool.release(cx)
+ assert pool.routing_tables["test_db"].last_updated_time > old_value
+
+
+@mark_async_test
+async def test_removes_old_routing_table(opener):
+ pool = AsyncNeo4jPool(
+ opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS
+ )
+ cx = await pool.acquire(READ_ACCESS, 30, "test_db1", None)
+ await pool.release(cx)
+ assert pool.routing_tables.get("test_db1")
+ cx = await pool.acquire(READ_ACCESS, 30, "test_db2", None)
+ await pool.release(cx)
+ assert pool.routing_tables.get("test_db2")
+
+ old_value = pool.routing_tables["test_db1"].last_updated_time
+ pool.routing_tables["test_db1"].ttl = 0
+ pool.routing_tables["test_db2"].ttl = \
+ -RoutingConfig.routing_table_purge_delay
+
+ cx = await pool.acquire(READ_ACCESS, 30, "test_db1", None)
+ await pool.release(cx)
+ assert pool.routing_tables["test_db1"].last_updated_time > old_value
+ assert "test_db2" not in pool.routing_tables
+
+
+@pytest.mark.parametrize("type_", ("r", "w"))
+@mark_async_test
+async def test_chooses_right_connection_type(opener, type_):
+ pool = AsyncNeo4jPool(
+ opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS
+ )
+ cx1 = await pool.acquire(READ_ACCESS if type_ == "r" else WRITE_ACCESS,
+ 30, "test_db", None)
+ await pool.release(cx1)
+ if type_ == "r":
+ assert cx1.addr == READER_ADDRESS
+ else:
+ assert cx1.addr == WRITER_ADDRESS
+
+
+@mark_async_test
+async def test_reuses_connection(opener):
+ pool = AsyncNeo4jPool(
+ opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS
+ )
+ cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None)
+ await pool.release(cx1)
+ cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None)
+ assert cx1 is cx2
+
+
+@pytest.mark.parametrize("break_on_close", (True, False))
+@mark_async_test
+async def test_closes_stale_connections(opener, break_on_close):
+ def break_connection():
+ pool.deactivate(cx1.addr)
+
+ if cx_close_mock_side_effect:
+ cx_close_mock_side_effect()
+
+ pool = AsyncNeo4jPool(
+ opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS
+ )
+ cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None)
+ await pool.release(cx1)
+ assert cx1 in pool.connections[cx1.addr]
+ # simulate connection going stale (e.g. exceeding) and then breaking when
+ # the pool tries to close the connection
+ cx1.stale.return_value = True
+ cx_close_mock = cx1.close
+ if break_on_close:
+ cx_close_mock_side_effect = cx_close_mock.side_effect
+ cx_close_mock.side_effect = break_connection
+ cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None)
+ await pool.release(cx2)
+ if break_on_close:
+ cx1.close.assert_called()
+ else:
+ cx1.close.assert_called_once()
+ assert cx2 is not cx1
+ assert cx2.addr == cx1.addr
+ assert cx1 not in pool.connections[cx1.addr]
+ assert cx2 in pool.connections[cx2.addr]
+
+
+@mark_async_test
+async def test_does_not_close_stale_connections_in_use(opener):
+ pool = AsyncNeo4jPool(
+ opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS
+ )
+ cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None)
+ assert cx1 in pool.connections[cx1.addr]
+ # simulate connection going stale (e.g. exceeding) while being in use
+ cx1.stale.return_value = True
+ cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None)
+ await pool.release(cx2)
+ cx1.close.assert_not_called()
+ assert cx2 is not cx1
+ assert cx2.addr == cx1.addr
+ assert cx1 in pool.connections[cx1.addr]
+ assert cx2 in pool.connections[cx2.addr]
+
+ await pool.release(cx1)
+ # now that cx1 is back in the pool and still stale,
+ # it should be closed when trying to acquire the next connection
+ cx1.close.assert_not_called()
+
+ cx3 = await pool.acquire(READ_ACCESS, 30, "test_db", None)
+ await pool.release(cx3)
+ cx1.close.assert_called_once()
+ assert cx2 is cx3
+ assert cx3.addr == cx1.addr
+ assert cx1 not in pool.connections[cx1.addr]
+ assert cx3 in pool.connections[cx2.addr]
+
+
+@mark_async_test
+async def test_release_resets_connections(opener):
+ pool = AsyncNeo4jPool(
+ opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS
+ )
+ cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None)
+ cx1.is_reset_mock.return_value = False
+ cx1.is_reset_mock.reset_mock()
+ await pool.release(cx1)
+ cx1.is_reset_mock.assert_called_once()
+ cx1.reset.assert_called_once()
+
+
+@mark_async_test
+async def test_release_does_not_resets_closed_connections(opener):
+ pool = AsyncNeo4jPool(
+ opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS
+ )
+ cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None)
+ cx1.closed.return_value = True
+ cx1.closed.reset_mock()
+ cx1.is_reset_mock.reset_mock()
+ await pool.release(cx1)
+ cx1.closed.assert_called_once()
+ cx1.is_reset_mock.asset_not_called()
+ cx1.reset.asset_not_called()
+
+
+@mark_async_test
+async def test_release_does_not_resets_defunct_connections(opener):
+ pool = AsyncNeo4jPool(
+ opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS
+ )
+ cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None)
+ cx1.defunct.return_value = True
+ cx1.defunct.reset_mock()
+ cx1.is_reset_mock.reset_mock()
+ await pool.release(cx1)
+ cx1.defunct.assert_called_once()
+ cx1.is_reset_mock.asset_not_called()
+ cx1.reset.asset_not_called()
diff --git a/tests/unit/async_/test_addressing.py b/tests/unit/async_/test_addressing.py
new file mode 100644
index 00000000..69a5556f
--- /dev/null
+++ b/tests/unit/async_/test_addressing.py
@@ -0,0 +1,125 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from socket import (
+ AF_INET,
+ AF_INET6,
+)
+import unittest.mock as mock
+
+import pytest
+
+from neo4j import (
+ Address,
+ IPv4Address,
+)
+from neo4j._async_compat.network import AsyncNetworkUtil
+from neo4j._async_compat.util import AsyncUtil
+
+from .._async_compat import mark_async_test
+
+
+mock_socket_ipv4 = mock.Mock()
+mock_socket_ipv4.getpeername = lambda: ("127.0.0.1", 7687) # (address, port)
+
+mock_socket_ipv6 = mock.Mock()
+mock_socket_ipv6.getpeername = lambda: ("[::1]", 7687, 0, 0) # (address, port, flow info, scope id)
+
+
+@mark_async_test
+async def test_address_resolve():
+ address = Address(("127.0.0.1", 7687))
+ resolved = AsyncNetworkUtil.resolve_address(address)
+ resolved = await AsyncUtil.list(resolved)
+ assert isinstance(resolved, Address) is False
+ assert isinstance(resolved, list) is True
+ assert len(resolved) == 1
+ assert resolved[0] == IPv4Address(('127.0.0.1', 7687))
+
+
+@mark_async_test
+async def test_address_resolve_with_custom_resolver_none():
+ address = Address(("127.0.0.1", 7687))
+ resolved = AsyncNetworkUtil.resolve_address(address, resolver=None)
+ resolved = await AsyncUtil.list(resolved)
+ assert isinstance(resolved, Address) is False
+ assert isinstance(resolved, list) is True
+ assert len(resolved) == 1
+ assert resolved[0] == IPv4Address(('127.0.0.1', 7687))
+
+
+@pytest.mark.parametrize(
+ "test_input, expected",
+ [
+ (Address(("127.0.0.1", "abcd")), ValueError),
+ (Address((None, None)), ValueError),
+ ]
+
+)
+@mark_async_test
+async def test_address_resolve_with_unresolvable_address(test_input, expected):
+ with pytest.raises(expected):
+ await AsyncUtil.list(
+ AsyncNetworkUtil.resolve_address(test_input, resolver=None)
+ )
+
+
+@mark_async_test
+@pytest.mark.parametrize("resolver_type", ("sync", "async"))
+async def test_address_resolve_with_custom_resolver(resolver_type):
+ def custom_resolver_sync(_):
+ return [("127.0.0.1", 7687), ("localhost", 1234)]
+
+ async def custom_resolver_async(_):
+ return [("127.0.0.1", 7687), ("localhost", 1234)]
+
+ if resolver_type == "sync":
+ custom_resolver = custom_resolver_sync
+ else:
+ custom_resolver = custom_resolver_async
+
+ address = Address(("127.0.0.1", 7687))
+ resolved = AsyncNetworkUtil.resolve_address(
+ address, family=AF_INET, resolver=custom_resolver
+ )
+ resolved = await AsyncUtil.list(resolved)
+ assert isinstance(resolved, Address) is False
+ assert isinstance(resolved, list) is True
+ assert len(resolved) == 2 # IPv4 only
+ assert resolved[0] == IPv4Address(('127.0.0.1', 7687))
+ assert resolved[1] == IPv4Address(('127.0.0.1', 1234))
+
+
+@mark_async_test
+async def test_address_unresolve():
+ custom_resolved = [("127.0.0.1", 7687), ("localhost", 4321)]
+ custom_resolver = lambda _: custom_resolved
+
+ address = Address(("foobar", 1234))
+ unresolved = address.unresolved
+ assert address.__class__ == unresolved.__class__
+ assert address == unresolved
+ resolved = AsyncNetworkUtil.resolve_address(
+ address, family=AF_INET, resolver=custom_resolver
+ )
+ resolved = await AsyncUtil.list(resolved)
+ custom_resolved = sorted(Address(a) for a in custom_resolved)
+ unresolved = sorted(a.unresolved for a in resolved)
+ assert custom_resolved == unresolved
+ assert (list(map(lambda a: a.__class__, custom_resolved))
+ == list(map(lambda a: a.__class__, unresolved)))
diff --git a/tests/unit/async_/test_driver.py b/tests/unit/async_/test_driver.py
new file mode 100644
index 00000000..7f23af09
--- /dev/null
+++ b/tests/unit/async_/test_driver.py
@@ -0,0 +1,157 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import pytest
+
+from neo4j import (
+ AsyncBoltDriver,
+ AsyncGraphDatabase,
+ AsyncNeo4jDriver,
+ TRUST_ALL_CERTIFICATES,
+ TRUST_SYSTEM_CA_SIGNED_CERTIFICATES,
+)
+from neo4j.api import WRITE_ACCESS
+from neo4j.exceptions import ConfigurationError
+
+from .._async_compat import (
+ mark_async_test,
+ mock,
+)
+
+
+@pytest.mark.parametrize("protocol", ("bolt://", "bolt+s://", "bolt+ssc://"))
+@pytest.mark.parametrize("host", ("localhost", "127.0.0.1",
+ "[::1]", "[0:0:0:0:0:0:0:1]"))
+@pytest.mark.parametrize("port", (":1234", "", ":7687"))
+@pytest.mark.parametrize("auth_token", (("test", "test"), None))
+def test_direct_driver_constructor(protocol, host, port, auth_token):
+ uri = protocol + host + port
+ driver = AsyncGraphDatabase.driver(uri, auth=auth_token)
+ assert isinstance(driver, AsyncBoltDriver)
+
+
+@pytest.mark.parametrize("protocol", ("neo4j://", "neo4j+s://", "neo4j+ssc://"))
+@pytest.mark.parametrize("host", ("localhost", "127.0.0.1",
+ "[::1]", "[0:0:0:0:0:0:0:1]"))
+@pytest.mark.parametrize("port", (":1234", "", ":7687"))
+@pytest.mark.parametrize("auth_token", (("test", "test"), None))
+def test_routing_driver_constructor(protocol, host, port, auth_token):
+ uri = protocol + host + port
+ driver = AsyncGraphDatabase.driver(uri, auth=auth_token)
+ assert isinstance(driver, AsyncNeo4jDriver)
+
+
+@pytest.mark.parametrize("test_uri", (
+ "bolt+ssc://127.0.0.1:9001",
+ "bolt+s://127.0.0.1:9001",
+ "bolt://127.0.0.1:9001",
+ "neo4j+ssc://127.0.0.1:9001",
+ "neo4j+s://127.0.0.1:9001",
+ "neo4j://127.0.0.1:9001",
+))
+@pytest.mark.parametrize(
+ ("test_config", "expected_failure", "expected_failure_message"),
+ (
+ ({"encrypted": False}, ConfigurationError, "The config settings"),
+ ({"encrypted": True}, ConfigurationError, "The config settings"),
+ (
+ {"encrypted": True, "trust": TRUST_ALL_CERTIFICATES},
+ ConfigurationError, "The config settings"
+ ),
+ (
+ {"trust": TRUST_ALL_CERTIFICATES},
+ ConfigurationError, "The config settings"
+ ),
+ (
+ {"trust": TRUST_SYSTEM_CA_SIGNED_CERTIFICATES},
+ ConfigurationError, "The config settings"
+ ),
+ )
+)
+def test_driver_config_error(
+ test_uri, test_config, expected_failure, expected_failure_message
+):
+ if "+" in test_uri:
+ # `+s` and `+ssc` are short hand syntax for not having to configure the
+ # encryption behavior of the driver. Specifying both is invalid.
+ with pytest.raises(expected_failure, match=expected_failure_message):
+ AsyncGraphDatabase.driver(test_uri, **test_config)
+ else:
+ AsyncGraphDatabase.driver(test_uri, **test_config)
+
+
+@pytest.mark.parametrize("test_uri", (
+ "http://localhost:9001",
+ "ftp://localhost:9001",
+ "x://localhost:9001",
+))
+def test_invalid_protocol(test_uri):
+ with pytest.raises(ConfigurationError, match="scheme"):
+ AsyncGraphDatabase.driver(test_uri)
+
+
+@pytest.mark.parametrize(
+ ("test_config", "expected_failure", "expected_failure_message"),
+ (
+ ({"trust": 1}, ConfigurationError, "The config setting `trust`"),
+ ({"trust": True}, ConfigurationError, "The config setting `trust`"),
+ ({"trust": None}, ConfigurationError, "The config setting `trust`"),
+ )
+)
+def test_driver_trust_config_error(
+ test_config, expected_failure, expected_failure_message
+):
+ with pytest.raises(expected_failure, match=expected_failure_message):
+ AsyncGraphDatabase.driver("bolt://127.0.0.1:9001", **test_config)
+
+
+@pytest.mark.parametrize("uri", (
+ "bolt://127.0.0.1:9000",
+ "neo4j://127.0.0.1:9000",
+))
+@mark_async_test
+async def test_driver_opens_write_session_by_default(uri, mocker):
+ driver = AsyncGraphDatabase.driver(uri)
+ from neo4j import AsyncTransaction
+
+ # we set a specific db, because else the driver would try to fetch a RT
+ # to get hold of the actual home database (which won't work in this
+ # unittest)
+ async with driver.session(database="foobar") as session:
+ with mock.patch.object(
+ session._pool, "acquire", autospec=True
+ ) as acquire_mock:
+ with mock.patch.object(
+ AsyncTransaction, "_begin", autospec=True
+ ) as tx_begin_mock:
+ tx = await session.begin_transaction()
+ acquire_mock.assert_called_once_with(
+ access_mode=WRITE_ACCESS,
+ timeout=mocker.ANY,
+ database=mocker.ANY,
+ bookmarks=mocker.ANY
+ )
+ tx_begin_mock.assert_called_once_with(
+ tx,
+ mocker.ANY,
+ mocker.ANY,
+ mocker.ANY,
+ WRITE_ACCESS,
+ mocker.ANY,
+ mocker.ANY
+ )
diff --git a/tests/unit/async_/work/__init__.py b/tests/unit/async_/work/__init__.py
new file mode 100644
index 00000000..3bfbf0ed
--- /dev/null
+++ b/tests/unit/async_/work/__init__.py
@@ -0,0 +1,22 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from ._fake_connection import (
+ async_fake_connection,
+ AsyncFakeConnection,
+)
diff --git a/tests/unit/async_/work/_fake_connection.py b/tests/unit/async_/work/_fake_connection.py
new file mode 100644
index 00000000..c3bf9b96
--- /dev/null
+++ b/tests/unit/async_/work/_fake_connection.py
@@ -0,0 +1,111 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import inspect
+
+import pytest
+
+from neo4j import ServerInfo
+from neo4j._async.io import AsyncBolt
+
+from ..._async_compat import (
+ AsyncMock,
+ mock,
+ Mock,
+)
+
+
+class AsyncFakeConnection(mock.NonCallableMagicMock):
+ callbacks = []
+ server_info = ServerInfo("127.0.0.1", (4, 3))
+
+ def __init__(self, *args, **kwargs):
+ kwargs["spec"] = AsyncBolt
+ super().__init__(*args, **kwargs)
+ self.attach_mock(Mock(return_value=True), "is_reset_mock")
+ self.attach_mock(Mock(return_value=False), "defunct")
+ self.attach_mock(Mock(return_value=False), "stale")
+ self.attach_mock(Mock(return_value=False), "closed")
+ self.attach_mock(Mock(), "unresolved_address")
+
+ def close_side_effect():
+ self.closed.return_value = True
+
+ self.attach_mock(AsyncMock(side_effect=close_side_effect),
+ "close")
+
+ @property
+ def is_reset(self):
+ if self.closed.return_value or self.defunct.return_value:
+ raise AssertionError(
+ "is_reset should not be called on a closed or defunct "
+ "connection."
+ )
+ return self.is_reset_mock()
+
+ async def fetch_message(self, *args, **kwargs):
+ if self.callbacks:
+ cb = self.callbacks.pop(0)
+ await cb()
+ return await super().__getattr__("fetch_message")(*args, **kwargs)
+
+ async def fetch_all(self, *args, **kwargs):
+ while self.callbacks:
+ cb = self.callbacks.pop(0)
+ cb()
+ return await super().__getattr__("fetch_all")(*args, **kwargs)
+
+ def __getattr__(self, name):
+ parent = super()
+
+ def build_message_handler(name):
+ def func(*args, **kwargs):
+ async def callback():
+ for cb_name, param_count in (
+ ("on_success", 1),
+ ("on_summary", 0)
+ ):
+ cb = kwargs.get(cb_name, None)
+ if callable(cb):
+ try:
+ param_count = \
+ len(inspect.signature(cb).parameters)
+ except ValueError:
+ # e.g. built-in method as cb
+ pass
+ if param_count == 1:
+ res = cb({})
+ else:
+ res = cb()
+ try:
+ await res # maybe the callback is async
+ except TypeError:
+ pass # or maybe it wasn't ;)
+ self.callbacks.append(callback)
+
+ return func
+
+ method_mock = parent.__getattr__(name)
+ if name in ("run", "commit", "pull", "rollback", "discard"):
+ method_mock.side_effect = build_message_handler(name)
+ return method_mock
+
+
+@pytest.fixture
+def async_fake_connection():
+ return AsyncFakeConnection()
diff --git a/tests/unit/async_/work/test_result.py b/tests/unit/async_/work/test_result.py
new file mode 100644
index 00000000..d73acaab
--- /dev/null
+++ b/tests/unit/async_/work/test_result.py
@@ -0,0 +1,456 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from unittest import mock
+
+import pytest
+
+from neo4j import (
+ Address,
+ AsyncResult,
+ Record,
+ ResultSummary,
+ ServerInfo,
+ SummaryCounters,
+ Version,
+)
+from neo4j._async_compat.util import AsyncUtil
+from neo4j.data import DataHydrator
+
+from ..._async_compat import mark_async_test
+
+
+class Records:
+ def __init__(self, fields, records):
+ assert all(len(fields) == len(r) for r in records)
+ self.fields = fields
+ # self.records = [{"record_values": r} for r in records]
+ self.records = records
+
+ def __len__(self):
+ return self.records.__len__()
+
+ def __iter__(self):
+ return self.records.__iter__()
+
+ def __getitem__(self, item):
+ return self.records.__getitem__(item)
+
+
+class AsyncConnectionStub:
+ class Message:
+ def __init__(self, message, *args, **kwargs):
+ self.message = message
+ self.args = args
+ self.kwargs = kwargs
+
+ async def _cb(self, cb_name, *args, **kwargs):
+ # print(self.message, cb_name.upper(), args, kwargs)
+ cb = self.kwargs.get(cb_name)
+ await AsyncUtil.callback(cb, *args, **kwargs)
+
+ async def on_success(self, metadata):
+ await self._cb("on_success", metadata)
+
+ async def on_summary(self):
+ await self._cb("on_summary")
+
+ async def on_records(self, records):
+ await self._cb("on_records", records)
+
+ def __eq__(self, other):
+ return self.message == other
+
+ def __repr__(self):
+ return "Message(%s)" % self.message
+
+ def __init__(self, records=None, run_meta=None, summary_meta=None,
+ force_qid=False):
+ self._multi_result = isinstance(records, (list, tuple))
+ if self._multi_result:
+ self._records = records
+ self._use_qid = True
+ else:
+ self._records = records,
+ self._use_qid = force_qid
+ self.fetch_idx = 0
+ self._qid = -1
+ self.most_recent_qid = None
+ self.record_idxs = [0] * len(self._records)
+ self.to_pull = [None] * len(self._records)
+ self._exhausted = [False] * len(self._records)
+ self.queued = []
+ self.sent = []
+ self.run_meta = run_meta
+ self.summary_meta = summary_meta
+ AsyncConnectionStub.server_info.update({"server": "Neo4j/4.3.0"})
+ self.unresolved_address = None
+
+ async def send_all(self):
+ self.sent += self.queued
+ self.queued = []
+
+ async def fetch_message(self):
+ if self.fetch_idx >= len(self.sent):
+ pytest.fail("Waits for reply to never sent message")
+ msg = self.sent[self.fetch_idx]
+ if msg == "RUN":
+ self.fetch_idx += 1
+ self._qid += 1
+ meta = {"fields": self._records[self._qid].fields,
+ **(self.run_meta or {})}
+ if self._use_qid:
+ meta.update(qid=self._qid)
+ await msg.on_success(meta)
+ elif msg == "DISCARD":
+ self.fetch_idx += 1
+ qid = msg.kwargs.get("qid", -1)
+ if qid < 0:
+ qid = self._qid
+ self.record_idxs[qid] = len(self._records[qid])
+ await msg.on_success(self.summary_meta or {})
+ await msg.on_summary()
+ elif msg == "PULL":
+ qid = msg.kwargs.get("qid", -1)
+ if qid < 0:
+ qid = self._qid
+ if self._exhausted[qid]:
+ pytest.fail("PULLing exhausted result")
+ if self.to_pull[qid] is None:
+ n = msg.kwargs.get("n", -1)
+ if n < 0:
+ n = len(self._records[qid])
+ self.to_pull[qid] = \
+ min(n, len(self._records[qid]) - self.record_idxs[qid])
+ # if to == len(self._records):
+ # self.fetch_idx += 1
+ if self.to_pull[qid] > 0:
+ record = self._records[qid][self.record_idxs[qid]]
+ self.record_idxs[qid] += 1
+ self.to_pull[qid] -= 1
+ await msg.on_records([record])
+ elif self.to_pull[qid] == 0:
+ self.to_pull[qid] = None
+ self.fetch_idx += 1
+ if self.record_idxs[qid] < len(self._records[qid]):
+ await msg.on_success({"has_more": True})
+ else:
+ await msg.on_success(
+ {"bookmark": "foo", **(self.summary_meta or {})}
+ )
+ self._exhausted[qid] = True
+ await msg.on_summary()
+
+ async def fetch_all(self):
+ while self.fetch_idx < len(self.sent):
+ await self.fetch_message()
+
+ def run(self, *args, **kwargs):
+ self.queued.append(AsyncConnectionStub.Message("RUN", *args, **kwargs))
+
+ def discard(self, *args, **kwargs):
+ self.queued.append(AsyncConnectionStub.Message("DISCARD", *args, **kwargs))
+
+ def pull(self, *args, **kwargs):
+ self.queued.append(AsyncConnectionStub.Message("PULL", *args, **kwargs))
+
+ server_info = ServerInfo(Address(("bolt://localhost", 7687)), Version(4, 3))
+
+ def defunct(self):
+ return False
+
+
+class HydratorStub(DataHydrator):
+ def hydrate(self, values):
+ return values
+
+
+def noop(*_, **__):
+ pass
+
+
+async def fetch_and_compare_all_records(
+ result, key, expected_records, method, limit=None
+):
+ received_records = []
+ if method == "for loop":
+ async for record in result:
+ assert isinstance(record, Record)
+ received_records.append([record.data().get(key, None)])
+ if limit is not None and len(received_records) == limit:
+ break
+ if limit is None:
+ assert result._closed
+ elif method == "next":
+ iter_ = AsyncUtil.iter(result)
+ n = len(expected_records) if limit is None else limit
+ for _ in range(n):
+ record = await AsyncUtil.next(iter_)
+ received_records.append([record.get(key, None)])
+ if limit is None:
+ with pytest.raises(StopAsyncIteration):
+ await AsyncUtil.next(iter_)
+ assert result._closed
+ elif method == "new iter":
+ n = len(expected_records) if limit is None else limit
+ for _ in range(n):
+ iter_ = AsyncUtil.iter(result)
+ record = await AsyncUtil.next(iter_)
+ received_records.append([record.get(key, None)])
+ if limit is None:
+ iter_ = AsyncUtil.iter(result)
+ with pytest.raises(StopAsyncIteration):
+ await AsyncUtil.next(iter_)
+ assert result._closed
+ else:
+ raise ValueError()
+ assert received_records == expected_records
+
+
+@pytest.mark.parametrize("method", ("for loop", "next", "new iter"))
+@pytest.mark.parametrize("records", (
+ [],
+ [[42]],
+ [[1], [2], [3], [4], [5]],
+))
+@mark_async_test
+async def test_result_iteration(method, records):
+ connection = AsyncConnectionStub(records=Records(["x"], records))
+ result = AsyncResult(connection, HydratorStub(), 2, noop, noop)
+ await result._run("CYPHER", {}, None, None, "r", None)
+ await fetch_and_compare_all_records(result, "x", records, method)
+
+
+@pytest.mark.parametrize("method", ("for loop", "next", "new iter"))
+@pytest.mark.parametrize("invert_fetch", (True, False))
+@mark_async_test
+async def test_parallel_result_iteration(method, invert_fetch):
+ records1 = [[i] for i in range(1, 6)]
+ records2 = [[i] for i in range(6, 11)]
+ connection = AsyncConnectionStub(
+ records=(Records(["x"], records1), Records(["x"], records2))
+ )
+ result1 = AsyncResult(connection, HydratorStub(), 2, noop, noop)
+ await result1._run("CYPHER1", {}, None, None, "r", None)
+ result2 = AsyncResult(connection, HydratorStub(), 2, noop, noop)
+ await result2._run("CYPHER2", {}, None, None, "r", None)
+ if invert_fetch:
+ await fetch_and_compare_all_records(
+ result2, "x", records2, method
+ )
+ await fetch_and_compare_all_records(
+ result1, "x", records1, method
+ )
+ else:
+ await fetch_and_compare_all_records(
+ result1, "x", records1, method
+ )
+ await fetch_and_compare_all_records(
+ result2, "x", records2, method
+ )
+
+
+@pytest.mark.parametrize("method", ("for loop", "next", "new iter"))
+@pytest.mark.parametrize("invert_fetch", (True, False))
+@mark_async_test
+async def test_interwoven_result_iteration(method, invert_fetch):
+ records1 = [[i] for i in range(1, 10)]
+ records2 = [[i] for i in range(11, 20)]
+ connection = AsyncConnectionStub(
+ records=(Records(["x"], records1), Records(["y"], records2))
+ )
+ result1 = AsyncResult(connection, HydratorStub(), 2, noop, noop)
+ await result1._run("CYPHER1", {}, None, None, "r", None)
+ result2 = AsyncResult(connection, HydratorStub(), 2, noop, noop)
+ await result2._run("CYPHER2", {}, None, None, "r", None)
+ start = 0
+ for n in (1, 2, 3, 1, None):
+ end = n if n is None else start + n
+ if invert_fetch:
+ await fetch_and_compare_all_records(
+ result2, "y", records2[start:end], method, n
+ )
+ await fetch_and_compare_all_records(
+ result1, "x", records1[start:end], method, n
+ )
+ else:
+ await fetch_and_compare_all_records(
+ result1, "x", records1[start:end], method, n
+ )
+ await fetch_and_compare_all_records(
+ result2, "y", records2[start:end], method, n
+ )
+ start = end
+
+
+@pytest.mark.parametrize("records", ([[1], [2]], [[1]], []))
+@pytest.mark.parametrize("fetch_size", (1, 2))
+@mark_async_test
+async def test_result_peek(records, fetch_size):
+ connection = AsyncConnectionStub(records=Records(["x"], records))
+ result = AsyncResult(connection, HydratorStub(), fetch_size, noop, noop)
+ await result._run("CYPHER", {}, None, None, "r", None)
+ for i in range(len(records) + 1):
+ record = await result.peek()
+ if i == len(records):
+ assert record is None
+ else:
+ assert isinstance(record, Record)
+ assert record.get("x") == records[i][0]
+ iter_ = AsyncUtil.iter(result)
+ await AsyncUtil.next(iter_) # consume the record
+
+
+@pytest.mark.parametrize("records", ([[1], [2]], [[1]], []))
+@pytest.mark.parametrize("fetch_size", (1, 2))
+@mark_async_test
+async def test_result_single(records, fetch_size):
+ connection = AsyncConnectionStub(records=Records(["x"], records))
+ result = AsyncResult(connection, HydratorStub(), fetch_size, noop, noop)
+ await result._run("CYPHER", {}, None, None, "r", None)
+ with pytest.warns(None) as warning_record:
+ record = await result.single()
+ if not records:
+ assert not warning_record
+ assert record is None
+ else:
+ if len(records) > 1:
+ assert len(warning_record) == 1
+ else:
+ assert not warning_record
+ assert isinstance(record, Record)
+ assert record.get("x") == records[0][0]
+
+
+@mark_async_test
+async def test_keys_are_available_before_and_after_stream():
+ connection = AsyncConnectionStub(records=Records(["x"], [[1], [2]]))
+ result = AsyncResult(connection, HydratorStub(), 1, noop, noop)
+ await result._run("CYPHER", {}, None, None, "r", None)
+ assert list(result.keys()) == ["x"]
+ await AsyncUtil.list(result)
+ assert list(result.keys()) == ["x"]
+
+
+@pytest.mark.parametrize("records", ([[1], [2]], [[1]], []))
+@pytest.mark.parametrize("consume_one", (True, False))
+@pytest.mark.parametrize("summary_meta", (None, {"database": "foobar"}))
+@mark_async_test
+async def test_consume(records, consume_one, summary_meta):
+ connection = AsyncConnectionStub(
+ records=Records(["x"], records), summary_meta=summary_meta
+ )
+ result = AsyncResult(connection, HydratorStub(), 1, noop, noop)
+ await result._run("CYPHER", {}, None, None, "r", None)
+ if consume_one:
+ try:
+ await AsyncUtil.next(AsyncUtil.iter(result))
+ except StopAsyncIteration:
+ pass
+ summary = await result.consume()
+ assert isinstance(summary, ResultSummary)
+ if summary_meta and "db" in summary_meta:
+ assert summary.database == summary_meta["db"]
+ else:
+ assert summary.database is None
+ server_info = summary.server
+ assert isinstance(server_info, ServerInfo)
+ assert server_info.version_info() == Version(4, 3)
+ assert server_info.protocol_version == Version(4, 3)
+ assert isinstance(summary.counters, SummaryCounters)
+
+
+@pytest.mark.parametrize("t_first", (None, 0, 1, 123456789))
+@pytest.mark.parametrize("t_last", (None, 0, 1, 123456789))
+@mark_async_test
+async def test_time_in_summary(t_first, t_last):
+ run_meta = None
+ if t_first is not None:
+ run_meta = {"t_first": t_first}
+ summary_meta = None
+ if t_last is not None:
+ summary_meta = {"t_last": t_last}
+ connection = AsyncConnectionStub(
+ records=Records(["n"], [[i] for i in range(100)]), run_meta=run_meta,
+ summary_meta=summary_meta
+ )
+
+ result = AsyncResult(connection, HydratorStub(), 1, noop, noop)
+ await result._run("CYPHER", {}, None, None, "r", None)
+ summary = await result.consume()
+
+ if t_first is not None:
+ assert isinstance(summary.result_available_after, int)
+ assert summary.result_available_after == t_first
+ else:
+ assert summary.result_available_after is None
+ if t_last is not None:
+ assert isinstance(summary.result_consumed_after, int)
+ assert summary.result_consumed_after == t_last
+ else:
+ assert summary.result_consumed_after is None
+ assert not hasattr(summary, "t_first")
+ assert not hasattr(summary, "t_last")
+
+
+@mark_async_test
+async def test_counts_in_summary():
+ connection = AsyncConnectionStub(records=Records(["n"], [[1], [2]]))
+
+ result = AsyncResult(connection, HydratorStub(), 1, noop, noop)
+ await result._run("CYPHER", {}, None, None, "r", None)
+ summary = await result.consume()
+
+ assert isinstance(summary.counters, SummaryCounters)
+
+
+@pytest.mark.parametrize("query_type", ("r", "w", "rw", "s"))
+@mark_async_test
+async def test_query_type(query_type):
+ connection = AsyncConnectionStub(
+ records=Records(["n"], [[1], [2]]), summary_meta={"type": query_type}
+ )
+
+ result = AsyncResult(connection, HydratorStub(), 1, noop, noop)
+ await result._run("CYPHER", {}, None, None, "r", None)
+ summary = await result.consume()
+
+ assert isinstance(summary.query_type, str)
+ assert summary.query_type == query_type
+
+
+@pytest.mark.parametrize("num_records", range(0, 5))
+@mark_async_test
+async def test_data(num_records):
+ connection = AsyncConnectionStub(
+ records=Records(["n"], [[i + 1] for i in range(num_records)])
+ )
+
+ result = AsyncResult(connection, HydratorStub(), 1, noop, noop)
+ await result._run("CYPHER", {}, None, None, "r", None)
+ await result._buffer_all()
+ records = result._record_buffer.copy()
+ assert len(records) == num_records
+ expected_data = []
+ for i, record in enumerate(records):
+ record.data = mock.Mock()
+ expected_data.append("magic_return_%s" % i)
+ record.data.return_value = expected_data[-1]
+ assert await result.data("hello", "world") == expected_data
+ for record in records:
+ assert record.data.called_once_with("hello", "world")
diff --git a/tests/unit/async_/work/test_session.py b/tests/unit/async_/work/test_session.py
new file mode 100644
index 00000000..5da701da
--- /dev/null
+++ b/tests/unit/async_/work/test_session.py
@@ -0,0 +1,285 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from contextlib import contextmanager
+
+import pytest
+
+from neo4j import (
+ AsyncSession,
+ AsyncTransaction,
+ SessionConfig,
+ unit_of_work,
+)
+from neo4j._async.io._pool import AsyncIOPool
+
+from ..._async_compat import (
+ AsyncMock,
+ mark_async_test,
+ mock,
+)
+from ._fake_connection import AsyncFakeConnection
+
+
+@pytest.fixture()
+def pool():
+ pool = AsyncMock(spec=AsyncIOPool)
+ pool.acquire.side_effect = iter(AsyncFakeConnection, 0)
+ return pool
+
+
+@mark_async_test
+async def test_session_context_calls_close():
+ s = AsyncSession(None, SessionConfig())
+ with mock.patch.object(s, 'close', autospec=True) as mock_close:
+ async with s:
+ pass
+ mock_close.assert_called_once_with()
+
+
+@pytest.mark.parametrize("test_run_args", (
+ ("RETURN $x", {"x": 1}), ("RETURN 1",)
+))
+@pytest.mark.parametrize(("repetitions", "consume"), (
+ (1, False), (2, False), (2, True)
+))
+@mark_async_test
+async def test_opens_connection_on_run(
+ pool, test_run_args, repetitions, consume
+):
+ async with AsyncSession(pool, SessionConfig()) as session:
+ assert session._connection is None
+ result = await session.run(*test_run_args)
+ assert session._connection is not None
+ if consume:
+ await result.consume()
+
+
+@pytest.mark.parametrize("test_run_args", (
+ ("RETURN $x", {"x": 1}), ("RETURN 1",)
+))
+@pytest.mark.parametrize("repetitions", range(1, 3))
+@mark_async_test
+async def test_closes_connection_after_consume(
+ pool, test_run_args, repetitions
+):
+ async with AsyncSession(pool, SessionConfig()) as session:
+ result = await session.run(*test_run_args)
+ await result.consume()
+ assert session._connection is None
+ assert session._connection is None
+
+
+@pytest.mark.parametrize("test_run_args", (
+ ("RETURN $x", {"x": 1}), ("RETURN 1",)
+))
+@mark_async_test
+async def test_keeps_connection_until_last_result_consumed(
+ pool, test_run_args
+):
+ async with AsyncSession(pool, SessionConfig()) as session:
+ result1 = await session.run(*test_run_args)
+ result2 = await session.run(*test_run_args)
+ assert session._connection is not None
+ await result1.consume()
+ assert session._connection is not None
+ await result2.consume()
+ assert session._connection is None
+
+
+@mark_async_test
+async def test_opens_connection_on_tx_begin(pool):
+ async with AsyncSession(pool, SessionConfig()) as session:
+ assert session._connection is None
+ async with await session.begin_transaction() as _:
+ assert session._connection is not None
+
+
+@pytest.mark.parametrize("test_run_args", (
+ ("RETURN $x", {"x": 1}), ("RETURN 1",)
+))
+@pytest.mark.parametrize("repetitions", range(1, 3))
+@mark_async_test
+async def test_keeps_connection_on_tx_run(pool, test_run_args, repetitions):
+ async with AsyncSession(pool, SessionConfig()) as session:
+ async with await session.begin_transaction() as tx:
+ for _ in range(repetitions):
+ await tx.run(*test_run_args)
+ assert session._connection is not None
+
+
+@pytest.mark.parametrize("test_run_args", (
+ ("RETURN $x", {"x": 1}), ("RETURN 1",)
+))
+@pytest.mark.parametrize("repetitions", range(1, 3))
+@mark_async_test
+async def test_keeps_connection_on_tx_consume(
+ pool, test_run_args, repetitions
+):
+ async with AsyncSession(pool, SessionConfig()) as session:
+ async with await session.begin_transaction() as tx:
+ for _ in range(repetitions):
+ result = await tx.run(*test_run_args)
+ await result.consume()
+ assert session._connection is not None
+
+
+@pytest.mark.parametrize("test_run_args", (
+ ("RETURN $x", {"x": 1}), ("RETURN 1",)
+))
+@mark_async_test
+async def test_closes_connection_after_tx_close(pool, test_run_args):
+ async with AsyncSession(pool, SessionConfig()) as session:
+ async with await session.begin_transaction() as tx:
+ for _ in range(2):
+ result = await tx.run(*test_run_args)
+ await result.consume()
+ await tx.close()
+ assert session._connection is None
+ assert session._connection is None
+
+
+@pytest.mark.parametrize("test_run_args", (
+ ("RETURN $x", {"x": 1}), ("RETURN 1",)
+))
+@mark_async_test
+async def test_closes_connection_after_tx_commit(pool, test_run_args):
+ async with AsyncSession(pool, SessionConfig()) as session:
+ async with await session.begin_transaction() as tx:
+ for _ in range(2):
+ result = await tx.run(*test_run_args)
+ await result.consume()
+ await tx.commit()
+ assert session._connection is None
+ assert session._connection is None
+
+
+@pytest.mark.parametrize("bookmarks", (None, [], ["abc"], ["foo", "bar"]))
+@mark_async_test
+async def test_session_returns_bookmark_directly(pool, bookmarks):
+ async with AsyncSession(
+ pool, SessionConfig(bookmarks=bookmarks)
+ ) as session:
+ if bookmarks:
+ assert await session.last_bookmark() == bookmarks[-1]
+ else:
+ assert await session.last_bookmark() is None
+
+
+@pytest.mark.parametrize(("query", "error_type"), (
+ (None, ValueError),
+ (1234, TypeError),
+ ({"how about": "no?"}, TypeError),
+ (["I don't", "think so"], TypeError),
+))
+@mark_async_test
+async def test_session_run_wrong_types(pool, query, error_type):
+ async with AsyncSession(pool, SessionConfig()) as session:
+ with pytest.raises(error_type):
+ await session.run(query)
+
+
+@pytest.mark.parametrize("tx_type", ("write_transaction", "read_transaction"))
+@mark_async_test
+async def test_tx_function_argument_type(pool, tx_type):
+ async def work(tx):
+ assert isinstance(tx, AsyncTransaction)
+
+ async with AsyncSession(pool, SessionConfig()) as session:
+ getattr(session, tx_type)(work)
+
+
+@pytest.mark.parametrize("tx_type", ("write_transaction", "read_transaction"))
+@pytest.mark.parametrize("decorator_kwargs", (
+ {},
+ {"timeout": 5},
+ {"metadata": {"foo": "bar"}},
+ {"timeout": 5, "metadata": {"foo": "bar"}},
+
+))
+@mark_async_test
+async def test_decorated_tx_function_argument_type(pool, tx_type, decorator_kwargs):
+ @unit_of_work(**decorator_kwargs)
+ async def work(tx):
+ assert isinstance(tx, AsyncTransaction)
+
+ async with AsyncSession(pool, SessionConfig()) as session:
+ getattr(session, tx_type)(work)
+
+
+@mark_async_test
+async def test_session_tx_type(pool):
+ async with AsyncSession(pool, SessionConfig()) as session:
+ tx = await session.begin_transaction()
+ assert isinstance(tx, AsyncTransaction)
+
+
+@pytest.mark.parametrize(("parameters", "error_type"), (
+ ({"x": None}, None),
+ ({"x": True}, None),
+ ({"x": False}, None),
+ ({"x": 123456789}, None),
+ ({"x": 3.1415926}, None),
+ ({"x": float("nan")}, None),
+ ({"x": float("inf")}, None),
+ ({"x": float("-inf")}, None),
+ ({"x": "foo"}, None),
+ ({"x": bytearray([0x00, 0x33, 0x66, 0x99, 0xCC, 0xFF])}, None),
+ ({"x": b"\x00\x33\x66\x99\xcc\xff"}, None),
+ ({"x": [1, 2, 3]}, None),
+ ({"x": ["a", "b", "c"]}, None),
+ ({"x": ["a", 2, 1.234]}, None),
+ ({"x": ["a", 2, ["c"]]}, None),
+ ({"x": {"one": "eins", "two": "zwei", "three": "drei"}}, None),
+ ({"x": {"one": ["eins", "uno", 1], "two": ["zwei", "dos", 2]}}, None),
+
+ # maps must have string keys
+ ({"x": {1: 'eins', 2: 'zwei', 3: 'drei'}}, TypeError),
+ ({"x": {(1, 2): '1+2i', (2, 0): '2'}}, TypeError),
+))
+@pytest.mark.parametrize("run_type", ("auto", "unmanaged", "managed"))
+@mark_async_test
+async def test_session_run_with_parameters(
+ pool, parameters, error_type, run_type
+):
+ @contextmanager
+ def raises():
+ if error_type is not None:
+ with pytest.raises(error_type) as exc:
+ yield exc
+ else:
+ yield None
+
+ async with AsyncSession(pool, SessionConfig()) as session:
+ if run_type == "auto":
+ with raises():
+ await session.run("RETURN $x", **parameters)
+ elif run_type == "unmanaged":
+ tx = await session.begin_transaction()
+ with raises():
+ await tx.run("RETURN $x", **parameters)
+ elif run_type == "managed":
+ async def work(tx):
+ with raises() as exc:
+ await tx.run("RETURN $x", **parameters)
+ if exc is not None:
+ raise exc
+ with raises():
+ await session.write_transaction(work)
+ else:
+ raise ValueError(run_type)
diff --git a/tests/unit/async_/work/test_transaction.py b/tests/unit/async_/work/test_transaction.py
new file mode 100644
index 00000000..6eafbb1a
--- /dev/null
+++ b/tests/unit/async_/work/test_transaction.py
@@ -0,0 +1,185 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from unittest.mock import MagicMock
+from uuid import uuid4
+
+import pytest
+
+from neo4j import (
+ Query,
+ Transaction,
+)
+
+from ._fake_connection import async_fake_connection
+
+
+@pytest.mark.parametrize(("explicit_commit", "close"), (
+ (False, False),
+ (True, False),
+ (True, True),
+))
+def test_transaction_context_when_committing(mocker, async_fake_connection,
+ explicit_commit, close):
+ on_closed = MagicMock()
+ on_error = MagicMock()
+ tx = Transaction(async_fake_connection, 2, on_closed, on_error)
+ mock_commit = mocker.patch.object(tx, "commit", wraps=tx.commit)
+ mock_rollback = mocker.patch.object(tx, "rollback", wraps=tx.rollback)
+ with tx as tx_:
+ assert mock_commit.call_count == 0
+ assert mock_rollback.call_count == 0
+ assert tx is tx_
+ if explicit_commit:
+ tx_.commit()
+ mock_commit.assert_called_once_with()
+ assert tx.closed()
+ if close:
+ tx_.close()
+ assert tx_.closed()
+ mock_commit.assert_called_once_with()
+ assert mock_rollback.call_count == 0
+ assert tx_.closed()
+
+
+@pytest.mark.parametrize(("rollback", "close"), (
+ (True, False),
+ (False, True),
+ (True, True),
+))
+def test_transaction_context_with_explicit_rollback(mocker, async_fake_connection,
+ rollback, close):
+ on_closed = MagicMock()
+ on_error = MagicMock()
+ tx = Transaction(async_fake_connection, 2, on_closed, on_error)
+ mock_commit = mocker.patch.object(tx, "commit", wraps=tx.commit)
+ mock_rollback = mocker.patch.object(tx, "rollback", wraps=tx.rollback)
+ with tx as tx_:
+ assert mock_commit.call_count == 0
+ assert mock_rollback.call_count == 0
+ assert tx is tx_
+ if rollback:
+ tx_.rollback()
+ mock_rollback.assert_called_once_with()
+ assert tx_.closed()
+ if close:
+ tx_.close()
+ mock_rollback.assert_called_once_with()
+ assert tx_.closed()
+ assert mock_commit.call_count == 0
+ mock_rollback.assert_called_once_with()
+ assert tx_.closed()
+
+
+def test_transaction_context_calls_rollback_on_error(mocker, async_fake_connection):
+ class OopsError(RuntimeError):
+ pass
+
+ on_closed = MagicMock()
+ on_error = MagicMock()
+ tx = Transaction(async_fake_connection, 2, on_closed, on_error)
+ mock_commit = mocker.patch.object(tx, "commit", wraps=tx.commit)
+ mock_rollback = mocker.patch.object(tx, "rollback", wraps=tx.rollback)
+ with pytest.raises(OopsError):
+ with tx as tx_:
+ assert mock_commit.call_count == 0
+ assert mock_rollback.call_count == 0
+ assert tx is tx_
+ raise OopsError
+ assert mock_commit.call_count == 0
+ mock_rollback.assert_called_once_with()
+ assert tx_.closed()
+
+
+@pytest.mark.parametrize(("parameters", "error_type"), (
+ # maps must have string keys
+ ({"x": {1: 'eins', 2: 'zwei', 3: 'drei'}}, TypeError),
+ ({"x": {(1, 2): '1+2i', (2, 0): '2'}}, TypeError),
+ ({"x": uuid4()}, TypeError),
+))
+def test_transaction_run_with_invalid_parameters(async_fake_connection, parameters,
+ error_type):
+ on_closed = MagicMock()
+ on_error = MagicMock()
+ tx = Transaction(async_fake_connection, 2, on_closed, on_error)
+ with pytest.raises(error_type):
+ tx.run("RETURN $x", **parameters)
+
+
+def test_transaction_run_takes_no_query_object(async_fake_connection):
+ on_closed = MagicMock()
+ on_error = MagicMock()
+ tx = Transaction(async_fake_connection, 2, on_closed, on_error)
+ with pytest.raises(ValueError):
+ tx.run(Query("RETURN 1"))
+
+
+def test_transaction_rollbacks_on_open_connections(async_fake_connection):
+ tx = Transaction(async_fake_connection, 2,
+ lambda *args, **kwargs: None,
+ lambda *args, **kwargs: None)
+ with tx as tx_:
+ async_fake_connection.is_reset_mock.return_value = False
+ async_fake_connection.is_reset_mock.reset_mock()
+ tx_.rollback()
+ async_fake_connection.is_reset_mock.assert_called_once()
+ async_fake_connection.reset.assert_not_called()
+ async_fake_connection.rollback.assert_called_once()
+
+
+def test_transaction_no_rollback_on_reset_connections(async_fake_connection):
+ tx = Transaction(async_fake_connection, 2,
+ lambda *args, **kwargs: None,
+ lambda *args, **kwargs: None)
+ with tx as tx_:
+ async_fake_connection.is_reset_mock.return_value = True
+ async_fake_connection.is_reset_mock.reset_mock()
+ tx_.rollback()
+ async_fake_connection.is_reset_mock.assert_called_once()
+ async_fake_connection.reset.asset_not_called()
+ async_fake_connection.rollback.asset_not_called()
+
+
+def test_transaction_no_rollback_on_closed_connections(async_fake_connection):
+ tx = Transaction(async_fake_connection, 2,
+ lambda *args, **kwargs: None,
+ lambda *args, **kwargs: None)
+ with tx as tx_:
+ async_fake_connection.closed.return_value = True
+ async_fake_connection.closed.reset_mock()
+ async_fake_connection.is_reset_mock.reset_mock()
+ tx_.rollback()
+ async_fake_connection.closed.assert_called_once()
+ async_fake_connection.is_reset_mock.asset_not_called()
+ async_fake_connection.reset.asset_not_called()
+ async_fake_connection.rollback.asset_not_called()
+
+
+def test_transaction_no_rollback_on_defunct_connections(async_fake_connection):
+ tx = Transaction(async_fake_connection, 2,
+ lambda *args, **kwargs: None,
+ lambda *args, **kwargs: None)
+ with tx as tx_:
+ async_fake_connection.defunct.return_value = True
+ async_fake_connection.defunct.reset_mock()
+ async_fake_connection.is_reset_mock.reset_mock()
+ tx_.rollback()
+ async_fake_connection.defunct.assert_called_once()
+ async_fake_connection.is_reset_mock.asset_not_called()
+ async_fake_connection.reset.asset_not_called()
+ async_fake_connection.rollback.asset_not_called()
diff --git a/tests/unit/common/__init__.py b/tests/unit/common/__init__.py
new file mode 100644
index 00000000..b81a309d
--- /dev/null
+++ b/tests/unit/common/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/unit/common/data/__init__.py b/tests/unit/common/data/__init__.py
new file mode 100644
index 00000000..b81a309d
--- /dev/null
+++ b/tests/unit/common/data/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/unit/data/test_packing.py b/tests/unit/common/data/test_packing.py
similarity index 96%
rename from tests/unit/data/test_packing.py
rename to tests/unit/common/data/test_packing.py
index 6fc296af..39d32576 100644
--- a/tests/unit/data/test_packing.py
+++ b/tests/unit/common/data/test_packing.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,16 +16,20 @@
# limitations under the License.
-import struct
-from collections import OrderedDict
from io import BytesIO
from math import pi
+import struct
from unittest import TestCase
from uuid import uuid4
from pytest import raises
-from neo4j.packstream import Packer, UnpackableBuffer, Unpacker, Structure
+from neo4j.packstream import (
+ Packer,
+ Structure,
+ UnpackableBuffer,
+ Unpacker,
+)
class PackStreamTestCase(TestCase):
@@ -254,7 +255,7 @@ def test_empty_map(self):
def test_tiny_maps(self):
for size in range(0x10):
- data_in = OrderedDict()
+ data_in = dict()
data_out = bytearray([0xA0 + size])
for el in range(1, size + 1):
data_in[chr(64 + el)] = el
@@ -262,17 +263,17 @@ def test_tiny_maps(self):
self.assert_packable(data_in, bytes(data_out))
def test_map_8(self):
- d = OrderedDict([(u"A%s" % i, 1) for i in range(40)])
+ d = dict([(u"A%s" % i, 1) for i in range(40)])
b = b"".join(self.packb(u"A%s" % i, 1) for i in range(40))
self.assert_packable(d, b"\xD8\x28" + b)
def test_map_16(self):
- d = OrderedDict([(u"A%s" % i, 1) for i in range(40000)])
+ d = dict([(u"A%s" % i, 1) for i in range(40000)])
b = b"".join(self.packb(u"A%s" % i, 1) for i in range(40000))
self.assert_packable(d, b"\xD9\x9C\x40" + b)
def test_map_32(self):
- d = OrderedDict([(u"A%s" % i, 1) for i in range(80000)])
+ d = dict([(u"A%s" % i, 1) for i in range(80000)])
b = b"".join(self.packb(u"A%s" % i, 1) for i in range(80000))
self.assert_packable(d, b"\xDA\x00\x01\x38\x80" + b)
diff --git a/tests/unit/common/io/__init__.py b/tests/unit/common/io/__init__.py
new file mode 100644
index 00000000..b81a309d
--- /dev/null
+++ b/tests/unit/common/io/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/unit/io/test_routing.py b/tests/unit/common/io/test_routing.py
similarity index 84%
rename from tests/unit/io/test_routing.py
rename to tests/unit/common/io/test_routing.py
index ff555bcb..41f09171 100644
--- a/tests/unit/io/test_routing.py
+++ b/tests/unit/common/io/test_routing.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,19 +16,14 @@
# limitations under the License.
-from unittest import TestCase
+import pytest
-from neo4j.io import (
- Bolt,
- Neo4jPool,
-)
+from neo4j.api import DEFAULT_DATABASE
from neo4j.routing import (
OrderedSet,
RoutingTable,
)
-from neo4j.api import (
- DEFAULT_DATABASE,
-)
+
VALID_ROUTING_RECORD = {
"ttl": 300,
@@ -53,7 +45,7 @@
}
-class OrderedSetTestCase(TestCase):
+class TestOrderedSet:
def test_should_repr_as_set(self):
s = OrderedSet([1, 2, 3])
assert repr(s) == "{1, 2, 3}"
@@ -68,14 +60,14 @@ def test_should_not_contain_non_element(self):
def test_should_be_able_to_get_item_if_empty(self):
s = OrderedSet([])
- with self.assertRaises(IndexError):
+ with pytest.raises(IndexError):
_ = s[0]
def test_should_be_able_to_get_items_by_index(self):
s = OrderedSet([1, 2, 3])
- self.assertEqual(s[0], 1)
- self.assertEqual(s[1], 2)
- self.assertEqual(s[2], 3)
+ assert s[0] == 1
+ assert s[1] == 2
+ assert s[2] == 3
def test_should_be_iterable(self):
s = OrderedSet([1, 2, 3])
@@ -117,7 +109,7 @@ def test_should_be_able_to_remove_existing(self):
def test_should_not_be_able_to_remove_non_existing(self):
s = OrderedSet([1, 2, 3])
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
s.remove(4)
def test_should_be_able_to_update(self):
@@ -131,14 +123,14 @@ def test_should_be_able_to_replace(self):
assert list(s) == [3, 4, 5]
-class RoutingTableConstructionTestCase(TestCase):
+class TestRoutingTableConstruction:
def test_should_be_initially_stale(self):
table = RoutingTable(database=DEFAULT_DATABASE)
assert not table.is_fresh(readonly=True)
assert not table.is_fresh(readonly=False)
-class RoutingTableParseRoutingInfoTestCase(TestCase):
+class TestRoutingTableParseRoutingInfo:
def test_should_return_routing_table_on_valid_record(self):
table = RoutingTable.parse_routing_info(
database=DEFAULT_DATABASE,
@@ -162,7 +154,7 @@ def test_should_return_routing_table_on_valid_record_with_extra_role(self):
assert table.ttl == 300
-class RoutingTableServersTestCase(TestCase):
+class TestRoutingTableServers:
def test_should_return_all_distinct_servers_in_routing_table(self):
routing_table = {
"ttl": 300,
@@ -180,7 +172,7 @@ def test_should_return_all_distinct_servers_in_routing_table(self):
assert table.servers() == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003), ('127.0.0.1', 9005)}
-class RoutingTableFreshnessTestCase(TestCase):
+class TestRoutingTableFreshness:
def test_should_be_fresh_after_update(self):
table = RoutingTable.parse_routing_info(
database=DEFAULT_DATABASE,
@@ -221,16 +213,20 @@ def test_should_become_stale_if_no_writers(self):
assert not table.is_fresh(readonly=False)
-class RoutingTableUpdateTestCase(TestCase):
- def setUp(self):
- self.table = RoutingTable(
+class TestRoutingTableUpdate:
+ @pytest.fixture
+ def table(self):
+ return RoutingTable(
database=DEFAULT_DATABASE,
routers=[("192.168.1.1", 7687), ("192.168.1.2", 7687)],
readers=[("192.168.1.3", 7687)],
writers=[],
ttl=0,
)
- self.new_table = RoutingTable(
+
+ @pytest.fixture
+ def new_table(self):
+ return RoutingTable(
database=DEFAULT_DATABASE,
routers=[("127.0.0.1", 9001), ("127.0.0.1", 9002), ("127.0.0.1", 9003)],
readers=[("127.0.0.1", 9004), ("127.0.0.1", 9005)],
@@ -238,18 +234,19 @@ def setUp(self):
ttl=300,
)
- def test_update_should_replace_routers(self):
- self.table.update(self.new_table)
- assert self.table.routers == {("127.0.0.1", 9001), ("127.0.0.1", 9002), ("127.0.0.1", 9003)}
+ def test_update_should_replace_routers(self, table, new_table):
+ table.update(new_table)
+ assert table.routers == {("127.0.0.1", 9001), ("127.0.0.1", 9002),
+ ("127.0.0.1", 9003)}
- def test_update_should_replace_readers(self):
- self.table.update(self.new_table)
- assert self.table.readers == {("127.0.0.1", 9004), ("127.0.0.1", 9005)}
+ def test_update_should_replace_readers(self, table, new_table):
+ table.update(new_table)
+ assert table.readers == {("127.0.0.1", 9004), ("127.0.0.1", 9005)}
- def test_update_should_replace_writers(self):
- self.table.update(self.new_table)
- assert self.table.writers == {("127.0.0.1", 9006)}
+ def test_update_should_replace_writers(self, table, new_table):
+ table.update(new_table)
+ assert table.writers == {("127.0.0.1", 9006)}
- def test_update_should_replace_ttl(self):
- self.table.update(self.new_table)
- assert self.table.ttl == 300
+ def test_update_should_replace_ttl(self, table, new_table):
+ table.update(new_table)
+ assert table.ttl == 300
diff --git a/tests/unit/common/spatial/__init__.py b/tests/unit/common/spatial/__init__.py
new file mode 100644
index 00000000..b81a309d
--- /dev/null
+++ b/tests/unit/common/spatial/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/unit/spatial/test_cartesian_point.py b/tests/unit/common/spatial/test_cartesian_point.py
similarity index 98%
rename from tests/unit/spatial/test_cartesian_point.py
rename to tests/unit/common/spatial/test_cartesian_point.py
index ee86e5b9..c33a90a1 100644
--- a/tests/unit/spatial/test_cartesian_point.py
+++ b/tests/unit/common/spatial/test_cartesian_point.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# coding: utf-8
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -18,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+
import io
import struct
from unittest import TestCase
diff --git a/tests/unit/spatial/test_point.py b/tests/unit/common/spatial/test_point.py
similarity index 98%
rename from tests/unit/spatial/test_point.py
rename to tests/unit/common/spatial/test_point.py
index 082f95c5..74eaa3a7 100644
--- a/tests/unit/spatial/test_point.py
+++ b/tests/unit/common/spatial/test_point.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# coding: utf-8
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -18,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+
import io
import struct
from unittest import TestCase
diff --git a/tests/unit/spatial/test_wgs84_point.py b/tests/unit/common/spatial/test_wgs84_point.py
similarity index 98%
rename from tests/unit/spatial/test_wgs84_point.py
rename to tests/unit/common/spatial/test_wgs84_point.py
index 8f725a58..0dee1913 100644
--- a/tests/unit/spatial/test_wgs84_point.py
+++ b/tests/unit/common/spatial/test_wgs84_point.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# coding: utf-8
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -18,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+
import io
import struct
from unittest import TestCase
diff --git a/tests/unit/test_addressing.py b/tests/unit/common/test_addressing.py
similarity index 61%
rename from tests/unit/test_addressing.py
rename to tests/unit/common/test_addressing.py
index 317ae5bf..925e7017 100644
--- a/tests/unit/test_addressing.py
+++ b/tests/unit/common/test_addressing.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,19 +16,19 @@
# limitations under the License.
-import pytest
-import unittest.mock as mock
from socket import (
AF_INET,
AF_INET6,
)
+import unittest.mock as mock
-from neo4j.addressing import (
+import pytest
+
+from neo4j import (
Address,
IPv4Address,
- IPv6Address,
)
-from neo4j import GraphDatabase
+
mock_socket_ipv4 = mock.Mock()
mock_socket_ipv4.getpeername = lambda: ("127.0.0.1", 7687) # (address, port)
@@ -39,8 +36,6 @@
mock_socket_ipv6 = mock.Mock()
mock_socket_ipv6.getpeername = lambda: ("[::1]", 7687, 0, 0) # (address, port, flow info, scope id)
-# python -m pytest tests/unit/test_addressing.py -s
-
@pytest.mark.parametrize(
"test_input, expected",
@@ -57,7 +52,6 @@
]
)
def test_address_initialization(test_input, expected):
- # python -m pytest tests/unit/test_addressing.py -s -k test_address_initialization
address = Address(test_input)
assert address.family == expected["family"]
assert address.host == expected["host"]
@@ -74,7 +68,6 @@ def test_address_initialization(test_input, expected):
]
)
def test_address_init_with_address_object_returns_same_instance(test_input):
- # python -m pytest tests/unit/test_addressing.py -s -k test_address_init_with_address_object_returns_same_instance
address = Address(test_input)
assert address is test_input
assert id(address) == id(test_input)
@@ -90,7 +83,6 @@ def test_address_init_with_address_object_returns_same_instance(test_input):
]
)
def test_address_initialization_with_incorrect_input(test_input, expected):
- # python -m pytest tests/unit/test_addressing.py -s -k test_address_initialization_with_incorrect_input
with pytest.raises(expected):
address = Address(test_input)
@@ -103,14 +95,12 @@ def test_address_initialization_with_incorrect_input(test_input, expected):
]
)
def test_address_from_socket(test_input, expected):
- # python -m pytest tests/unit/test_addressing.py -s -k test_address_from_socket
address = Address.from_socket(test_input)
assert address == expected
def test_address_from_socket_with_none():
- # python -m pytest tests/unit/test_addressing.py -s -k test_address_from_socket_with_none
with pytest.raises(AttributeError):
address = Address.from_socket(None)
@@ -128,7 +118,6 @@ def test_address_from_socket_with_none():
]
)
def test_address_parse_with_ipv4(test_input, expected):
- # python -m pytest tests/unit/test_addressing.py -s -k test_address_parse_with_ipv4
parsed = Address.parse(test_input)
assert parsed == expected
@@ -143,7 +132,6 @@ def test_address_parse_with_ipv4(test_input, expected):
]
)
def test_address_should_parse_ipv6(test_input, expected):
- # python -m pytest tests/unit/test_addressing.py -s -k test_address_should_parse_ipv6
parsed = Address.parse(test_input)
assert parsed == expected
@@ -159,7 +147,6 @@ def test_address_should_parse_ipv6(test_input, expected):
]
)
def test_address_parse_with_invalid_input(test_input, expected):
- # python -m pytest tests/unit/test_addressing.py -s -k test_address_parse_with_invalid_input
with pytest.raises(expected):
parsed = Address.parse(test_input)
@@ -174,7 +161,6 @@ def test_address_parse_with_invalid_input(test_input, expected):
]
)
def test_address_parse_list(test_input, expected):
- # python -m pytest tests/unit/test_addressing.py -s -k test_address_parse_list
addresses = Address.parse_list(*test_input)
assert len(addresses) == expected
@@ -190,69 +176,5 @@ def test_address_parse_list(test_input, expected):
]
)
def test_address_parse_list_with_invalid_input(test_input, expected):
- # python -m pytest tests/unit/test_addressing.py -s -k test_address_parse_list_with_invalid_input
with pytest.raises(TypeError):
addresses = Address.parse_list(*test_input)
-
-
-def test_address_resolve():
- # python -m pytest tests/unit/test_addressing.py -s -k test_address_resolve
- address = Address(("127.0.0.1", 7687))
- resolved = address.resolve()
- assert isinstance(resolved, Address) is False
- assert isinstance(resolved, list) is True
- assert len(resolved) == 1
- assert resolved[0] == IPv4Address(('127.0.0.1', 7687))
-
-
-def test_address_resolve_with_custom_resolver_none():
- # python -m pytest tests/unit/test_addressing.py -s -k test_address_resolve_with_custom_resolver_none
- address = Address(("127.0.0.1", 7687))
- resolved = address.resolve(resolver=None)
- assert isinstance(resolved, Address) is False
- assert isinstance(resolved, list) is True
- assert len(resolved) == 1
- assert resolved[0] == IPv4Address(('127.0.0.1', 7687))
-
-
-@pytest.mark.parametrize(
- "test_input, expected",
- [
- (Address(("127.0.0.1", "abcd")), ValueError),
- (Address((None, None)), ValueError),
- ]
-)
-def test_address_resolve_with_unresolvable_address(test_input, expected):
- # python -m pytest tests/unit/test_addressing.py -s -k test_address_resolve_with_unresolvable_address
- with pytest.raises(expected):
- test_input.resolve(resolver=None)
-
-
-def test_address_resolve_with_custom_resolver():
- # python -m pytest tests/unit/test_addressing.py -s -k test_address_resolve_with_custom_resolver
- custom_resolver = lambda _: [("127.0.0.1", 7687), ("localhost", 1234)]
-
- address = Address(("127.0.0.1", 7687))
- resolved = address.resolve(family=AF_INET, resolver=custom_resolver)
- assert isinstance(resolved, Address) is False
- assert isinstance(resolved, list) is True
- assert len(resolved) == 2 # IPv4 only
- assert resolved[0] == IPv4Address(('127.0.0.1', 7687))
- assert resolved[1] == IPv4Address(('127.0.0.1', 1234))
-
-
-def test_address_unresolve():
- # python -m pytest tests/unit/test_addressing.py -s -k test_address_unresolve
- custom_resolved = [("127.0.0.1", 7687), ("localhost", 4321)]
- custom_resolver = lambda _: custom_resolved
-
- address = Address(("foobar", 1234))
- unresolved = address.unresolved
- assert address.__class__ == unresolved.__class__
- assert address == unresolved
- resolved = address.resolve(family=AF_INET, resolver=custom_resolver)
- custom_resolved = sorted(Address(a) for a in custom_resolved)
- unresolved = sorted(a.unresolved for a in resolved)
- assert custom_resolved == unresolved
- assert (list(map(lambda a: a.__class__, custom_resolved))
- == list(map(lambda a: a.__class__, unresolved)))
diff --git a/tests/unit/test_api.py b/tests/unit/common/test_api.py
similarity index 83%
rename from tests/unit/test_api.py
rename to tests/unit/common/test_api.py
index 3e06bf8b..6e3caa07 100644
--- a/tests/unit/test_api.py
+++ b/tests/unit/common/test_api.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,27 +16,24 @@
# limitations under the License.
-import pytest
from uuid import uuid4
+import pytest
+
import neo4j.api
from neo4j.data import DataDehydrator
-from neo4j.exceptions import (
- ConfigurationError,
-)
+from neo4j.exceptions import ConfigurationError
+
standard_ascii = [chr(i) for i in range(128)]
not_ascii = "♥O◘♦♥O◘♦"
-# python -m pytest tests/unit/test_api.py -s
-
def dehydrated_value(value):
return DataDehydrator.fix_parameters({"_": value})["_"]
def test_value_dehydration_should_allow_none():
- # python -m pytest tests/unit/test_api.py -s -k test_value_dehydration_should_allow_none
assert dehydrated_value(None) is None
@@ -51,7 +45,6 @@ def test_value_dehydration_should_allow_none():
]
)
def test_value_dehydration_should_allow_boolean(test_input, expected):
- # python -m pytest tests/unit/test_api.py -s -k test_value_dehydration_should_allow_boolean
assert dehydrated_value(test_input) is expected
@@ -67,7 +60,6 @@ def test_value_dehydration_should_allow_boolean(test_input, expected):
]
)
def test_value_dehydration_should_allow_integer(test_input, expected):
- # python -m pytest tests/unit/test_api.py -s -k test_value_dehydration_should_allow_integer
assert dehydrated_value(test_input) == expected
@@ -79,7 +71,6 @@ def test_value_dehydration_should_allow_integer(test_input, expected):
]
)
def test_value_dehydration_should_disallow_oversized_integer(test_input, expected):
- # python -m pytest tests/unit/test_api.py -s -k test_value_dehydration_should_disallow_oversized_integer
with pytest.raises(expected):
dehydrated_value(test_input)
@@ -94,7 +85,6 @@ def test_value_dehydration_should_disallow_oversized_integer(test_input, expecte
]
)
def test_value_dehydration_should_allow_float(test_input, expected):
- # python -m pytest tests/unit/test_api.py -s -k test_value_dehydration_should_allow_float
assert dehydrated_value(test_input) == expected
@@ -107,7 +97,6 @@ def test_value_dehydration_should_allow_float(test_input, expected):
]
)
def test_value_dehydration_should_allow_string(test_input, expected):
- # python -m pytest tests/unit/test_api.py -s -k test_value_dehydration_should_allow_string
assert dehydrated_value(test_input) == expected
@@ -119,7 +108,6 @@ def test_value_dehydration_should_allow_string(test_input, expected):
]
)
def test_value_dehydration_should_allow_bytes(test_input, expected):
- # python -m pytest tests/unit/test_api.py -s -k test_value_dehydration_should_allow_bytes
assert dehydrated_value(test_input) == expected
@@ -132,7 +120,6 @@ def test_value_dehydration_should_allow_bytes(test_input, expected):
]
)
def test_value_dehydration_should_allow_list(test_input, expected):
- # python -m pytest tests/unit/test_api.py -s -k test_value_dehydration_should_allow_list
assert dehydrated_value(test_input) == expected
@@ -146,7 +133,6 @@ def test_value_dehydration_should_allow_list(test_input, expected):
]
)
def test_value_dehydration_should_allow_dict(test_input, expected):
- # python -m pytest tests/unit/test_api.py -s -k test_value_dehydration_should_allow_dict
assert dehydrated_value(test_input) == expected
@@ -158,13 +144,11 @@ def test_value_dehydration_should_allow_dict(test_input, expected):
]
)
def test_value_dehydration_should_disallow_object(test_input, expected):
- # python -m pytest tests/unit/test_api.py -s -k test_value_dehydration_should_disallow_object
with pytest.raises(expected):
dehydrated_value(test_input)
def test_bookmark_initialization_with_no_values():
- # python -m pytest tests/unit/test_api.py -s -k test_bookmark_initialization_with_no_values
bookmark = neo4j.api.Bookmark()
assert bookmark.values == frozenset()
assert bool(bookmark) is False
@@ -182,7 +166,6 @@ def test_bookmark_initialization_with_no_values():
]
)
def test_bookmark_initialization_with_values_none(test_input, expected_values, expected_bool, expected_repr):
- # python -m pytest tests/unit/test_api.py -s -k test_bookmark_initialization_with_values_none
bookmark = neo4j.api.Bookmark(*test_input)
assert bookmark.values == expected_values
assert bool(bookmark) is expected_bool
@@ -200,7 +183,6 @@ def test_bookmark_initialization_with_values_none(test_input, expected_values, e
]
)
def test_bookmark_initialization_with_values_empty_string(test_input, expected_values, expected_bool, expected_repr):
- # python -m pytest tests/unit/test_api.py -s -k test_bookmark_initialization_with_values_empty_string
bookmark = neo4j.api.Bookmark(*test_input)
assert bookmark.values == expected_values
assert bool(bookmark) is expected_bool
@@ -216,7 +198,6 @@ def test_bookmark_initialization_with_values_empty_string(test_input, expected_v
]
)
def test_bookmark_initialization_with_valid_strings(test_input, expected_values, expected_bool, expected_repr):
- # python -m pytest tests/unit/test_api.py -s -k test_bookmark_initialization_with_valid_strings
bookmark = neo4j.api.Bookmark(*test_input)
assert bookmark.values == expected_values
assert bool(bookmark) is expected_bool
@@ -232,7 +213,6 @@ def test_bookmark_initialization_with_valid_strings(test_input, expected_values,
]
)
def test_bookmark_initialization_with_invalid_strings(test_input, expected):
- # python -m pytest tests/unit/test_api.py -s -k test_bookmark_initialization_with_invalid_strings
with pytest.raises(expected) as e:
bookmark = neo4j.api.Bookmark(*test_input)
@@ -251,7 +231,6 @@ def test_bookmark_initialization_with_invalid_strings(test_input, expected):
]
)
def test_version_initialization(test_input, expected_str, expected_repr):
- # python -m pytest tests/unit/test_api.py -s -k test_version_initialization
version = neo4j.api.Version(*test_input)
assert str(version) == expected_str
assert repr(version) == expected_repr
@@ -268,7 +247,6 @@ def test_version_initialization(test_input, expected_str, expected_repr):
]
)
def test_version_from_bytes_with_valid_bolt_version_handshake(test_input, expected_str, expected_repr):
- # python -m pytest tests/unit/test_api.py -s -k test_version_from_bytes_with_valid_bolt_version_handshake
version = neo4j.api.Version.from_bytes(test_input)
assert str(version) == expected_str
assert repr(version) == expected_repr
@@ -285,7 +263,6 @@ def test_version_from_bytes_with_valid_bolt_version_handshake(test_input, expect
]
)
def test_version_from_bytes_with_not_valid_bolt_version_handshake(test_input, expected):
- # python -m pytest tests/unit/test_api.py -s -k test_version_from_bytes_with_not_valid_bolt_version_handshake
with pytest.raises(expected):
version = neo4j.api.Version.from_bytes(test_input)
@@ -303,13 +280,11 @@ def test_version_from_bytes_with_not_valid_bolt_version_handshake(test_input, ex
]
)
def test_version_to_bytes_with_valid_bolt_version(test_input, expected):
- # python -m pytest tests/unit/test_api.py -s -k test_version_to_bytes_with_valid_bolt_version
version = neo4j.api.Version(*test_input)
assert version.to_bytes() == expected
def test_serverinfo_initialization():
- # python -m pytest tests/unit/test_api.py -s -k test_serverinfo_initialization
from neo4j.addressing import Address
@@ -332,7 +307,6 @@ def test_serverinfo_initialization():
]
)
def test_serverinfo_with_metadata(test_input, expected_agent, expected_version_info):
- # python -m pytest tests/unit/test_api.py -s -k test_serverinfo_with_metadata
from neo4j.addressing import Address
address = Address(("bolt://localhost", 7687))
@@ -365,7 +339,6 @@ def test_serverinfo_with_metadata(test_input, expected_agent, expected_version_i
]
)
def test_uri_scheme(test_input, expected_driver_type, expected_security_type, expected_error):
- # python -m pytest tests/unit/test_api.py -s -k test_uri_scheme
if expected_error:
with pytest.raises(expected_error):
neo4j.api.parse_neo4j_uri(test_input)
@@ -376,18 +349,15 @@ def test_uri_scheme(test_input, expected_driver_type, expected_security_type, ex
def test_parse_routing_context():
- # python -m pytest tests/unit/test_api.py -s -v -k test_parse_routing_context
context = neo4j.api.parse_routing_context(query="name=molly&color=white")
assert context == {"name": "molly", "color": "white"}
def test_parse_routing_context_should_error_when_value_missing():
- # python -m pytest tests/unit/test_api.py -s -v -k test_parse_routing_context_should_error_when_value_missing
with pytest.raises(ConfigurationError):
neo4j.api.parse_routing_context("name=&color=white")
def test_parse_routing_context_should_error_when_key_duplicate():
- # python -m pytest tests/unit/test_api.py -s -v -k test_parse_routing_context_should_error_when_key_duplicate
with pytest.raises(ConfigurationError):
neo4j.api.parse_routing_context("name=molly&name=white")
diff --git a/tests/unit/test_conf.py b/tests/unit/common/test_conf.py
similarity index 98%
rename from tests/unit/test_conf.py
rename to tests/unit/common/test_conf.py
index 9eb71819..192bad34 100644
--- a/tests/unit/test_conf.py
+++ b/tests/unit/common/test_conf.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -21,24 +18,23 @@
import pytest
-from neo4j.exceptions import (
- ConfigurationError,
+from neo4j.api import (
+ READ_ACCESS,
+ TRUST_SYSTEM_CA_SIGNED_CERTIFICATES,
+ WRITE_ACCESS,
)
from neo4j.conf import (
Config,
PoolConfig,
- WorkspaceConfig,
SessionConfig,
+ WorkspaceConfig,
)
-from neo4j.api import (
- TRUST_SYSTEM_CA_SIGNED_CERTIFICATES,
- WRITE_ACCESS,
- READ_ACCESS,
-)
+from neo4j.debug import watch
+from neo4j.exceptions import ConfigurationError
+
# python -m pytest tests/unit/test_conf.py -s -v
-from neo4j.debug import watch
watch("neo4j")
test_pool_config = {
diff --git a/tests/unit/test_data.py b/tests/unit/common/test_data.py
similarity index 97%
rename from tests/unit/test_data.py
rename to tests/unit/common/test_data.py
index 7f48af9d..14577efd 100644
--- a/tests/unit/test_data.py
+++ b/tests/unit/common/test_data.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,11 +16,10 @@
# limitations under the License.
-import pytest
-
from neo4j.data import DataHydrator
from neo4j.packstream import Structure
+
# python -m pytest -s -v tests/unit/test_data.py
diff --git a/tests/unit/test_exceptions.py b/tests/unit/common/test_exceptions.py
similarity index 98%
rename from tests/unit/test_exceptions.py
rename to tests/unit/common/test_exceptions.py
index f574ce67..802e480f 100644
--- a/tests/unit/test_exceptions.py
+++ b/tests/unit/common/test_exceptions.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -21,52 +18,48 @@
import pytest
+from neo4j._exceptions import (
+ BoltConnectionBroken,
+ BoltConnectionClosed,
+ BoltConnectionError,
+ BoltError,
+ BoltFailure,
+ BoltHandshakeError,
+ BoltProtocolError,
+ BoltSecurityError,
+)
+from neo4j._sync.io import Bolt
from neo4j.exceptions import (
- Neo4jError,
+ AuthConfigurationError,
+ AuthError,
+ CertificateConfigurationError,
+ CLASSIFICATION_CLIENT,
+ CLASSIFICATION_DATABASE,
+ CLASSIFICATION_TRANSIENT,
ClientError,
+ ConfigurationError,
+ ConstraintError,
CypherSyntaxError,
CypherTypeError,
- ConstraintError,
- AuthError,
- Forbidden,
- ForbiddenOnReadOnlyDatabase,
- NotALeader,
DatabaseError,
- TransientError,
DatabaseUnavailable,
DriverError,
+ Forbidden,
+ ForbiddenOnReadOnlyDatabase,
+ IncompleteCommit,
+ Neo4jError,
+ NotALeader,
+ ReadServiceUnavailable,
+ ResultConsumedError,
+ RoutingServiceUnavailable,
+ ServiceUnavailable,
+ SessionExpired,
TransactionError,
TransactionNestingError,
- SessionExpired,
- ServiceUnavailable,
- RoutingServiceUnavailable,
+ TransientError,
WriteServiceUnavailable,
- ReadServiceUnavailable,
- IncompleteCommit,
- ConfigurationError,
- AuthConfigurationError,
- CertificateConfigurationError,
- ResultConsumedError,
- CLASSIFICATION_CLIENT,
- CLASSIFICATION_DATABASE,
- CLASSIFICATION_TRANSIENT,
)
-from neo4j._exceptions import (
- BoltError,
- BoltHandshakeError,
- BoltConnectionError,
- BoltSecurityError,
- BoltConnectionBroken,
- BoltConnectionClosed,
- BoltFailure,
- BoltProtocolError,
-)
-
-from neo4j.io import Bolt
-
-
-# python -m pytest tests/unit/test_exceptions.py -s -v
def test_bolt_error():
with pytest.raises(BoltError) as e:
diff --git a/tests/unit/common/test_import_neo4j.py b/tests/unit/common/test_import_neo4j.py
new file mode 100644
index 00000000..89704d9e
--- /dev/null
+++ b/tests/unit/common/test_import_neo4j.py
@@ -0,0 +1,164 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+def test_import_dunder_version():
+ from neo4j import __version__
+
+
+def test_import_graphdatabase():
+ from neo4j import GraphDatabase
+
+
+def test_import_async_graphdatabase():
+ from neo4j import AsyncGraphDatabase
+
+
+def test_import_driver():
+ from neo4j import Driver
+
+
+def test_import_async_driver():
+ from neo4j import AsyncDriver
+
+
+def test_import_boltdriver():
+ from neo4j import BoltDriver
+
+
+def test_import_async_boltdriver():
+ from neo4j import AsyncBoltDriver
+
+
+def test_import_neo4jdriver():
+ from neo4j import Neo4jDriver
+
+
+def test_import_async_neo4jdriver():
+ from neo4j import AsyncNeo4jDriver
+
+
+def test_import_auth():
+ from neo4j import Auth
+
+
+def test_import_authtoken():
+ from neo4j import AuthToken
+
+
+def test_import_basic_auth():
+ from neo4j import basic_auth
+
+
+def test_import_bearer_auth():
+ from neo4j import bearer_auth
+
+
+def test_import_kerberos_auth():
+ from neo4j import kerberos_auth
+
+
+def test_import_custom_auth():
+ from neo4j import custom_auth
+
+
+def test_import_read_access():
+ from neo4j import READ_ACCESS
+
+
+def test_import_write_access():
+ from neo4j import WRITE_ACCESS
+
+
+def test_import_transaction():
+ from neo4j import Transaction
+
+
+def test_import_async_transaction():
+ from neo4j import AsyncTransaction
+
+
+def test_import_record():
+ from neo4j import Record
+
+
+def test_import_session():
+ from neo4j import Session
+
+
+def test_import_async_session():
+ from neo4j import AsyncSession
+
+
+def test_import_sessionconfig():
+ from neo4j import SessionConfig
+
+
+def test_import_query():
+ from neo4j import Query
+
+
+def test_import_result():
+ from neo4j import Result
+
+
+def test_import_async_result():
+ from neo4j import AsyncResult
+
+
+def test_import_resultsummary():
+ from neo4j import ResultSummary
+
+
+def test_import_unit_of_work():
+ from neo4j import unit_of_work
+
+
+def test_import_config():
+ from neo4j import Config
+
+
+def test_import_poolconfig():
+ from neo4j import PoolConfig
+
+
+def test_import_graph():
+ import neo4j.graph as graph
+
+
+def test_import_graph_node():
+ from neo4j.graph import Node
+
+
+def test_import_graph_path():
+ from neo4j.graph import Path
+
+
+def test_import_graph_graph():
+ from neo4j.graph import Graph
+
+
+def test_import_spatial():
+ import neo4j.spatial as spatial
+
+
+def test_import_time():
+ import neo4j.time as time
+
+
+def test_import_exceptions():
+ import neo4j.exceptions as exceptions
diff --git a/tests/unit/test_record.py b/tests/unit/common/test_record.py
similarity index 99%
rename from tests/unit/test_record.py
rename to tests/unit/common/test_record.py
index 778b7ea3..85e0bc1f 100644
--- a/tests/unit/test_record.py
+++ b/tests/unit/common/test_record.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -27,6 +24,7 @@
Record,
)
+
# python -m pytest -s -v tests/unit/test_record.py
diff --git a/tests/unit/test_security.py b/tests/unit/common/test_security.py
similarity index 98%
rename from tests/unit/test_security.py
rename to tests/unit/common/test_security.py
index 3ac760fc..247cb4f5 100644
--- a/tests/unit/test_security.py
+++ b/tests/unit/common/test_security.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -20,12 +17,13 @@
from neo4j.api import (
- kerberos_auth,
basic_auth,
bearer_auth,
custom_auth,
+ kerberos_auth,
)
+
# python -m pytest -s -v tests/unit/test_security.py
diff --git a/tests/unit/test_types.py b/tests/unit/common/test_types.py
similarity index 98%
rename from tests/unit/test_types.py
rename to tests/unit/common/test_types.py
index c2454316..979c4080 100644
--- a/tests/unit/test_types.py
+++ b/tests/unit/common/test_types.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -18,18 +15,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+
from itertools import product
import pytest
-from neo4j.data import DataHydrator
from neo4j.graph import (
+ Graph,
Node,
Path,
- Graph,
Relationship,
)
-from neo4j.packstream import Structure
+
# python -m pytest -s -v tests/unit/test_types.py
diff --git a/tests/unit/common/time/__init__.py b/tests/unit/common/time/__init__.py
new file mode 100644
index 00000000..b81a309d
--- /dev/null
+++ b/tests/unit/common/time/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/unit/time/test_clock.py b/tests/unit/common/time/test_clock.py
similarity index 96%
rename from tests/unit/time/test_clock.py
rename to tests/unit/common/time/test_clock.py
index 4b4c18a7..c7bef0e4 100644
--- a/tests/unit/time/test_clock.py
+++ b/tests/unit/common/time/test_clock.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# coding: utf-8
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -21,7 +18,10 @@
from unittest import TestCase
-from neo4j.time import Clock, ClockTime
+from neo4j.time import (
+ Clock,
+ ClockTime,
+)
class TestClock(TestCase):
diff --git a/tests/unit/time/test_clocktime.py b/tests/unit/common/time/test_clocktime.py
similarity index 97%
rename from tests/unit/time/test_clocktime.py
rename to tests/unit/common/time/test_clocktime.py
index ca801d74..253672cd 100644
--- a/tests/unit/time/test_clocktime.py
+++ b/tests/unit/common/time/test_clocktime.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# coding: utf-8
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -21,7 +18,10 @@
from unittest import TestCase
-from neo4j.time import ClockTime, Duration
+from neo4j.time import (
+ ClockTime,
+ Duration,
+)
class TestClockTime(TestCase):
diff --git a/tests/unit/time/test_date.py b/tests/unit/common/time/test_date.py
similarity index 99%
rename from tests/unit/time/test_date.py
rename to tests/unit/common/time/test_date.py
index 3b34bddf..bf48b719 100644
--- a/tests/unit/time/test_date.py
+++ b/tests/unit/common/time/test_date.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# coding: utf-8
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,14 +16,19 @@
# limitations under the License.
+import copy
from datetime import date
from time import struct_time
from unittest import TestCase
import pytz
-import copy
-from neo4j.time import Duration, Date, UnixEpoch, ZeroDate
+from neo4j.time import (
+ Date,
+ Duration,
+ UnixEpoch,
+ ZeroDate,
+)
eastern = pytz.timezone("US/Eastern")
diff --git a/tests/unit/time/test_datetime.py b/tests/unit/common/time/test_datetime.py
similarity index 98%
rename from tests/unit/time/test_datetime.py
rename to tests/unit/common/time/test_datetime.py
index cc3f1ef2..ffc551e6 100644
--- a/tests/unit/time/test_datetime.py
+++ b/tests/unit/common/time/test_datetime.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# coding: utf-8
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -18,24 +15,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+
import copy
-from decimal import Decimal
from datetime import (
datetime,
timedelta,
)
+from decimal import Decimal
import pytest
from pytz import (
- timezone,
FixedOffset,
+ timezone,
)
from neo4j.time import (
DateTime,
- MIN_YEAR,
- MAX_YEAR,
Duration,
+ MAX_YEAR,
+ MIN_YEAR,
)
from neo4j.time.arithmetic import (
nano_add,
@@ -45,17 +43,7 @@
Clock,
ClockTime,
)
-from neo4j.time.hydration import (
- hydrate_date,
- dehydrate_date,
- hydrate_time,
- dehydrate_time,
- hydrate_datetime,
- dehydrate_datetime,
- hydrate_duration,
- dehydrate_duration,
- dehydrate_timedelta,
-)
+
timezone_us_eastern = timezone("US/Eastern")
timezone_utc = timezone("UTC")
diff --git a/tests/unit/time/test_duration.py b/tests/unit/common/time/test_duration.py
similarity index 99%
rename from tests/unit/time/test_duration.py
rename to tests/unit/common/time/test_duration.py
index aa29f750..5edc2c9c 100644
--- a/tests/unit/time/test_duration.py
+++ b/tests/unit/common/time/test_duration.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# coding: utf-8
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,9 +16,9 @@
# limitations under the License.
+import copy
from datetime import timedelta
from decimal import Decimal
-import copy
import pytest
diff --git a/tests/unit/time/test_hydration.py b/tests/unit/common/time/test_hydration.py
similarity index 97%
rename from tests/unit/time/test_hydration.py
rename to tests/unit/common/time/test_hydration.py
index d66f6949..0efe868d 100644
--- a/tests/unit/time/test_hydration.py
+++ b/tests/unit/common/time/test_hydration.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# coding: utf-8
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
diff --git a/tests/unit/time/test_time.py b/tests/unit/common/time/test_time.py
similarity index 99%
rename from tests/unit/time/test_time.py
rename to tests/unit/common/time/test_time.py
index 0336781f..b3bdc938 100644
--- a/tests/unit/time/test_time.py
+++ b/tests/unit/common/time/test_time.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# coding: utf-8
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -24,9 +21,8 @@
import pytest
from pytz import (
- build_tzinfo,
- timezone,
FixedOffset,
+ timezone,
)
from neo4j.time import Time
diff --git a/tests/unit/data/__init__.py b/tests/unit/data/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/tests/unit/io/__init__.py b/tests/unit/io/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/tests/unit/io/test__common.py b/tests/unit/io/test__common.py
deleted file mode 100644
index 3b61c710..00000000
--- a/tests/unit/io/test__common.py
+++ /dev/null
@@ -1,32 +0,0 @@
-import pytest
-
-from neo4j.io._common import Outbox
-
-
-@pytest.mark.parametrize(("chunk_size", "data", "result"), (
- (
- 2,
- (bytes(range(10, 15)),),
- bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 1, 14))
- ),
- (
- 2,
- (bytes(range(10, 14)),),
- bytes((0, 2, 10, 11, 0, 2, 12, 13))
- ),
- (
- 2,
- (bytes((5, 6, 7)), bytes((8, 9))),
- bytes((0, 2, 5, 6, 0, 2, 7, 8, 0, 1, 9))
- ),
-))
-def test_outbox_chunking(chunk_size, data, result):
- outbox = Outbox(max_chunk_size=chunk_size)
- assert bytes(outbox.view()) == b""
- for d in data:
- outbox.write(d)
- assert bytes(outbox.view()) == result
- # make sure this works multiple times
- assert bytes(outbox.view()) == result
- outbox.clear()
- assert bytes(outbox.view()) == b""
diff --git a/tests/unit/io/test_direct.py b/tests/unit/io/test_direct.py
deleted file mode 100644
index 9e572fb2..00000000
--- a/tests/unit/io/test_direct.py
+++ /dev/null
@@ -1,309 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-# Copyright (c) "Neo4j"
-# Neo4j Sweden AB [http://neo4j.com]
-#
-# This file is part of Neo4j.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from unittest import (
- mock,
- TestCase,
-)
-import pytest
-from threading import (
- Condition,
- Event,
- Lock,
- Thread,
-)
-import time
-
-from neo4j import (
- Config,
- PoolConfig,
- WorkspaceConfig,
-)
-from neo4j.io import (
- Bolt,
- BoltPool,
- IOPool
-)
-from neo4j.exceptions import (
- ClientError,
- ServiceUnavailable,
-)
-
-
-class FakeSocket:
- def __init__(self, address):
- self.address = address
-
- def getpeername(self):
- return self.address
-
- def sendall(self, data):
- return
-
- def close(self):
- return
-
-
-class QuickConnection:
-
- def __init__(self, socket):
- self.socket = socket
- self.address = socket.getpeername()
-
- @property
- def is_reset(self):
- return True
-
- def stale(self):
- return False
-
- def reset(self):
- pass
-
- def close(self):
- self.socket.close()
-
- def closed(self):
- return False
-
- def defunct(self):
- return False
-
- def timedout(self):
- return False
-
-
-class FakeBoltPool(IOPool):
-
- def __init__(self, address, *, auth=None, **config):
- self.pool_config, self.workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig)
- if config:
- raise ValueError("Unexpected config keys: %s" % ", ".join(config.keys()))
-
- def opener(addr, timeout):
- return QuickConnection(FakeSocket(addr))
-
- super().__init__(opener, self.pool_config, self.workspace_config)
- self.address = address
-
- def acquire(self, access_mode=None, timeout=None, database=None,
- bookmarks=None):
- return self._acquire(self.address, timeout)
-
-
-class BoltTestCase(TestCase):
-
- def test_open(self):
- with pytest.raises(ServiceUnavailable):
- connection = Bolt.open(("localhost", 9999), auth=("test", "test"))
-
- def test_open_timeout(self):
- with pytest.raises(ServiceUnavailable):
- connection = Bolt.open(("localhost", 9999), auth=("test", "test"), timeout=1)
-
- def test_ping(self):
- protocol_version = Bolt.ping(("localhost", 9999))
- assert protocol_version is None
-
- def test_ping_timeout(self):
- protocol_version = Bolt.ping(("localhost", 9999), timeout=1)
- assert protocol_version is None
-
-
-class MultiEvent:
- # Adopted from threading.Event
-
- def __init__(self):
- super().__init__()
- self._cond = Condition(Lock())
- self._counter = 0
-
- def _reset_internal_locks(self):
- # private! called by Thread._reset_internal_locks by _after_fork()
- self._cond.__init__(Lock())
-
- def counter(self):
- return self._counter
-
- def increment(self):
- with self._cond:
- self._counter += 1
- self._cond.notify_all()
-
- def decrement(self):
- with self._cond:
- self._counter -= 1
- self._cond.notify_all()
-
- def clear(self):
- with self._cond:
- self._counter = 0
- self._cond.notify_all()
-
- def wait(self, value=0, timeout=None):
- with self._cond:
- t_start = time.time()
- while True:
- if value == self._counter:
- return True
- if timeout is None:
- time_left = None
- else:
- time_left = timeout - (time.time() - t_start)
- if time_left <= 0:
- return False
- if not self._cond.wait(time_left):
- return False
-
-
-class ConnectionPoolTestCase(TestCase):
-
- def setUp(self):
- self.pool = FakeBoltPool(("127.0.0.1", 7687))
-
- def tearDown(self):
- self.pool.close()
-
- def assert_pool_size(self, address, expected_active, expected_inactive, pool=None):
- if pool is None:
- pool = self.pool
- try:
- connections = pool.connections[address]
- except KeyError:
- self.assertEqual(0, expected_active)
- self.assertEqual(0, expected_inactive)
- else:
- self.assertEqual(expected_active, len([cx for cx in connections if cx.in_use]))
- self.assertEqual(expected_inactive, len([cx for cx in connections if not cx.in_use]))
-
- def test_can_acquire(self):
- address = ("127.0.0.1", 7687)
- connection = self.pool._acquire(address, timeout=3)
- assert connection.address == address
- self.assert_pool_size(address, 1, 0)
-
- def test_can_acquire_twice(self):
- address = ("127.0.0.1", 7687)
- connection_1 = self.pool._acquire(address, timeout=3)
- connection_2 = self.pool._acquire(address, timeout=3)
- assert connection_1.address == address
- assert connection_2.address == address
- assert connection_1 is not connection_2
- self.assert_pool_size(address, 2, 0)
-
- def test_can_acquire_two_addresses(self):
- address_1 = ("127.0.0.1", 7687)
- address_2 = ("127.0.0.1", 7474)
- connection_1 = self.pool._acquire(address_1, timeout=3)
- connection_2 = self.pool._acquire(address_2, timeout=3)
- assert connection_1.address == address_1
- assert connection_2.address == address_2
- self.assert_pool_size(address_1, 1, 0)
- self.assert_pool_size(address_2, 1, 0)
-
- def test_can_acquire_and_release(self):
- address = ("127.0.0.1", 7687)
- connection = self.pool._acquire(address, timeout=3)
- self.assert_pool_size(address, 1, 0)
- self.pool.release(connection)
- self.assert_pool_size(address, 0, 1)
-
- def test_releasing_twice(self):
- address = ("127.0.0.1", 7687)
- connection = self.pool._acquire(address, timeout=3)
- self.pool.release(connection)
- self.assert_pool_size(address, 0, 1)
- self.pool.release(connection)
- self.assert_pool_size(address, 0, 1)
-
- def test_in_use_count(self):
- address = ("127.0.0.1", 7687)
- self.assertEqual(self.pool.in_use_connection_count(address), 0)
- connection = self.pool._acquire(address, timeout=3)
- self.assertEqual(self.pool.in_use_connection_count(address), 1)
- self.pool.release(connection)
- self.assertEqual(self.pool.in_use_connection_count(address), 0)
-
- def test_max_conn_pool_size(self):
- with FakeBoltPool((), max_connection_pool_size=1) as pool:
- address = ("127.0.0.1", 7687)
- pool._acquire(address, timeout=0)
- self.assertEqual(pool.in_use_connection_count(address), 1)
- with self.assertRaises(ClientError):
- pool._acquire(address, timeout=0)
- self.assertEqual(pool.in_use_connection_count(address), 1)
-
- def test_multithread(self):
- with FakeBoltPool((), max_connection_pool_size=5) as pool:
- address = ("127.0.0.1", 7687)
- acquired_counter = MultiEvent()
- release_event = Event()
-
- # start 10 threads competing for connections from a pool of size 5
- threads = []
- for i in range(10):
- t = Thread(
- target=acquire_release_conn,
- args=(pool, address, acquired_counter, release_event),
- daemon=True
- )
- t.start()
- threads.append(t)
-
- if not acquired_counter.wait(5, timeout=1):
- raise RuntimeError("Acquire threads not fast enough")
- # The pool size should be 5, all are in-use
- self.assert_pool_size(address, 5, 0, pool)
- # Now we allow thread to release connections they obtained from pool
- release_event.set()
-
- # wait for all threads to release connections back to pool
- for t in threads:
- t.join(timeout=1)
- # The pool size is still 5, but all are free
- self.assert_pool_size(address, 0, 5, pool)
-
- def test_reset_when_released(self):
- def test(is_reset):
- with mock.patch(__name__ + ".QuickConnection.is_reset",
- new_callable=mock.PropertyMock) as is_reset_mock:
- with mock.patch(__name__ + ".QuickConnection.reset",
- new_callable=mock.MagicMock) as reset_mock:
- is_reset_mock.return_value = is_reset
- connection = self.pool._acquire(address, timeout=3)
- self.assertIsInstance(connection, QuickConnection)
- self.assertEqual(is_reset_mock.call_count, 0)
- self.assertEqual(reset_mock.call_count, 0)
- self.pool.release(connection)
- self.assertEqual(is_reset_mock.call_count, 1)
- self.assertEqual(reset_mock.call_count, int(not is_reset))
-
- address = ("127.0.0.1", 7687)
- for is_reset in (True, False):
- with self.subTest():
- test(is_reset)
-
-
-def acquire_release_conn(pool, address, acquired_counter, release_event):
- conn = pool._acquire(address, timeout=3)
- acquired_counter.increment()
- release_event.wait()
- pool.release(conn)
diff --git a/tests/unit/mixed/__init__.py b/tests/unit/mixed/__init__.py
new file mode 100644
index 00000000..b81a309d
--- /dev/null
+++ b/tests/unit/mixed/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/unit/mixed/async_compat/__init__.py b/tests/unit/mixed/async_compat/__init__.py
new file mode 100644
index 00000000..b81a309d
--- /dev/null
+++ b/tests/unit/mixed/async_compat/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/unit/mixed/async_compat/test_concurrency.py b/tests/unit/mixed/async_compat/test_concurrency.py
new file mode 100644
index 00000000..1bce780d
--- /dev/null
+++ b/tests/unit/mixed/async_compat/test_concurrency.py
@@ -0,0 +1,115 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import asyncio
+
+import pytest
+
+from neo4j._async_compat.concurrency import AsyncRLock
+
+
+@pytest.mark.asyncio
+async def test_async_r_lock():
+ counter = 1
+ lock = AsyncRLock()
+
+ async def worker():
+ nonlocal counter
+ async with lock:
+ counter_ = counter
+ counter += 1
+ await asyncio.sleep(0)
+ # assert no one else touched the counter
+ assert counter == counter_ + 1
+
+ assert not lock.locked()
+ await asyncio.gather(worker(), worker(), worker())
+ assert not lock.locked()
+
+
+@pytest.mark.asyncio
+async def test_async_r_lock_is_reentrant():
+ lock = AsyncRLock()
+
+ async def worker():
+ async with lock:
+ assert lock._count == 1
+ async with lock:
+ assert lock._count == 2
+ assert lock.locked()
+
+ assert not lock.locked()
+ await asyncio.gather(worker(), worker(), worker())
+ assert not lock.locked()
+
+
+@pytest.mark.asyncio
+async def test_async_r_lock_acquire_timeout_blocked():
+ lock = AsyncRLock()
+
+ async def blocker():
+ await lock.acquire()
+
+ async def waiter():
+ # make sure blocker has a chance to acquire the lock
+ await asyncio.sleep(0)
+ await lock.acquire(0.1)
+
+ assert not lock.locked()
+ with pytest.raises(asyncio.TimeoutError):
+ await asyncio.gather(blocker(), waiter())
+
+
+@pytest.mark.asyncio
+async def test_async_r_lock_acquire_timeout_released():
+ lock = AsyncRLock()
+
+ async def blocker():
+ await lock.acquire()
+ await asyncio.sleep(0)
+ # waiter: lock.acquire(0.1)
+ lock.release()
+
+ async def waiter():
+ await asyncio.sleep(0)
+ # blocker: lock.acquire()
+ await lock.acquire(0.1)
+ # blocker: lock.release()
+
+ assert not lock.locked()
+ await asyncio.gather(blocker(), waiter())
+ assert lock.locked() # waiter still owns it!
+
+
+@pytest.mark.asyncio
+async def test_async_r_lock_acquire_timeout_reentrant():
+ lock = AsyncRLock()
+ assert not lock.locked()
+
+ await lock.acquire()
+ assert lock._count == 1
+ await lock.acquire()
+ assert lock._count == 2
+ await lock.acquire(0.1)
+ assert lock._count == 3
+ await lock.acquire(0.1)
+ assert lock._count == 4
+ for _ in range(4):
+ lock.release()
+
+ assert not lock.locked()
diff --git a/tests/unit/mixed/io/__init__.py b/tests/unit/mixed/io/__init__.py
new file mode 100644
index 00000000..b81a309d
--- /dev/null
+++ b/tests/unit/mixed/io/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/unit/mixed/io/test_direct.py b/tests/unit/mixed/io/test_direct.py
new file mode 100644
index 00000000..4cf6afa5
--- /dev/null
+++ b/tests/unit/mixed/io/test_direct.py
@@ -0,0 +1,243 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import asyncio
+from asyncio import (
+ Condition as AsyncCondition,
+ Event as AsyncEvent,
+ Lock as AsyncLock,
+)
+from threading import (
+ Condition,
+ Event,
+ Lock,
+ Thread,
+)
+import time
+from unittest import (
+ mock,
+ TestCase,
+)
+
+import pytest
+
+from neo4j import (
+ Config,
+ PoolConfig,
+ WorkspaceConfig,
+)
+from neo4j._async.io import AsyncBolt
+from neo4j._async.io._pool import AsyncIOPool
+from neo4j.exceptions import (
+ ClientError,
+ ServiceUnavailable,
+)
+
+from ...async_.io.test_direct import AsyncFakeBoltPool
+from ...sync.io.test_direct import FakeBoltPool
+
+
+class MultiEvent:
+ # Adopted from threading.Event
+
+ def __init__(self):
+ super().__init__()
+ self._cond = Condition(Lock())
+ self._counter = 0
+
+ def _reset_internal_locks(self):
+ # private! called by Thread._reset_internal_locks by _after_fork()
+ self._cond.__init__(Lock())
+
+ def counter(self):
+ return self._counter
+
+ def increment(self):
+ with self._cond:
+ self._counter += 1
+ self._cond.notify_all()
+
+ def decrement(self):
+ with self._cond:
+ self._counter -= 1
+ self._cond.notify_all()
+
+ def clear(self):
+ with self._cond:
+ self._counter = 0
+ self._cond.notify_all()
+
+ def wait(self, value=0, timeout=None):
+ with self._cond:
+ t_start = time.time()
+ while True:
+ if value == self._counter:
+ return True
+ if timeout is None:
+ time_left = None
+ else:
+ time_left = timeout - (time.time() - t_start)
+ if time_left <= 0:
+ return False
+ if not self._cond.wait(time_left):
+ return False
+
+
+class AsyncMultiEvent:
+ # Adopted from threading.Event
+
+ def __init__(self):
+ super().__init__()
+ self._cond = AsyncCondition()
+ self._counter = 0
+
+ def _reset_internal_locks(self):
+ # private! called by Thread._reset_internal_locks by _after_fork()
+ self._cond.__init__(AsyncLock())
+
+ def counter(self):
+ return self._counter
+
+ async def increment(self):
+ async with self._cond:
+ self._counter += 1
+ self._cond.notify_all()
+
+ async def decrement(self):
+ async with self._cond:
+ self._counter -= 1
+ self._cond.notify_all()
+
+ async def clear(self):
+ async with self._cond:
+ self._counter = 0
+ self._cond.notify_all()
+
+ async def wait(self, value=0, timeout=None):
+ async with self._cond:
+ t_start = time.time()
+ while True:
+ if value == self._counter:
+ return True
+ if timeout is None:
+ time_left = None
+ else:
+ time_left = timeout - (time.time() - t_start)
+ if time_left <= 0:
+ return False
+ try:
+ await asyncio.wait_for(self._cond.wait(), time_left)
+ except asyncio.TimeoutError:
+ return False
+
+
+class TestMixedConnectionPoolTestCase:
+ def assert_pool_size(self, address, expected_active, expected_inactive,
+ pool):
+ try:
+ connections = pool.connections[address]
+ except KeyError:
+ assert 0 == expected_active
+ assert 0 == expected_inactive
+ else:
+ assert (expected_active
+ == len([cx for cx in connections if cx.in_use]))
+ assert (expected_inactive
+ == len([cx for cx in connections if not cx.in_use]))
+
+ def test_multithread(self):
+ def acquire_release_conn(pool, address, acquired_counter,
+ release_event):
+ conn = pool._acquire(address, timeout=3)
+ acquired_counter.increment()
+ release_event.wait()
+ pool.release(conn)
+
+ with FakeBoltPool((), max_connection_pool_size=5) as pool:
+ address = ("127.0.0.1", 7687)
+ acquired_counter = MultiEvent()
+ release_event = Event()
+
+ # start 10 threads competing for connections from a pool of size 5
+ threads = []
+ for i in range(10):
+ t = Thread(
+ target=acquire_release_conn,
+ args=(pool, address, acquired_counter, release_event),
+ daemon=True
+ )
+ t.start()
+ threads.append(t)
+
+ if not acquired_counter.wait(5, timeout=1):
+ raise RuntimeError("Acquire threads not fast enough")
+ # The pool size should be 5, all are in-use
+ self.assert_pool_size(address, 5, 0, pool)
+ # Now we allow the threads to release connections they obtained
+ # from the pool
+ release_event.set()
+
+ # wait for all threads to release connections back to pool
+ for t in threads:
+ t.join(timeout=1)
+ # The pool size is still 5, but all are free
+ self.assert_pool_size(address, 0, 5, pool)
+
+ @pytest.mark.asyncio
+ async def test_multi_coroutine(self):
+ async def acquire_release_conn(pool_, address_, acquired_counter_,
+ release_event_):
+ try:
+ conn = await pool_._acquire(address_, timeout=3)
+ await acquired_counter_.increment()
+ await release_event_.wait()
+ await pool_.release(conn)
+ except ClientError:
+ raise
+
+ async def waiter(pool_, acquired_counter_, release_event_):
+ if not await acquired_counter_.wait(5, timeout=1):
+ raise RuntimeError("Acquire coroutines not fast enough")
+ # The pool size should be 5, all are in-use
+ self.assert_pool_size(address, 5, 0, pool_)
+ # Now we allow the coroutines to release connections they obtained
+ # from the pool
+ release_event_.set()
+
+ # wait for all coroutines to release connections back to pool
+ if not await acquired_counter_.wait(10, timeout=5):
+ raise RuntimeError("Acquire coroutines not fast enough")
+ # The pool size is still 5, but all are free
+ self.assert_pool_size(address, 0, 5, pool_)
+
+ async with AsyncFakeBoltPool((), max_connection_pool_size=5) as pool:
+ address = ("127.0.0.1", 7687)
+ acquired_counter = AsyncMultiEvent()
+ release_event = AsyncEvent()
+
+ # start 10 coroutines competing for connections from a pool of size
+ # 5
+ coroutines = [
+ acquire_release_conn(
+ pool, address, acquired_counter, release_event
+ ) for _ in range(10)
+ ]
+ await asyncio.gather(
+ waiter(pool, acquired_counter, release_event),
+ *coroutines
+ )
diff --git a/tests/unit/spatial/__init__.py b/tests/unit/spatial/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/tests/unit/sync/__init__.py b/tests/unit/sync/__init__.py
new file mode 100644
index 00000000..b81a309d
--- /dev/null
+++ b/tests/unit/sync/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/unit/sync/io/__init__.py b/tests/unit/sync/io/__init__.py
new file mode 100644
index 00000000..b81a309d
--- /dev/null
+++ b/tests/unit/sync/io/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/unit/io/conftest.py b/tests/unit/sync/io/conftest.py
similarity index 92%
rename from tests/unit/io/conftest.py
rename to tests/unit/sync/io/conftest.py
index 0f8d21ea..33309fc9 100644
--- a/tests/unit/io/conftest.py
+++ b/tests/unit/sync/io/conftest.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -20,12 +17,19 @@
from io import BytesIO
-from struct import pack as struct_pack, unpack as struct_unpack
+from struct import (
+ pack as struct_pack,
+ unpack as struct_unpack,
+)
import pytest
-from neo4j.io._common import MessageInbox
-from neo4j.packstream import Packer, UnpackableBuffer, Unpacker
+from neo4j._sync.io._common import MessageInbox
+from neo4j.packstream import (
+ Packer,
+ UnpackableBuffer,
+ Unpacker,
+)
class FakeSocket:
@@ -89,7 +93,7 @@ def close(self):
def inject(self, data):
self.recv_buffer += data
- def pop_chunk(self):
+ def _pop_chunk(self):
chunk_size, = struct_unpack(">H", self.recv_buffer[:2])
print("CHUNK SIZE %r" % chunk_size)
end = 2 + chunk_size
@@ -99,7 +103,7 @@ def pop_chunk(self):
def pop_message(self):
data = bytearray()
while True:
- chunk = self.pop_chunk()
+ chunk = self._pop_chunk()
print("CHUNK %r" % chunk)
if chunk:
data.extend(chunk)
diff --git a/tests/unit/sync/io/test__common.py b/tests/unit/sync/io/test__common.py
new file mode 100644
index 00000000..5106a2da
--- /dev/null
+++ b/tests/unit/sync/io/test__common.py
@@ -0,0 +1,50 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import pytest
+
+from neo4j._sync.io._common import Outbox
+
+
+@pytest.mark.parametrize(("chunk_size", "data", "result"), (
+ (
+ 2,
+ (bytes(range(10, 15)),),
+ bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 1, 14))
+ ),
+ (
+ 2,
+ (bytes(range(10, 14)),),
+ bytes((0, 2, 10, 11, 0, 2, 12, 13))
+ ),
+ (
+ 2,
+ (bytes((5, 6, 7)), bytes((8, 9))),
+ bytes((0, 2, 5, 6, 0, 2, 7, 8, 0, 1, 9))
+ ),
+))
+def test_async_outbox_chunking(chunk_size, data, result):
+ outbox = Outbox(max_chunk_size=chunk_size)
+ assert bytes(outbox.view()) == b""
+ for d in data:
+ outbox.write(d)
+ assert bytes(outbox.view()) == result
+ # make sure this works multiple times
+ assert bytes(outbox.view()) == result
+ outbox.clear()
+ assert bytes(outbox.view()) == b""
diff --git a/tests/unit/io/test_class_bolt.py b/tests/unit/sync/io/test_class_bolt.py
similarity index 96%
rename from tests/unit/io/test_class_bolt.py
rename to tests/unit/sync/io/test_class_bolt.py
index 2001546c..b7d1e6c5 100644
--- a/tests/unit/io/test_class_bolt.py
+++ b/tests/unit/sync/io/test_class_bolt.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -20,7 +17,9 @@
import pytest
-from neo4j.io import Bolt
+
+from neo4j._sync.io import Bolt
+
# python -m pytest tests/unit/io/test_class_bolt.py -s -v
diff --git a/tests/unit/io/test_class_bolt3.py b/tests/unit/sync/io/test_class_bolt3.py
similarity index 88%
rename from tests/unit/io/test_class_bolt3.py
rename to tests/unit/sync/io/test_class_bolt3.py
index f7d63e85..b42512d0 100644
--- a/tests/unit/io/test_class_bolt3.py
+++ b/tests/unit/sync/io/test_class_bolt3.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -18,17 +15,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from unittest.mock import MagicMock
import pytest
-from neo4j.io._bolt3 import Bolt3
+from neo4j._sync.io._bolt3 import Bolt3
from neo4j.conf import PoolConfig
-from neo4j.exceptions import (
- ConfigurationError,
-)
+from neo4j.exceptions import ConfigurationError
-# python -m pytest tests/unit/io/test_class_bolt3.py -s -v
+from ..._async_compat import (
+ MagicMock,
+ mark_sync_test,
+)
@pytest.mark.parametrize("set_stale", (True, False))
@@ -75,6 +72,7 @@ def test_db_extra_not_supported_in_run(fake_socket):
connection.run("", db="something")
+@mark_sync_test
def test_simple_discard(fake_socket):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -86,6 +84,7 @@ def test_simple_discard(fake_socket):
assert len(fields) == 0
+@mark_sync_test
def test_simple_pull(fake_socket):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -98,7 +97,10 @@ def test_simple_pull(fake_socket):
@pytest.mark.parametrize("recv_timeout", (1, -1))
-def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout):
+@mark_sync_test
+def test_hint_recv_timeout_seconds_gets_ignored(
+ fake_socket_pair, recv_timeout
+):
address = ("127.0.0.1", 7687)
sockets = fake_socket_pair(address)
sockets.client.settimeout = MagicMock()
@@ -106,7 +108,8 @@ def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout):
"server": "Neo4j/3.5.0",
"hints": {"connection.recv_timeout_seconds": recv_timeout},
})
- connection = Bolt3(address, sockets.client,
- PoolConfig.max_connection_lifetime)
+ connection = Bolt3(
+ address, sockets.client, PoolConfig.max_connection_lifetime
+ )
connection.hello()
sockets.client.settimeout.assert_not_called()
diff --git a/tests/unit/io/test_class_bolt4x0.py b/tests/unit/sync/io/test_class_bolt4x0.py
similarity index 92%
rename from tests/unit/io/test_class_bolt4x0.py
rename to tests/unit/sync/io/test_class_bolt4x0.py
index 3879acb0..5f94d5c0 100644
--- a/tests/unit/io/test_class_bolt4x0.py
+++ b/tests/unit/sync/io/test_class_bolt4x0.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -18,13 +15,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+
from unittest.mock import MagicMock
import pytest
-from neo4j.io._bolt4 import Bolt4x0
+from neo4j._sync.io._bolt4 import Bolt4x0
from neo4j.conf import PoolConfig
+from ..._async_compat import mark_sync_test
+
@pytest.mark.parametrize("set_stale", (True, False))
def test_conn_is_stale(fake_socket, set_stale):
@@ -56,6 +56,7 @@ def test_conn_is_not_stale(fake_socket, set_stale):
assert connection.stale() is set_stale
+@mark_sync_test
def test_db_extra_in_begin(fake_socket):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -68,6 +69,7 @@ def test_db_extra_in_begin(fake_socket):
assert fields[0] == {"db": "something"}
+@mark_sync_test
def test_db_extra_in_run(fake_socket):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -82,6 +84,7 @@ def test_db_extra_in_run(fake_socket):
assert fields[2] == {"db": "something"}
+@mark_sync_test
def test_n_extra_in_discard(fake_socket):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -101,6 +104,7 @@ def test_n_extra_in_discard(fake_socket):
(-1, {"n": -1}),
]
)
+@mark_sync_test
def test_qid_extra_in_discard(fake_socket, test_input, expected):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -120,8 +124,8 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected):
(-1, {"n": 666}),
]
)
+@mark_sync_test
def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected):
- # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime)
@@ -140,6 +144,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected):
(-1, {"n": -1}),
]
)
+@mark_sync_test
def test_n_extra_in_pull(fake_socket, test_input, expected):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -159,8 +164,8 @@ def test_n_extra_in_pull(fake_socket, test_input, expected):
(-1, {"n": -1}),
]
)
+@mark_sync_test
def test_qid_extra_in_pull(fake_socket, test_input, expected):
- # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime)
@@ -172,6 +177,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected):
assert fields[0] == expected
+@mark_sync_test
def test_n_and_qid_extras_in_pull(fake_socket):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -185,7 +191,10 @@ def test_n_and_qid_extras_in_pull(fake_socket):
@pytest.mark.parametrize("recv_timeout", (1, -1))
-def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout):
+@mark_sync_test
+def test_hint_recv_timeout_seconds_gets_ignored(
+ fake_socket_pair, recv_timeout
+):
address = ("127.0.0.1", 7687)
sockets = fake_socket_pair(address)
sockets.client.settimeout = MagicMock()
@@ -193,7 +202,8 @@ def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout):
"server": "Neo4j/4.0.0",
"hints": {"connection.recv_timeout_seconds": recv_timeout},
})
- connection = Bolt4x0(address, sockets.client,
- PoolConfig.max_connection_lifetime)
+ connection = Bolt4x0(
+ address, sockets.client, PoolConfig.max_connection_lifetime
+ )
connection.hello()
sockets.client.settimeout.assert_not_called()
diff --git a/tests/unit/io/test_class_bolt4x1.py b/tests/unit/sync/io/test_class_bolt4x1.py
similarity index 92%
rename from tests/unit/io/test_class_bolt4x1.py
rename to tests/unit/sync/io/test_class_bolt4x1.py
index 663d3cbe..2d69b9de 100644
--- a/tests/unit/io/test_class_bolt4x1.py
+++ b/tests/unit/sync/io/test_class_bolt4x1.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -18,13 +15,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from unittest.mock import MagicMock
import pytest
-from neo4j.io._bolt4 import Bolt4x1
+from neo4j._sync.io._bolt4 import Bolt4x1
from neo4j.conf import PoolConfig
+from ..._async_compat import (
+ MagicMock,
+ mark_sync_test,
+)
+
@pytest.mark.parametrize("set_stale", (True, False))
def test_conn_is_stale(fake_socket, set_stale):
@@ -56,6 +57,7 @@ def test_conn_is_not_stale(fake_socket, set_stale):
assert connection.stale() is set_stale
+@mark_sync_test
def test_db_extra_in_begin(fake_socket):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -68,6 +70,7 @@ def test_db_extra_in_begin(fake_socket):
assert fields[0] == {"db": "something"}
+@mark_sync_test
def test_db_extra_in_run(fake_socket):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -82,6 +85,7 @@ def test_db_extra_in_run(fake_socket):
assert fields[2] == {"db": "something"}
+@mark_sync_test
def test_n_extra_in_discard(fake_socket):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -101,6 +105,7 @@ def test_n_extra_in_discard(fake_socket):
(-1, {"n": -1}),
]
)
+@mark_sync_test
def test_qid_extra_in_discard(fake_socket, test_input, expected):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -120,6 +125,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected):
(-1, {"n": 666}),
]
)
+@mark_sync_test
def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected):
# python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard
address = ("127.0.0.1", 7687)
@@ -140,6 +146,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected):
(-1, {"n": -1}),
]
)
+@mark_sync_test
def test_n_extra_in_pull(fake_socket, test_input, expected):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -159,6 +166,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected):
(-1, {"n": -1}),
]
)
+@mark_sync_test
def test_qid_extra_in_pull(fake_socket, test_input, expected):
# python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull
address = ("127.0.0.1", 7687)
@@ -172,6 +180,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected):
assert fields[0] == expected
+@mark_sync_test
def test_n_and_qid_extras_in_pull(fake_socket):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -184,12 +193,15 @@ def test_n_and_qid_extras_in_pull(fake_socket):
assert fields[0] == {"n": 666, "qid": 777}
+@mark_sync_test
def test_hello_passes_routing_metadata(fake_socket_pair):
address = ("127.0.0.1", 7687)
sockets = fake_socket_pair(address)
sockets.server.send_message(0x70, {"server": "Neo4j/4.1.0"})
- connection = Bolt4x1(address, sockets.client, PoolConfig.max_connection_lifetime,
- routing_context={"foo": "bar"})
+ connection = Bolt4x1(
+ address, sockets.client, PoolConfig.max_connection_lifetime,
+ routing_context={"foo": "bar"}
+ )
connection.hello()
tag, fields = sockets.server.pop_message()
assert tag == 0x01
@@ -198,7 +210,10 @@ def test_hello_passes_routing_metadata(fake_socket_pair):
@pytest.mark.parametrize("recv_timeout", (1, -1))
-def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout):
+@mark_sync_test
+def test_hint_recv_timeout_seconds_gets_ignored(
+ fake_socket_pair, recv_timeout
+):
address = ("127.0.0.1", 7687)
sockets = fake_socket_pair(address)
sockets.client.settimeout = MagicMock()
diff --git a/tests/unit/io/test_class_bolt4x2.py b/tests/unit/sync/io/test_class_bolt4x2.py
similarity index 91%
rename from tests/unit/io/test_class_bolt4x2.py
rename to tests/unit/sync/io/test_class_bolt4x2.py
index 470adf5c..03605796 100644
--- a/tests/unit/io/test_class_bolt4x2.py
+++ b/tests/unit/sync/io/test_class_bolt4x2.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -19,13 +16,16 @@
# limitations under the License.
-from unittest.mock import MagicMock
-
import pytest
-from neo4j.io._bolt4 import Bolt4x2
+from neo4j._sync.io._bolt4 import Bolt4x2
from neo4j.conf import PoolConfig
+from ..._async_compat import (
+ MagicMock,
+ mark_sync_test,
+)
+
@pytest.mark.parametrize("set_stale", (True, False))
def test_conn_is_stale(fake_socket, set_stale):
@@ -57,6 +57,7 @@ def test_conn_is_not_stale(fake_socket, set_stale):
assert connection.stale() is set_stale
+@mark_sync_test
def test_db_extra_in_begin(fake_socket):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -69,6 +70,7 @@ def test_db_extra_in_begin(fake_socket):
assert fields[0] == {"db": "something"}
+@mark_sync_test
def test_db_extra_in_run(fake_socket):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -83,6 +85,7 @@ def test_db_extra_in_run(fake_socket):
assert fields[2] == {"db": "something"}
+@mark_sync_test
def test_n_extra_in_discard(fake_socket):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -102,6 +105,7 @@ def test_n_extra_in_discard(fake_socket):
(-1, {"n": -1}),
]
)
+@mark_sync_test
def test_qid_extra_in_discard(fake_socket, test_input, expected):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -121,6 +125,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected):
(-1, {"n": 666}),
]
)
+@mark_sync_test
def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected):
# python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard
address = ("127.0.0.1", 7687)
@@ -141,6 +146,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected):
(-1, {"n": -1}),
]
)
+@mark_sync_test
def test_n_extra_in_pull(fake_socket, test_input, expected):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -160,6 +166,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected):
(-1, {"n": -1}),
]
)
+@mark_sync_test
def test_qid_extra_in_pull(fake_socket, test_input, expected):
# python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull
address = ("127.0.0.1", 7687)
@@ -173,6 +180,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected):
assert fields[0] == expected
+@mark_sync_test
def test_n_and_qid_extras_in_pull(fake_socket):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -185,12 +193,15 @@ def test_n_and_qid_extras_in_pull(fake_socket):
assert fields[0] == {"n": 666, "qid": 777}
+@mark_sync_test
def test_hello_passes_routing_metadata(fake_socket_pair):
address = ("127.0.0.1", 7687)
sockets = fake_socket_pair(address)
sockets.server.send_message(0x70, {"server": "Neo4j/4.2.0"})
- connection = Bolt4x2(address, sockets.client, PoolConfig.max_connection_lifetime,
- routing_context={"foo": "bar"})
+ connection = Bolt4x2(
+ address, sockets.client, PoolConfig.max_connection_lifetime,
+ routing_context={"foo": "bar"}
+ )
connection.hello()
tag, fields = sockets.server.pop_message()
assert tag == 0x01
@@ -199,7 +210,10 @@ def test_hello_passes_routing_metadata(fake_socket_pair):
@pytest.mark.parametrize("recv_timeout", (1, -1))
-def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout):
+@mark_sync_test
+def test_hint_recv_timeout_seconds_gets_ignored(
+ fake_socket_pair, recv_timeout
+):
address = ("127.0.0.1", 7687)
sockets = fake_socket_pair(address)
sockets.client.settimeout = MagicMock()
@@ -207,7 +221,8 @@ def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout):
"server": "Neo4j/4.2.0",
"hints": {"connection.recv_timeout_seconds": recv_timeout},
})
- connection = Bolt4x2(address, sockets.client,
- PoolConfig.max_connection_lifetime)
+ connection = Bolt4x2(
+ address, sockets.client, PoolConfig.max_connection_lifetime
+ )
connection.hello()
sockets.client.settimeout.assert_not_called()
diff --git a/tests/unit/io/test_class_bolt4x3.py b/tests/unit/sync/io/test_class_bolt4x3.py
similarity index 92%
rename from tests/unit/io/test_class_bolt4x3.py
rename to tests/unit/sync/io/test_class_bolt4x3.py
index fc08f5b9..43469fc9 100644
--- a/tests/unit/io/test_class_bolt4x3.py
+++ b/tests/unit/sync/io/test_class_bolt4x3.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -18,14 +15,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+
import logging
-from unittest.mock import MagicMock
import pytest
-from neo4j.io._bolt4 import Bolt4x3
+from neo4j._sync.io._bolt4 import Bolt4x3
from neo4j.conf import PoolConfig
+from ..._async_compat import (
+ MagicMock,
+ mark_sync_test,
+)
+
@pytest.mark.parametrize("set_stale", (True, False))
def test_conn_is_stale(fake_socket, set_stale):
@@ -57,6 +59,7 @@ def test_conn_is_not_stale(fake_socket, set_stale):
assert connection.stale() is set_stale
+@mark_sync_test
def test_db_extra_in_begin(fake_socket):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -69,6 +72,7 @@ def test_db_extra_in_begin(fake_socket):
assert fields[0] == {"db": "something"}
+@mark_sync_test
def test_db_extra_in_run(fake_socket):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -83,6 +87,7 @@ def test_db_extra_in_run(fake_socket):
assert fields[2] == {"db": "something"}
+@mark_sync_test
def test_n_extra_in_discard(fake_socket):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -102,6 +107,7 @@ def test_n_extra_in_discard(fake_socket):
(-1, {"n": -1}),
]
)
+@mark_sync_test
def test_qid_extra_in_discard(fake_socket, test_input, expected):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -121,6 +127,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected):
(-1, {"n": 666}),
]
)
+@mark_sync_test
def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected):
# python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard
address = ("127.0.0.1", 7687)
@@ -141,6 +148,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected):
(-1, {"n": -1}),
]
)
+@mark_sync_test
def test_n_extra_in_pull(fake_socket, test_input, expected):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -160,6 +168,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected):
(-1, {"n": -1}),
]
)
+@mark_sync_test
def test_qid_extra_in_pull(fake_socket, test_input, expected):
# python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull
address = ("127.0.0.1", 7687)
@@ -173,6 +182,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected):
assert fields[0] == expected
+@mark_sync_test
def test_n_and_qid_extras_in_pull(fake_socket):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -185,13 +195,15 @@ def test_n_and_qid_extras_in_pull(fake_socket):
assert fields[0] == {"n": 666, "qid": 777}
+@mark_sync_test
def test_hello_passes_routing_metadata(fake_socket_pair):
address = ("127.0.0.1", 7687)
sockets = fake_socket_pair(address)
sockets.server.send_message(0x70, {"server": "Neo4j/4.3.0"})
- connection = Bolt4x3(address, sockets.client,
- PoolConfig.max_connection_lifetime,
- routing_context={"foo": "bar"})
+ connection = Bolt4x3(
+ address, sockets.client, PoolConfig.max_connection_lifetime,
+ routing_context={"foo": "bar"}
+ )
connection.hello()
tag, fields = sockets.server.pop_message()
assert tag == 0x01
@@ -211,12 +223,16 @@ def test_hello_passes_routing_metadata(fake_socket_pair):
({"connection.recv_timeout_seconds": False}, False),
({"connection.recv_timeout_seconds": "1"}, False),
))
-def test_hint_recv_timeout_seconds(fake_socket_pair, hints, valid,
- caplog):
+@mark_sync_test
+def test_hint_recv_timeout_seconds(
+ fake_socket_pair, hints, valid, caplog
+):
address = ("127.0.0.1", 7687)
sockets = fake_socket_pair(address)
sockets.client.settimeout = MagicMock()
- sockets.server.send_message(0x70, {"server": "Neo4j/4.3.0", "hints": hints})
+ sockets.server.send_message(
+ 0x70, {"server": "Neo4j/4.3.0", "hints": hints}
+ )
connection = Bolt4x3(address, sockets.client,
PoolConfig.max_connection_lifetime)
with caplog.at_level(logging.INFO):
diff --git a/tests/unit/io/test_class_bolt4x4.py b/tests/unit/sync/io/test_class_bolt4x4.py
similarity index 92%
rename from tests/unit/io/test_class_bolt4x4.py
rename to tests/unit/sync/io/test_class_bolt4x4.py
index 19378a1c..b2523b1c 100644
--- a/tests/unit/io/test_class_bolt4x4.py
+++ b/tests/unit/sync/io/test_class_bolt4x4.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -18,14 +15,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+
import logging
from unittest.mock import MagicMock
import pytest
-from neo4j.io._bolt4 import Bolt4x4
+from neo4j._sync.io._bolt4 import Bolt4x4
from neo4j.conf import PoolConfig
+from ..._async_compat import (
+ MagicMock,
+ mark_sync_test,
+)
+
@pytest.mark.parametrize("set_stale", (True, False))
def test_conn_is_stale(fake_socket, set_stale):
@@ -56,6 +59,7 @@ def test_conn_is_not_stale(fake_socket, set_stale):
connection.set_stale()
assert connection.stale() is set_stale
+
@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), (
(("", {}), {"db": "something"}, ({"db": "something"},)),
(("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)),
@@ -65,6 +69,7 @@ def test_conn_is_not_stale(fake_socket, set_stale):
({"db": "something", "imp_user": "imposter"},)
),
))
+@mark_sync_test
def test_extra_in_begin(fake_socket, args, kwargs, expected_fields):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -75,6 +80,7 @@ def test_extra_in_begin(fake_socket, args, kwargs, expected_fields):
assert tag == b"\x11"
assert tuple(is_fields) == expected_fields
+
@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), (
(("", {}), {"db": "something"}, ("", {}, {"db": "something"})),
(("", {}), {"imp_user": "imposter"}, ("", {}, {"imp_user": "imposter"})),
@@ -84,6 +90,7 @@ def test_extra_in_begin(fake_socket, args, kwargs, expected_fields):
("", {}, {"db": "something", "imp_user": "imposter"})
),
))
+@mark_sync_test
def test_extra_in_run(fake_socket, args, kwargs, expected_fields):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -95,6 +102,7 @@ def test_extra_in_run(fake_socket, args, kwargs, expected_fields):
assert tuple(is_fields) == expected_fields
+@mark_sync_test
def test_n_extra_in_discard(fake_socket):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -114,6 +122,7 @@ def test_n_extra_in_discard(fake_socket):
(-1, {"n": -1}),
]
)
+@mark_sync_test
def test_qid_extra_in_discard(fake_socket, test_input, expected):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -133,6 +142,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected):
(-1, {"n": 666}),
]
)
+@mark_sync_test
def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected):
# python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard
address = ("127.0.0.1", 7687)
@@ -153,6 +163,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected):
(-1, {"n": -1}),
]
)
+@mark_sync_test
def test_n_extra_in_pull(fake_socket, test_input, expected):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -172,6 +183,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected):
(-1, {"n": -1}),
]
)
+@mark_sync_test
def test_qid_extra_in_pull(fake_socket, test_input, expected):
# python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull
address = ("127.0.0.1", 7687)
@@ -185,6 +197,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected):
assert fields[0] == expected
+@mark_sync_test
def test_n_and_qid_extras_in_pull(fake_socket):
address = ("127.0.0.1", 7687)
socket = fake_socket(address)
@@ -197,13 +210,15 @@ def test_n_and_qid_extras_in_pull(fake_socket):
assert fields[0] == {"n": 666, "qid": 777}
+@mark_sync_test
def test_hello_passes_routing_metadata(fake_socket_pair):
address = ("127.0.0.1", 7687)
sockets = fake_socket_pair(address)
sockets.server.send_message(0x70, {"server": "Neo4j/4.4.0"})
- connection = Bolt4x4(address, sockets.client,
- PoolConfig.max_connection_lifetime,
- routing_context={"foo": "bar"})
+ connection = Bolt4x4(
+ address, sockets.client, PoolConfig.max_connection_lifetime,
+ routing_context={"foo": "bar"}
+ )
connection.hello()
tag, fields = sockets.server.pop_message()
assert tag == 0x01
@@ -223,14 +238,19 @@ def test_hello_passes_routing_metadata(fake_socket_pair):
({"connection.recv_timeout_seconds": False}, False),
({"connection.recv_timeout_seconds": "1"}, False),
))
-def test_hint_recv_timeout_seconds(fake_socket_pair, hints, valid,
- caplog):
+@mark_sync_test
+def test_hint_recv_timeout_seconds(
+ fake_socket_pair, hints, valid, caplog
+):
address = ("127.0.0.1", 7687)
sockets = fake_socket_pair(address)
sockets.client.settimeout = MagicMock()
- sockets.server.send_message(0x70, {"server": "Neo4j/4.3.4", "hints": hints})
- connection = Bolt4x4(address, sockets.client,
- PoolConfig.max_connection_lifetime)
+ sockets.server.send_message(
+ 0x70, {"server": "Neo4j/4.3.4", "hints": hints}
+ )
+ connection = Bolt4x4(
+ address, sockets.client, PoolConfig.max_connection_lifetime
+ )
with caplog.at_level(logging.INFO):
connection.hello()
if valid:
diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py
new file mode 100644
index 00000000..d5ff16cb
--- /dev/null
+++ b/tests/unit/sync/io/test_direct.py
@@ -0,0 +1,231 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import pytest
+
+from neo4j import (
+ Config,
+ PoolConfig,
+ WorkspaceConfig,
+)
+from neo4j._sync.io import Bolt
+from neo4j._sync.io._pool import IOPool
+from neo4j.exceptions import (
+ ClientError,
+ ServiceUnavailable,
+)
+
+from ..._async_compat import (
+ mark_sync_test,
+ Mock,
+ mock,
+)
+
+
+class FakeSocket:
+ def __init__(self, address):
+ self.address = address
+
+ def getpeername(self):
+ return self.address
+
+ def sendall(self, data):
+ return
+
+ def close(self):
+ return
+
+
+class QuickConnection:
+ def __init__(self, socket):
+ self.socket = socket
+ self.address = socket.getpeername()
+
+ @property
+ def is_reset(self):
+ return True
+
+ def stale(self):
+ return False
+
+ def reset(self):
+ pass
+
+ def close(self):
+ self.socket.close()
+
+ def closed(self):
+ return False
+
+ def defunct(self):
+ return False
+
+ def timedout(self):
+ return False
+
+
+class FakeBoltPool(IOPool):
+
+ def __init__(self, address, *, auth=None, **config):
+ self.pool_config, self.workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig)
+ if config:
+ raise ValueError("Unexpected config keys: %s" % ", ".join(config.keys()))
+
+ def opener(addr, timeout):
+ return QuickConnection(FakeSocket(addr))
+
+ super().__init__(opener, self.pool_config, self.workspace_config)
+ self.address = address
+
+ def acquire(
+ self, access_mode=None, timeout=None, database=None, bookmarks=None
+ ):
+ return self._acquire(self.address, timeout)
+
+
+@mark_sync_test
+def test_bolt_connection_open():
+ with pytest.raises(ServiceUnavailable):
+ Bolt.open(("localhost", 9999), auth=("test", "test"))
+
+
+@mark_sync_test
+def test_bolt_connection_open_timeout():
+ with pytest.raises(ServiceUnavailable):
+ Bolt.open(("localhost", 9999), auth=("test", "test"),
+ timeout=1)
+
+
+@mark_sync_test
+def test_bolt_connection_ping():
+ protocol_version = Bolt.ping(("localhost", 9999))
+ assert protocol_version is None
+
+
+@mark_sync_test
+def test_bolt_connection_ping_timeout():
+ protocol_version = Bolt.ping(("localhost", 9999), timeout=1)
+ assert protocol_version is None
+
+
+@pytest.fixture
+def pool():
+ with FakeBoltPool(("127.0.0.1", 7687)) as pool:
+ yield pool
+
+
+def assert_pool_size( address, expected_active, expected_inactive, pool):
+ try:
+ connections = pool.connections[address]
+ except KeyError:
+ assert 0 == expected_active
+ assert 0 == expected_inactive
+ else:
+ assert expected_active == len([cx for cx in connections if cx.in_use])
+ assert (expected_inactive
+ == len([cx for cx in connections if not cx.in_use]))
+
+
+@mark_sync_test
+def test_pool_can_acquire(pool):
+ address = ("127.0.0.1", 7687)
+ connection = pool._acquire(address, timeout=3)
+ assert connection.address == address
+ assert_pool_size(address, 1, 0, pool)
+
+
+@mark_sync_test
+def test_pool_can_acquire_twice(pool):
+ address = ("127.0.0.1", 7687)
+ connection_1 = pool._acquire(address, timeout=3)
+ connection_2 = pool._acquire(address, timeout=3)
+ assert connection_1.address == address
+ assert connection_2.address == address
+ assert connection_1 is not connection_2
+ assert_pool_size(address, 2, 0, pool)
+
+
+@mark_sync_test
+def test_pool_can_acquire_two_addresses(pool):
+ address_1 = ("127.0.0.1", 7687)
+ address_2 = ("127.0.0.1", 7474)
+ connection_1 = pool._acquire(address_1, timeout=3)
+ connection_2 = pool._acquire(address_2, timeout=3)
+ assert connection_1.address == address_1
+ assert connection_2.address == address_2
+ assert_pool_size(address_1, 1, 0, pool)
+ assert_pool_size(address_2, 1, 0, pool)
+
+
+@mark_sync_test
+def test_pool_can_acquire_and_release(pool):
+ address = ("127.0.0.1", 7687)
+ connection = pool._acquire(address, timeout=3)
+ assert_pool_size(address, 1, 0, pool)
+ pool.release(connection)
+ assert_pool_size(address, 0, 1, pool)
+
+
+@mark_sync_test
+def test_pool_releasing_twice(pool):
+ address = ("127.0.0.1", 7687)
+ connection = pool._acquire(address, timeout=3)
+ pool.release(connection)
+ assert_pool_size(address, 0, 1, pool)
+ pool.release(connection)
+ assert_pool_size(address, 0, 1, pool)
+
+
+@mark_sync_test
+def test_pool_in_use_count(pool):
+ address = ("127.0.0.1", 7687)
+ assert pool.in_use_connection_count(address) == 0
+ connection = pool._acquire(address, timeout=3)
+ assert pool.in_use_connection_count(address) == 1
+ pool.release(connection)
+ assert pool.in_use_connection_count(address) == 0
+
+
+@mark_sync_test
+def test_pool_max_conn_pool_size(pool):
+ with FakeBoltPool((), max_connection_pool_size=1) as pool:
+ address = ("127.0.0.1", 7687)
+ pool._acquire(address, timeout=0)
+ assert pool.in_use_connection_count(address) == 1
+ with pytest.raises(ClientError):
+ pool._acquire(address, timeout=0)
+ assert pool.in_use_connection_count(address) == 1
+
+
+@pytest.mark.parametrize("is_reset", (True, False))
+@mark_sync_test
+def test_pool_reset_when_released(is_reset, pool):
+ address = ("127.0.0.1", 7687)
+ quick_connection_name = QuickConnection.__name__
+ with mock.patch(f"{__name__}.{quick_connection_name}.is_reset",
+ new_callable=mock.PropertyMock) as is_reset_mock:
+ with mock.patch(f"{__name__}.{quick_connection_name}.reset",
+ new_callable=Mock) as reset_mock:
+ is_reset_mock.return_value = is_reset
+ connection = pool._acquire(address, timeout=3)
+ assert isinstance(connection, QuickConnection)
+ assert is_reset_mock.call_count == 0
+ assert reset_mock.call_count == 0
+ pool.release(connection)
+ assert is_reset_mock.call_count == 1
+ assert reset_mock.call_count == int(not is_reset)
diff --git a/tests/unit/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py
similarity index 85%
rename from tests/unit/io/test_neo4j_pool.py
rename to tests/unit/sync/io/test_neo4j_pool.py
index a5df0a90..6fb57b98 100644
--- a/tests/unit/io/test_neo4j_pool.py
+++ b/tests/unit/sync/io/test_neo4j_pool.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -23,19 +20,23 @@
import pytest
-from ..work import FakeConnection
-
from neo4j import (
READ_ACCESS,
WRITE_ACCESS,
)
+from neo4j._sync.io import Neo4jPool
from neo4j.addressing import ResolvedAddress
from neo4j.conf import (
PoolConfig,
RoutingConfig,
- WorkspaceConfig
+ WorkspaceConfig,
)
-from neo4j.io import Neo4jPool
+
+from ..._async_compat import (
+ mark_sync_test,
+ Mock,
+)
+from ..work import FakeConnection
ROUTER_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host")
@@ -68,8 +69,11 @@ def open_(addr, timeout):
return opener_
+@mark_sync_test
def test_acquires_new_routing_table_if_deleted(opener):
- pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
+ pool = Neo4jPool(
+ opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS
+ )
cx = pool.acquire(READ_ACCESS, 30, "test_db", None)
pool.release(cx)
assert pool.routing_tables.get("test_db")
@@ -81,8 +85,11 @@ def test_acquires_new_routing_table_if_deleted(opener):
assert pool.routing_tables.get("test_db")
+@mark_sync_test
def test_acquires_new_routing_table_if_stale(opener):
- pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
+ pool = Neo4jPool(
+ opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS
+ )
cx = pool.acquire(READ_ACCESS, 30, "test_db", None)
pool.release(cx)
assert pool.routing_tables.get("test_db")
@@ -95,8 +102,11 @@ def test_acquires_new_routing_table_if_stale(opener):
assert pool.routing_tables["test_db"].last_updated_time > old_value
+@mark_sync_test
def test_removes_old_routing_table(opener):
- pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
+ pool = Neo4jPool(
+ opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS
+ )
cx = pool.acquire(READ_ACCESS, 30, "test_db1", None)
pool.release(cx)
assert pool.routing_tables.get("test_db1")
@@ -116,8 +126,11 @@ def test_removes_old_routing_table(opener):
@pytest.mark.parametrize("type_", ("r", "w"))
+@mark_sync_test
def test_chooses_right_connection_type(opener, type_):
- pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
+ pool = Neo4jPool(
+ opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS
+ )
cx1 = pool.acquire(READ_ACCESS if type_ == "r" else WRITE_ACCESS,
30, "test_db", None)
pool.release(cx1)
@@ -127,8 +140,11 @@ def test_chooses_right_connection_type(opener, type_):
assert cx1.addr == WRITER_ADDRESS
+@mark_sync_test
def test_reuses_connection(opener):
- pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
+ pool = Neo4jPool(
+ opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS
+ )
cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None)
pool.release(cx1)
cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None)
@@ -136,6 +152,7 @@ def test_reuses_connection(opener):
@pytest.mark.parametrize("break_on_close", (True, False))
+@mark_sync_test
def test_closes_stale_connections(opener, break_on_close):
def break_connection():
pool.deactivate(cx1.addr)
@@ -143,7 +160,9 @@ def break_connection():
if cx_close_mock_side_effect:
cx_close_mock_side_effect()
- pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
+ pool = Neo4jPool(
+ opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS
+ )
cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None)
pool.release(cx1)
assert cx1 in pool.connections[cx1.addr]
@@ -166,8 +185,11 @@ def break_connection():
assert cx2 in pool.connections[cx2.addr]
+@mark_sync_test
def test_does_not_close_stale_connections_in_use(opener):
- pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
+ pool = Neo4jPool(
+ opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS
+ )
cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None)
assert cx1 in pool.connections[cx1.addr]
# simulate connection going stale (e.g. exceeding) while being in use
@@ -194,8 +216,11 @@ def test_does_not_close_stale_connections_in_use(opener):
assert cx3 in pool.connections[cx2.addr]
+@mark_sync_test
def test_release_resets_connections(opener):
- pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
+ pool = Neo4jPool(
+ opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS
+ )
cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None)
cx1.is_reset_mock.return_value = False
cx1.is_reset_mock.reset_mock()
@@ -204,8 +229,11 @@ def test_release_resets_connections(opener):
cx1.reset.assert_called_once()
+@mark_sync_test
def test_release_does_not_resets_closed_connections(opener):
- pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
+ pool = Neo4jPool(
+ opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS
+ )
cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None)
cx1.closed.return_value = True
cx1.closed.reset_mock()
@@ -216,8 +244,11 @@ def test_release_does_not_resets_closed_connections(opener):
cx1.reset.asset_not_called()
+@mark_sync_test
def test_release_does_not_resets_defunct_connections(opener):
- pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS)
+ pool = Neo4jPool(
+ opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS
+ )
cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None)
cx1.defunct.return_value = True
cx1.defunct.reset_mock()
diff --git a/tests/unit/sync/test_addressing.py b/tests/unit/sync/test_addressing.py
new file mode 100644
index 00000000..4fd814b2
--- /dev/null
+++ b/tests/unit/sync/test_addressing.py
@@ -0,0 +1,125 @@
+# Copyright (c) "Neo4j"
+# Neo4j Sweden AB [http://neo4j.com]
+#
+# This file is part of Neo4j.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from socket import (
+ AF_INET,
+ AF_INET6,
+)
+import unittest.mock as mock
+
+import pytest
+
+from neo4j import (
+ Address,
+ IPv4Address,
+)
+from neo4j._async_compat.network import NetworkUtil
+from neo4j._async_compat.util import Util
+
+from .._async_compat import mark_sync_test
+
+
+mock_socket_ipv4 = mock.Mock()
+mock_socket_ipv4.getpeername = lambda: ("127.0.0.1", 7687) # (address, port)
+
+mock_socket_ipv6 = mock.Mock()
+mock_socket_ipv6.getpeername = lambda: ("[::1]", 7687, 0, 0) # (address, port, flow info, scope id)
+
+
+@mark_sync_test
+def test_address_resolve():
+ address = Address(("127.0.0.1", 7687))
+ resolved = NetworkUtil.resolve_address(address)
+ resolved = Util.list(resolved)
+ assert isinstance(resolved, Address) is False
+ assert isinstance(resolved, list) is True
+ assert len(resolved) == 1
+ assert resolved[0] == IPv4Address(('127.0.0.1', 7687))
+
+
+@mark_sync_test
+def test_address_resolve_with_custom_resolver_none():
+ address = Address(("127.0.0.1", 7687))
+ resolved = NetworkUtil.resolve_address(address, resolver=None)
+ resolved = Util.list(resolved)
+ assert isinstance(resolved, Address) is False
+ assert isinstance(resolved, list) is True
+ assert len(resolved) == 1
+ assert resolved[0] == IPv4Address(('127.0.0.1', 7687))
+
+
+@pytest.mark.parametrize(
+ "test_input, expected",
+ [
+ (Address(("127.0.0.1", "abcd")), ValueError),
+ (Address((None, None)), ValueError),
+ ]
+
+)
+@mark_sync_test
+def test_address_resolve_with_unresolvable_address(test_input, expected):
+ with pytest.raises(expected):
+ Util.list(
+ NetworkUtil.resolve_address(test_input, resolver=None)
+ )
+
+
+@mark_sync_test
+@pytest.mark.parametrize("resolver_type", ("sync", "async"))
+def test_address_resolve_with_custom_resolver(resolver_type):
+ def custom_resolver_sync(_):
+ return [("127.0.0.1", 7687), ("localhost", 1234)]
+
+ def custom_resolver_async(_):
+ return [("127.0.0.1", 7687), ("localhost", 1234)]
+
+ if resolver_type == "sync":
+ custom_resolver = custom_resolver_sync
+ else:
+ custom_resolver = custom_resolver_async
+
+ address = Address(("127.0.0.1", 7687))
+ resolved = NetworkUtil.resolve_address(
+ address, family=AF_INET, resolver=custom_resolver
+ )
+ resolved = Util.list(resolved)
+ assert isinstance(resolved, Address) is False
+ assert isinstance(resolved, list) is True
+ assert len(resolved) == 2 # IPv4 only
+ assert resolved[0] == IPv4Address(('127.0.0.1', 7687))
+ assert resolved[1] == IPv4Address(('127.0.0.1', 1234))
+
+
+@mark_sync_test
+def test_address_unresolve():
+ custom_resolved = [("127.0.0.1", 7687), ("localhost", 4321)]
+ custom_resolver = lambda _: custom_resolved
+
+ address = Address(("foobar", 1234))
+ unresolved = address.unresolved
+ assert address.__class__ == unresolved.__class__
+ assert address == unresolved
+ resolved = NetworkUtil.resolve_address(
+ address, family=AF_INET, resolver=custom_resolver
+ )
+ resolved = Util.list(resolved)
+ custom_resolved = sorted(Address(a) for a in custom_resolved)
+ unresolved = sorted(a.unresolved for a in resolved)
+ assert custom_resolved == unresolved
+ assert (list(map(lambda a: a.__class__, custom_resolved))
+ == list(map(lambda a: a.__class__, unresolved)))
diff --git a/tests/unit/test_driver.py b/tests/unit/sync/test_driver.py
similarity index 92%
rename from tests/unit/test_driver.py
rename to tests/unit/sync/test_driver.py
index 0c1192e4..93579b2e 100644
--- a/tests/unit/test_driver.py
+++ b/tests/unit/sync/test_driver.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -18,18 +15,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+
import pytest
from neo4j import (
BoltDriver,
GraphDatabase,
Neo4jDriver,
- TRUST_SYSTEM_CA_SIGNED_CERTIFICATES,
TRUST_ALL_CERTIFICATES,
+ TRUST_SYSTEM_CA_SIGNED_CERTIFICATES,
)
from neo4j.api import WRITE_ACCESS
from neo4j.exceptions import ConfigurationError
+from .._async_compat import (
+ mark_sync_test,
+ mock,
+)
+
@pytest.mark.parametrize("protocol", ("bolt://", "bolt+s://", "bolt+ssc://"))
@pytest.mark.parametrize("host", ("localhost", "127.0.0.1",
@@ -121,18 +124,22 @@ def test_driver_trust_config_error(
"bolt://127.0.0.1:9000",
"neo4j://127.0.0.1:9000",
))
+@mark_sync_test
def test_driver_opens_write_session_by_default(uri, mocker):
driver = GraphDatabase.driver(uri)
- from neo4j.work.transaction import Transaction
+ from neo4j import Transaction
+
# we set a specific db, because else the driver would try to fetch a RT
# to get hold of the actual home database (which won't work in this
# unittest)
with driver.session(database="foobar") as session:
- acquire_mock = mocker.patch.object(session._pool, "acquire",
- autospec=True)
- tx_begin_mock = mocker.patch.object(Transaction, "_begin",
- autospec=True)
- tx = session.begin_transaction()
+ with mock.patch.object(
+ session._pool, "acquire", autospec=True
+ ) as acquire_mock:
+ with mock.patch.object(
+ Transaction, "_begin", autospec=True
+ ) as tx_begin_mock:
+ tx = session.begin_transaction()
acquire_mock.assert_called_once_with(
access_mode=WRITE_ACCESS,
timeout=mocker.ANY,
diff --git a/tests/unit/work/__init__.py b/tests/unit/sync/work/__init__.py
similarity index 93%
rename from tests/unit/work/__init__.py
rename to tests/unit/sync/work/__init__.py
index 238e61d3..2613b53d 100644
--- a/tests/unit/work/__init__.py
+++ b/tests/unit/sync/work/__init__.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -18,7 +15,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+
from ._fake_connection import (
- FakeConnection,
fake_connection,
+ FakeConnection,
)
diff --git a/tests/unit/work/_fake_connection.py b/tests/unit/sync/work/_fake_connection.py
similarity index 73%
rename from tests/unit/work/_fake_connection.py
rename to tests/unit/sync/work/_fake_connection.py
index fef0b580..1748ea61 100644
--- a/tests/unit/work/_fake_connection.py
+++ b/tests/unit/sync/work/_fake_connection.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -20,11 +17,16 @@
import inspect
-from unittest import mock
import pytest
from neo4j import ServerInfo
+from neo4j._sync.io import Bolt
+
+from ..._async_compat import (
+ Mock,
+ mock,
+)
class FakeConnection(mock.NonCallableMagicMock):
@@ -32,22 +34,27 @@ class FakeConnection(mock.NonCallableMagicMock):
server_info = ServerInfo("127.0.0.1", (4, 3))
def __init__(self, *args, **kwargs):
+ kwargs["spec"] = Bolt
super().__init__(*args, **kwargs)
- self.attach_mock(mock.Mock(return_value=True), "is_reset_mock")
- self.attach_mock(mock.Mock(return_value=False), "defunct")
- self.attach_mock(mock.Mock(return_value=False), "stale")
- self.attach_mock(mock.Mock(return_value=False), "closed")
+ self.attach_mock(Mock(return_value=True), "is_reset_mock")
+ self.attach_mock(Mock(return_value=False), "defunct")
+ self.attach_mock(Mock(return_value=False), "stale")
+ self.attach_mock(Mock(return_value=False), "closed")
+ self.attach_mock(Mock(), "unresolved_address")
def close_side_effect():
self.closed.return_value = True
- self.attach_mock(mock.Mock(side_effect=close_side_effect), "close")
+ self.attach_mock(Mock(side_effect=close_side_effect),
+ "close")
@property
def is_reset(self):
if self.closed.return_value or self.defunct.return_value:
- raise AssertionError("is_reset should not be called on a closed or "
- "defunct connection.")
+ raise AssertionError(
+ "is_reset should not be called on a closed or defunct "
+ "connection."
+ )
return self.is_reset_mock()
def fetch_message(self, *args, **kwargs):
@@ -81,9 +88,13 @@ def callback():
# e.g. built-in method as cb
pass
if param_count == 1:
- cb({})
+ res = cb({})
else:
- cb()
+ res = cb()
+ try:
+ res # maybe the callback is async
+ except TypeError:
+ pass # or maybe it wasn't ;)
self.callbacks.append(callback)
return func
diff --git a/tests/unit/work/test_result.py b/tests/unit/sync/work/test_result.py
similarity index 83%
rename from tests/unit/work/test_result.py
rename to tests/unit/sync/work/test_result.py
index c4b5aa13..863df838 100644
--- a/tests/unit/work/test_result.py
+++ b/tests/unit/sync/work/test_result.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -26,13 +23,16 @@
from neo4j import (
Address,
Record,
+ Result,
ResultSummary,
ServerInfo,
SummaryCounters,
Version,
)
+from neo4j._async_compat.util import Util
from neo4j.data import DataHydrator
-from neo4j.work.result import Result
+
+from ..._async_compat import mark_sync_test
class Records:
@@ -62,8 +62,7 @@ def __init__(self, message, *args, **kwargs):
def _cb(self, cb_name, *args, **kwargs):
# print(self.message, cb_name.upper(), args, kwargs)
cb = self.kwargs.get(cb_name)
- if callable(self.kwargs.get(cb_name)):
- cb(*args, **kwargs)
+ Util.callback(cb, *args, **kwargs)
def on_success(self, metadata):
self._cb("on_success", metadata)
@@ -151,8 +150,9 @@ def fetch_message(self):
if self.record_idxs[qid] < len(self._records[qid]):
msg.on_success({"has_more": True})
else:
- msg.on_success({"bookmark": "foo",
- **(self.summary_meta or {})})
+ msg.on_success(
+ {"bookmark": "foo", **(self.summary_meta or {})}
+ )
self._exhausted[qid] = True
msg.on_summary()
@@ -184,8 +184,9 @@ def noop(*_, **__):
pass
-def _fetch_and_compare_all_records(result, key, expected_records, method,
- limit=None):
+def fetch_and_compare_all_records(
+ result, key, expected_records, method, limit=None
+):
received_records = []
if method == "for loop":
for record in result:
@@ -196,42 +197,48 @@ def _fetch_and_compare_all_records(result, key, expected_records, method,
if limit is None:
assert result._closed
elif method == "next":
- iter_ = iter(result)
+ iter_ = Util.iter(result)
n = len(expected_records) if limit is None else limit
for _ in range(n):
- received_records.append([next(iter_).get(key, None)])
+ record = Util.next(iter_)
+ received_records.append([record.get(key, None)])
if limit is None:
with pytest.raises(StopIteration):
- received_records.append([next(iter_).get(key, None)])
+ Util.next(iter_)
assert result._closed
elif method == "new iter":
n = len(expected_records) if limit is None else limit
for _ in range(n):
- received_records.append([next(iter(result)).get(key, None)])
+ iter_ = Util.iter(result)
+ record = Util.next(iter_)
+ received_records.append([record.get(key, None)])
if limit is None:
+ iter_ = Util.iter(result)
with pytest.raises(StopIteration):
- received_records.append([next(iter(result)).get(key, None)])
+ Util.next(iter_)
assert result._closed
else:
raise ValueError()
assert received_records == expected_records
-@pytest.mark.parametrize("method", ("for loop", "next", "new iter"))
+@pytest.mark.parametrize("method", ("for loop", "next", "new iter"))
@pytest.mark.parametrize("records", (
[],
[[42]],
[[1], [2], [3], [4], [5]],
))
+@mark_sync_test
def test_result_iteration(method, records):
connection = ConnectionStub(records=Records(["x"], records))
result = Result(connection, HydratorStub(), 2, noop, noop)
result._run("CYPHER", {}, None, None, "r", None)
- _fetch_and_compare_all_records(result, "x", records, method)
+ fetch_and_compare_all_records(result, "x", records, method)
-@pytest.mark.parametrize("method", ("for loop", "next", "new iter"))
+@pytest.mark.parametrize("method", ("for loop", "next", "new iter"))
@pytest.mark.parametrize("invert_fetch", (True, False))
+@mark_sync_test
def test_parallel_result_iteration(method, invert_fetch):
records1 = [[i] for i in range(1, 6)]
records2 = [[i] for i in range(6, 11)]
@@ -243,15 +250,24 @@ def test_parallel_result_iteration(method, invert_fetch):
result2 = Result(connection, HydratorStub(), 2, noop, noop)
result2._run("CYPHER2", {}, None, None, "r", None)
if invert_fetch:
- _fetch_and_compare_all_records(result2, "x", records2, method)
- _fetch_and_compare_all_records(result1, "x", records1, method)
+ fetch_and_compare_all_records(
+ result2, "x", records2, method
+ )
+ fetch_and_compare_all_records(
+ result1, "x", records1, method
+ )
else:
- _fetch_and_compare_all_records(result1, "x", records1, method)
- _fetch_and_compare_all_records(result2, "x", records2, method)
+ fetch_and_compare_all_records(
+ result1, "x", records1, method
+ )
+ fetch_and_compare_all_records(
+ result2, "x", records2, method
+ )
-@pytest.mark.parametrize("method", ("for loop", "next", "new iter"))
+@pytest.mark.parametrize("method", ("for loop", "next", "new iter"))
@pytest.mark.parametrize("invert_fetch", (True, False))
+@mark_sync_test
def test_interwoven_result_iteration(method, invert_fetch):
records1 = [[i] for i in range(1, 10)]
records2 = [[i] for i in range(11, 20)]
@@ -266,20 +282,25 @@ def test_interwoven_result_iteration(method, invert_fetch):
for n in (1, 2, 3, 1, None):
end = n if n is None else start + n
if invert_fetch:
- _fetch_and_compare_all_records(result2, "y", records2[start:end],
- method, n)
- _fetch_and_compare_all_records(result1, "x", records1[start:end],
- method, n)
+ fetch_and_compare_all_records(
+ result2, "y", records2[start:end], method, n
+ )
+ fetch_and_compare_all_records(
+ result1, "x", records1[start:end], method, n
+ )
else:
- _fetch_and_compare_all_records(result1, "x", records1[start:end],
- method, n)
- _fetch_and_compare_all_records(result2, "y", records2[start:end],
- method, n)
+ fetch_and_compare_all_records(
+ result1, "x", records1[start:end], method, n
+ )
+ fetch_and_compare_all_records(
+ result2, "y", records2[start:end], method, n
+ )
start = end
@pytest.mark.parametrize("records", ([[1], [2]], [[1]], []))
@pytest.mark.parametrize("fetch_size", (1, 2))
+@mark_sync_test
def test_result_peek(records, fetch_size):
connection = ConnectionStub(records=Records(["x"], records))
result = Result(connection, HydratorStub(), fetch_size, noop, noop)
@@ -291,11 +312,13 @@ def test_result_peek(records, fetch_size):
else:
assert isinstance(record, Record)
assert record.get("x") == records[i][0]
- next(iter(result)) # consume the record
+ iter_ = Util.iter(result)
+ Util.next(iter_) # consume the record
@pytest.mark.parametrize("records", ([[1], [2]], [[1]], []))
@pytest.mark.parametrize("fetch_size", (1, 2))
+@mark_sync_test
def test_result_single(records, fetch_size):
connection = ConnectionStub(records=Records(["x"], records))
result = Result(connection, HydratorStub(), fetch_size, noop, noop)
@@ -314,26 +337,29 @@ def test_result_single(records, fetch_size):
assert record.get("x") == records[0][0]
+@mark_sync_test
def test_keys_are_available_before_and_after_stream():
connection = ConnectionStub(records=Records(["x"], [[1], [2]]))
result = Result(connection, HydratorStub(), 1, noop, noop)
result._run("CYPHER", {}, None, None, "r", None)
assert list(result.keys()) == ["x"]
- list(result)
+ Util.list(result)
assert list(result.keys()) == ["x"]
@pytest.mark.parametrize("records", ([[1], [2]], [[1]], []))
@pytest.mark.parametrize("consume_one", (True, False))
@pytest.mark.parametrize("summary_meta", (None, {"database": "foobar"}))
+@mark_sync_test
def test_consume(records, consume_one, summary_meta):
- connection = ConnectionStub(records=Records(["x"], records),
- summary_meta=summary_meta)
+ connection = ConnectionStub(
+ records=Records(["x"], records), summary_meta=summary_meta
+ )
result = Result(connection, HydratorStub(), 1, noop, noop)
result._run("CYPHER", {}, None, None, "r", None)
if consume_one:
try:
- next(iter(result))
+ Util.next(Util.iter(result))
except StopIteration:
pass
summary = result.consume()
@@ -351,6 +377,7 @@ def test_consume(records, consume_one, summary_meta):
@pytest.mark.parametrize("t_first", (None, 0, 1, 123456789))
@pytest.mark.parametrize("t_last", (None, 0, 1, 123456789))
+@mark_sync_test
def test_time_in_summary(t_first, t_last):
run_meta = None
if t_first is not None:
@@ -358,9 +385,10 @@ def test_time_in_summary(t_first, t_last):
summary_meta = None
if t_last is not None:
summary_meta = {"t_last": t_last}
- connection = ConnectionStub(records=Records(["n"],
- [[i] for i in range(100)]),
- run_meta=run_meta, summary_meta=summary_meta)
+ connection = ConnectionStub(
+ records=Records(["n"], [[i] for i in range(100)]), run_meta=run_meta,
+ summary_meta=summary_meta
+ )
result = Result(connection, HydratorStub(), 1, noop, noop)
result._run("CYPHER", {}, None, None, "r", None)
@@ -380,6 +408,7 @@ def test_time_in_summary(t_first, t_last):
assert not hasattr(summary, "t_last")
+@mark_sync_test
def test_counts_in_summary():
connection = ConnectionStub(records=Records(["n"], [[1], [2]]))
@@ -391,9 +420,11 @@ def test_counts_in_summary():
@pytest.mark.parametrize("query_type", ("r", "w", "rw", "s"))
+@mark_sync_test
def test_query_type(query_type):
- connection = ConnectionStub(records=Records(["n"], [[1], [2]]),
- summary_meta={"type": query_type})
+ connection = ConnectionStub(
+ records=Records(["n"], [[1], [2]]), summary_meta={"type": query_type}
+ )
result = Result(connection, HydratorStub(), 1, noop, noop)
result._run("CYPHER", {}, None, None, "r", None)
@@ -404,6 +435,7 @@ def test_query_type(query_type):
@pytest.mark.parametrize("num_records", range(0, 5))
+@mark_sync_test
def test_data(num_records):
connection = ConnectionStub(
records=Records(["n"], [[i + 1] for i in range(num_records)])
diff --git a/tests/unit/work/test_session.py b/tests/unit/sync/work/test_session.py
similarity index 87%
rename from tests/unit/work/test_session.py
rename to tests/unit/sync/work/test_session.py
index 71a3bde8..4a12f695 100644
--- a/tests/unit/work/test_session.py
+++ b/tests/unit/sync/work/test_session.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -18,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+
from contextlib import contextmanager
import pytest
@@ -28,23 +26,30 @@
Transaction,
unit_of_work,
)
+from neo4j._sync.io._pool import IOPool
+from ..._async_compat import (
+ mark_sync_test,
+ Mock,
+ mock,
+)
from ._fake_connection import FakeConnection
@pytest.fixture()
-def pool(mocker):
- pool = mocker.MagicMock()
- pool.acquire = mocker.MagicMock(side_effect=iter(FakeConnection, 0))
+def pool():
+ pool = Mock(spec=IOPool)
+ pool.acquire.side_effect = iter(FakeConnection, 0)
return pool
-def test_session_context_calls_close(mocker):
+@mark_sync_test
+def test_session_context_calls_close():
s = Session(None, SessionConfig())
- mock_close = mocker.patch.object(s, 'close', autospec=True)
- with s:
- pass
- mock_close.assert_called_once_with()
+ with mock.patch.object(s, 'close', autospec=True) as mock_close:
+ with s:
+ pass
+ mock_close.assert_called_once_with()
@pytest.mark.parametrize("test_run_args", (
@@ -53,7 +58,10 @@ def test_session_context_calls_close(mocker):
@pytest.mark.parametrize(("repetitions", "consume"), (
(1, False), (2, False), (2, True)
))
-def test_opens_connection_on_run(pool, test_run_args, repetitions, consume):
+@mark_sync_test
+def test_opens_connection_on_run(
+ pool, test_run_args, repetitions, consume
+):
with Session(pool, SessionConfig()) as session:
assert session._connection is None
result = session.run(*test_run_args)
@@ -66,7 +74,10 @@ def test_opens_connection_on_run(pool, test_run_args, repetitions, consume):
("RETURN $x", {"x": 1}), ("RETURN 1",)
))
@pytest.mark.parametrize("repetitions", range(1, 3))
-def test_closes_connection_after_consume(pool, test_run_args, repetitions):
+@mark_sync_test
+def test_closes_connection_after_consume(
+ pool, test_run_args, repetitions
+):
with Session(pool, SessionConfig()) as session:
result = session.run(*test_run_args)
result.consume()
@@ -77,7 +88,10 @@ def test_closes_connection_after_consume(pool, test_run_args, repetitions):
@pytest.mark.parametrize("test_run_args", (
("RETURN $x", {"x": 1}), ("RETURN 1",)
))
-def test_keeps_connection_until_last_result_consumed(pool, test_run_args):
+@mark_sync_test
+def test_keeps_connection_until_last_result_consumed(
+ pool, test_run_args
+):
with Session(pool, SessionConfig()) as session:
result1 = session.run(*test_run_args)
result2 = session.run(*test_run_args)
@@ -88,6 +102,7 @@ def test_keeps_connection_until_last_result_consumed(pool, test_run_args):
assert session._connection is None
+@mark_sync_test
def test_opens_connection_on_tx_begin(pool):
with Session(pool, SessionConfig()) as session:
assert session._connection is None
@@ -99,6 +114,7 @@ def test_opens_connection_on_tx_begin(pool):
("RETURN $x", {"x": 1}), ("RETURN 1",)
))
@pytest.mark.parametrize("repetitions", range(1, 3))
+@mark_sync_test
def test_keeps_connection_on_tx_run(pool, test_run_args, repetitions):
with Session(pool, SessionConfig()) as session:
with session.begin_transaction() as tx:
@@ -111,7 +127,10 @@ def test_keeps_connection_on_tx_run(pool, test_run_args, repetitions):
("RETURN $x", {"x": 1}), ("RETURN 1",)
))
@pytest.mark.parametrize("repetitions", range(1, 3))
-def test_keeps_connection_on_tx_consume(pool, test_run_args, repetitions):
+@mark_sync_test
+def test_keeps_connection_on_tx_consume(
+ pool, test_run_args, repetitions
+):
with Session(pool, SessionConfig()) as session:
with session.begin_transaction() as tx:
for _ in range(repetitions):
@@ -123,6 +142,7 @@ def test_keeps_connection_on_tx_consume(pool, test_run_args, repetitions):
@pytest.mark.parametrize("test_run_args", (
("RETURN $x", {"x": 1}), ("RETURN 1",)
))
+@mark_sync_test
def test_closes_connection_after_tx_close(pool, test_run_args):
with Session(pool, SessionConfig()) as session:
with session.begin_transaction() as tx:
@@ -137,6 +157,7 @@ def test_closes_connection_after_tx_close(pool, test_run_args):
@pytest.mark.parametrize("test_run_args", (
("RETURN $x", {"x": 1}), ("RETURN 1",)
))
+@mark_sync_test
def test_closes_connection_after_tx_commit(pool, test_run_args):
with Session(pool, SessionConfig()) as session:
with session.begin_transaction() as tx:
@@ -149,8 +170,11 @@ def test_closes_connection_after_tx_commit(pool, test_run_args):
@pytest.mark.parametrize("bookmarks", (None, [], ["abc"], ["foo", "bar"]))
+@mark_sync_test
def test_session_returns_bookmark_directly(pool, bookmarks):
- with Session(pool, SessionConfig(bookmarks=bookmarks)) as session:
+ with Session(
+ pool, SessionConfig(bookmarks=bookmarks)
+ ) as session:
if bookmarks:
assert session.last_bookmark() == bookmarks[-1]
else:
@@ -163,6 +187,7 @@ def test_session_returns_bookmark_directly(pool, bookmarks):
({"how about": "no?"}, TypeError),
(["I don't", "think so"], TypeError),
))
+@mark_sync_test
def test_session_run_wrong_types(pool, query, error_type):
with Session(pool, SessionConfig()) as session:
with pytest.raises(error_type):
@@ -170,6 +195,7 @@ def test_session_run_wrong_types(pool, query, error_type):
@pytest.mark.parametrize("tx_type", ("write_transaction", "read_transaction"))
+@mark_sync_test
def test_tx_function_argument_type(pool, tx_type):
def work(tx):
assert isinstance(tx, Transaction)
@@ -186,6 +212,7 @@ def work(tx):
{"timeout": 5, "metadata": {"foo": "bar"}},
))
+@mark_sync_test
def test_decorated_tx_function_argument_type(pool, tx_type, decorator_kwargs):
@unit_of_work(**decorator_kwargs)
def work(tx):
@@ -195,6 +222,7 @@ def work(tx):
getattr(session, tx_type)(work)
+@mark_sync_test
def test_session_tx_type(pool):
with Session(pool, SessionConfig()) as session:
tx = session.begin_transaction()
@@ -225,7 +253,10 @@ def test_session_tx_type(pool):
({"x": {(1, 2): '1+2i', (2, 0): '2'}}, TypeError),
))
@pytest.mark.parametrize("run_type", ("auto", "unmanaged", "managed"))
-def test_session_run_with_parameters(pool, parameters, error_type, run_type):
+@mark_sync_test
+def test_session_run_with_parameters(
+ pool, parameters, error_type, run_type
+):
@contextmanager
def raises():
if error_type is not None:
diff --git a/tests/unit/work/test_transaction.py b/tests/unit/sync/work/test_transaction.py
similarity index 98%
rename from tests/unit/work/test_transaction.py
rename to tests/unit/sync/work/test_transaction.py
index 06e75566..b5b40283 100644
--- a/tests/unit/work/test_transaction.py
+++ b/tests/unit/sync/work/test_transaction.py
@@ -1,6 +1,3 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
@@ -18,11 +15,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+
+from unittest.mock import MagicMock
from uuid import uuid4
-from unittest.mock import (
- MagicMock,
- NonCallableMagicMock,
-)
import pytest
diff --git a/tests/unit/test_import_neo4j.py b/tests/unit/test_import_neo4j.py
deleted file mode 100644
index a7ab16e1..00000000
--- a/tests/unit/test_import_neo4j.py
+++ /dev/null
@@ -1,171 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-# Copyright (c) "Neo4j"
-# Neo4j Sweden AB [http://neo4j.com]
-#
-# This file is part of Neo4j.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import pytest
-
-# python -m pytest tests/unit/test_import_neo4j.py -s -v
-
-
-def test_import_dunder_version():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_dunder_version
- from neo4j import __version__
-
-
-def test_import_graphdatabase():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_graphdatabase
- from neo4j import GraphDatabase
-
-
-def test_import_driver():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_driver
- from neo4j import Driver
-
-
-def test_import_boltdriver():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_boltdriver
- from neo4j import BoltDriver
-
-
-def test_import_neo4jdriver():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_neo4jdriver
- from neo4j import Neo4jDriver
-
-
-def test_import_auth():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_auth
- from neo4j import Auth
-
-
-def test_import_authtoken():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_authtoken
- from neo4j import AuthToken
-
-
-def test_import_basic_auth():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_auth
- from neo4j import basic_auth
-
-
-def test_import_kerberos_auth():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_kerberos_auth
- from neo4j import kerberos_auth
-
-
-def test_import_custom_auth():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_custom_auth
- from neo4j import custom_auth
-
-
-def test_import_read_access():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_read_access
- from neo4j import READ_ACCESS
-
-
-def test_import_write_access():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_write_access
- from neo4j import WRITE_ACCESS
-
-
-def test_import_transaction():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_transaction
- from neo4j import Transaction
-
-
-def test_import_record():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_record
- from neo4j import Record
-
-
-def test_import_session():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_session
- from neo4j import Session
-
-
-def test_import_sessionconfig():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_sessionconfig
- from neo4j import SessionConfig
-
-
-def test_import_query():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_query
- from neo4j import Query
-
-
-def test_import_result():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_result
- from neo4j import Result
-
-
-def test_import_resultsummary():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_resultsummary
- from neo4j import ResultSummary
-
-
-def test_import_unit_of_work():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_unit_of_work
- from neo4j import unit_of_work
-
-
-def test_import_config():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_config
- from neo4j import Config
-
-
-def test_import_poolconfig():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_poolconfig
- from neo4j import PoolConfig
-
-
-def test_import_graph():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_graph
- import neo4j.graph as graph
-
-
-def test_import_graph_node():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_graph_node
- from neo4j.graph import Node
-
-
-def test_import_graph_path():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_graph_path
- from neo4j.graph import Path
-
-
-def test_import_graph_graph():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_graph_graph
- from neo4j.graph import Graph
-
-
-def test_import_spatial():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_spatial
- import neo4j.spatial as spatial
-
-
-def test_import_time():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_time
- import neo4j.time as time
-
-
-def test_import_exceptions():
- # python -m pytest tests/unit/test_import_neo4j.py -s -v -k test_import_exceptions
- import neo4j.exceptions as exceptions
-
-