diff --git a/docs/source/api.rst b/docs/source/api.rst index 46008581..5ae6ff11 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -107,10 +107,17 @@ Auth To authenticate with Neo4j the authentication details are supplied at driver creation. -The auth token is an object of the class :class:`neo4j.Auth` containing the details. +The auth token is an object of the class :class:`neo4j.Auth` containing static details or :class:`neo4j.auth_management.AuthManager` object. .. autoclass:: neo4j.Auth +.. autoclass:: neo4j.auth_management.AuthManager + :members: + +.. autoclass:: neo4j.auth_management.AuthManagers + :members: + +.. autoclass:: neo4j.auth_management.ExpiringAuth Example: @@ -154,7 +161,8 @@ Closing a driver will immediately shut down all connections in the pool. .. autoclass:: neo4j.Driver() :members: session, query_bookmark_manager, encrypted, close, - verify_connectivity, get_server_info + verify_connectivity, get_server_info, verify_authentication, + supports_session_auth, supports_multi_db .. method:: execute_query(query, parameters_=None,routing_=neo4j.RoutingControl.WRITERS, database_=None, impersonated_user_=None, bookmark_manager_=self.query_bookmark_manager, result_transformer_=Result.to_eager_result, **kwargs) @@ -174,7 +182,7 @@ Closing a driver will immediately shut down all connections in the pool. def execute_query( query_, parameters_, routing_, database_, impersonated_user_, - bookmark_manager_, result_transformer_, **kwargs + bookmark_manager_, auth_, result_transformer_, **kwargs ): def work(tx): result = tx.run(query_, parameters_, **kwargs) @@ -184,6 +192,7 @@ Closing a driver will immediately shut down all connections in the pool. database=database_, impersonated_user=impersonated_user_, bookmark_manager=bookmark_manager_, + auth=auth_, ) as session: if routing_ == RoutingControl.WRITERS: return session.execute_write(work) @@ -263,6 +272,20 @@ Closing a driver will immediately shut down all connections in the pool. See also the Session config :ref:`impersonated-user-ref`. :type impersonated_user_: typing.Optional[str] + :param auth_: + Authentication information to use for this query. + + By default, the driver configuration is used. + + **This is a preview** (see :ref:`filter-warnings-ref`). + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + See also the Session config :ref:`session-auth-ref`. + :type auth_: typing.Union[ + typing.Tuple[typing.Any, typing.Any], neo4j.Auth, None + ] :param result_transformer_: A function that gets passed the :class:`neo4j.Result` object resulting from the query and converts it to a different type. The @@ -273,7 +296,6 @@ Closing a driver will immediately shut down all connections in the pool. The transformer function must **not** return the :class:`neo4j.Result` itself. - .. warning:: N.B. the driver might retry the underlying transaction so the @@ -334,7 +356,7 @@ Closing a driver will immediately shut down all connections in the pool. Defaults to the driver's :attr:`.query_bookmark_manager`. - Pass :const:`None` to disable causal consistency. + Pass :data:`None` to disable causal consistency. :type bookmark_manager_: typing.Union[neo4j.BookmarkManager, neo4j.BookmarkManager, None] @@ -348,7 +370,7 @@ Closing a driver will immediately shut down all connections in the pool. :returns: the result of the ``result_transformer`` :rtype: T - **This is experimental.** (See :ref:`filter-warnings-ref`) + **This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. We are looking for feedback on this feature. Please let us know what @@ -357,6 +379,9 @@ Closing a driver will immediately shut down all connections in the pool. .. versionadded:: 5.5 + .. versionchanged:: 5.8 + Added the ``auth_`` parameter. + .. _driver-configuration-ref: @@ -433,7 +458,7 @@ Specify whether TCP keep-alive should be enabled. :Type: ``bool`` :Default: ``True`` -**This is experimental.** (See :ref:`filter-warnings-ref`) +**This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. @@ -784,6 +809,7 @@ Session .. automethod:: execute_write + Query ===== @@ -804,6 +830,7 @@ To construct a :class:`neo4j.Session` use the :meth:`neo4j.Driver.session` metho + :ref:`default-access-mode-ref` + :ref:`fetch-size-ref` + :ref:`bookmark-manager-ref` ++ :ref:`session-auth-ref` + :ref:`session-notifications-min-severity-ref` + :ref:`session-notifications-disabled-categories-ref` @@ -816,7 +843,7 @@ Optional :class:`neo4j.Bookmarks`. Use this to causally chain sessions. See :meth:`Session.last_bookmarks` or :meth:`AsyncSession.last_bookmarks` for more information. -:Default: ``None`` +:Default: :data:`None` .. deprecated:: 5.0 Alternatively, an iterable of strings can be passed. This usage is @@ -995,10 +1022,33 @@ See :class:`.BookmarkManager` for more information. .. versionadded:: 5.0 -**This is experimental.** (See :ref:`filter-warnings-ref`) +**This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. +.. _session-auth-ref: + +``auth`` +-------- +Optional :class:`neo4j.Auth` or ``(user, password)``-tuple. Use this overwrite the +authentication information for the session. +This requires the server to support re-authentication on the protocol level. You can +check this by calling :meth:`.Driver.supports_session_auth` / :meth:`.AsyncDriver.supports_session_auth`. + +It is not possible to overwrite the authentication information for the session with no authentication, +i.e., downgrade the authentication at session level. +Instead, you should create a driver with no authentication and upgrade the authentication at session level as needed. + +**This is a preview** (see :ref:`filter-warnings-ref`). +It might be changed without following the deprecation policy. +See also https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + +:Type: :data:`None`, :class:`.Auth` or ``(user, password)``-tuple +:Default: :data:`None` - use the authentication information provided during driver creation. + +.. versionadded:: 5.x + + .. _session-notifications-min-severity-ref: ``notifications_min_severity`` @@ -1293,7 +1343,7 @@ Graph .. automethod:: relationship_type -**This is experimental.** (See :ref:`filter-warnings-ref`) +**This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. @@ -1751,6 +1801,8 @@ Server-side errors * :class:`neo4j.exceptions.TokenExpired` + * :class:`neo4j.exceptions.TokenExpiredRetryable` + * :class:`neo4j.exceptions.Forbidden` * :class:`neo4j.exceptions.DatabaseError` @@ -1786,6 +1838,9 @@ Server-side errors .. autoexception:: neo4j.exceptions.TokenExpired() :show-inheritance: +.. autoexception:: neo4j.exceptions.TokenExpiredRetryable() + :show-inheritance: + .. autoexception:: neo4j.exceptions.Forbidden() :show-inheritance: diff --git a/docs/source/async_api.rst b/docs/source/async_api.rst index 5fdcd8e8..35e08a13 100644 --- a/docs/source/async_api.rst +++ b/docs/source/async_api.rst @@ -40,7 +40,7 @@ The :class:`neo4j.AsyncDriver` construction is done via a ``classmethod`` on the await driver.close() # close the driver object - asyncio.run(main()) + asyncio.run(main()) For basic authentication, ``auth`` can be a simple tuple, for example: @@ -67,7 +67,7 @@ The :class:`neo4j.AsyncDriver` construction is done via a ``classmethod`` on the async with AsyncGraphDatabase.driver(uri, auth=auth) as driver: ... # use the driver - asyncio.run(main()) + asyncio.run(main()) @@ -118,6 +118,20 @@ Each supported scheme maps to a particular :class:`neo4j.AsyncDriver` subclass t See https://neo4j.com/docs/operations-manual/current/configuration/ports/ for Neo4j ports. +.. _async-auth-ref: + +Async Auth +========== + +Authentication works the same as in the synchronous driver. +With the exception that when using AuthManagers, their asynchronous equivalents have to be used. + +.. autoclass:: neo4j.auth_management.AsyncAuthManager + :members: + +.. autoclass:: neo4j.auth_management.AsyncAuthManagers + :members: + *********** AsyncDriver @@ -136,7 +150,8 @@ Closing a driver will immediately shut down all connections in the pool. .. autoclass:: neo4j.AsyncDriver() :members: session, query_bookmark_manager, encrypted, close, - verify_connectivity, get_server_info + verify_connectivity, get_server_info, verify_authentication, + supports_session_auth, supports_multi_db .. method:: execute_query(query, parameters_=None, routing_=neo4j.RoutingControl.WRITERS, database_=None, impersonated_user_=None, bookmark_manager_=self.query_bookmark_manager, result_transformer_=AsyncResult.to_eager_result, **kwargs) :async: @@ -157,7 +172,7 @@ Closing a driver will immediately shut down all connections in the pool. async def execute_query( query_, parameters_, routing_, database_, impersonated_user_, - bookmark_manager_, result_transformer_, **kwargs + bookmark_manager_, auth_, result_transformer_, **kwargs ): async def work(tx): result = await tx.run(query_, parameters_, **kwargs) @@ -167,6 +182,7 @@ Closing a driver will immediately shut down all connections in the pool. database=database_, impersonated_user=impersonated_user_, bookmark_manager=bookmark_manager_, + auth=auth_, ) as session: if routing_ == RoutingControl.WRITERS: return await session.execute_write(work) @@ -202,13 +218,14 @@ Closing a driver will immediately shut down all connections in the pool. async def example(driver: neo4j.AsyncDriver) -> int: """Call all young people "My dear" and get their count.""" record = await driver.execute_query( - "MATCH (p:Person) WHERE p.age <= 15 " + "MATCH (p:Person) WHERE p.age <= $age " "SET p.nickname = 'My dear' " "RETURN count(*)", # optional routing parameter, as write is default # routing_=neo4j.RoutingControl.WRITERS, # or just "w", database_="neo4j", result_transformer_=neo4j.AsyncResult.single, + age=15, ) assert record is not None # for typechecking and illustration count = record[0] @@ -245,6 +262,20 @@ Closing a driver will immediately shut down all connections in the pool. See also the Session config :ref:`impersonated-user-ref`. :type impersonated_user_: typing.Optional[str] + :param auth_: + Authentication information to use for this query. + + By default, the driver configuration is used. + + **This is a preview** (see :ref:`filter-warnings-ref`). + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + See also the Session config :ref:`session-auth-ref`. + :type auth_: typing.Union[ + typing.Tuple[typing.Any, typing.Any], neo4j.Auth, None + ] :param result_transformer_: A function that gets passed the :class:`neo4j.AsyncResult` object resulting from the query and converts it to a different type. The @@ -315,7 +346,7 @@ Closing a driver will immediately shut down all connections in the pool. Defaults to the driver's :attr:`.query_bookmark_manager`. - Pass :const:`None` to disable causal consistency. + Pass :data:`None` to disable causal consistency. :type bookmark_manager_: typing.Union[neo4j.AsyncBookmarkManager, neo4j.BookmarkManager, None] @@ -329,7 +360,7 @@ Closing a driver will immediately shut down all connections in the pool. :returns: the result of the ``result_transformer`` :rtype: T - **This is experimental.** (See :ref:`filter-warnings-ref`) + **This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. We are looking for feedback on this feature. Please let us know what @@ -338,6 +369,9 @@ Closing a driver will immediately shut down all connections in the pool. .. versionadded:: 5.5 + .. versionchanged:: 5.8 + Added the ``auth_`` parameter. + .. _async-driver-configuration-ref: @@ -345,8 +379,11 @@ Async Driver Configuration ========================== :class:`neo4j.AsyncDriver` is configured exactly like :class:`neo4j.Driver` -(see :ref:`driver-configuration-ref`). The only difference is that the async -driver accepts an async custom resolver function: +(see :ref:`driver-configuration-ref`). The only differences are that the async +driver accepts + + * a sync as well as an async custom resolver function (see :ref:`async-resolver-ref`) + * as sync as well as an async auth token manager (see :class:`neo4j.AsyncAuthManager`). .. _async-resolver-ref: @@ -622,7 +659,7 @@ See :class:`BookmarkManager` for more information. :Type: :data:`None`, :class:`BookmarkManager`, or :class:`AsyncBookmarkManager` :Default: :data:`None` -**This is experimental.** (See :ref:`filter-warnings-ref`) +**This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. diff --git a/requirements-dev.txt b/requirements-dev.txt index 7a157f09..1845d8e3 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -17,6 +17,7 @@ tomlkit~=0.11.6 # needed for running tests coverage[toml]>=5.5 +freezegun >= 1.2.2 mock>=4.0.3 pyarrow>=1.0.0 pytest>=6.2.5 diff --git a/src/neo4j/__init__.py b/src/neo4j/__init__.py index ed42a115..b18803d5 100644 --- a/src/neo4j/__init__.py +++ b/src/neo4j/__init__.py @@ -52,6 +52,7 @@ deprecation_warn as _deprecation_warn, ExperimentalWarning, get_user_agent, + PreviewWarning, version as __version__, ) from ._sync.driver import ( @@ -140,11 +141,13 @@ "NotificationFilter", "NotificationSeverity", "PoolConfig", + "PreviewWarning", "Query", "READ_ACCESS", "Record", "Result", "ResultSummary", + "RenewableAuth", "RoutingControl", "ServerInfo", "Session", diff --git a/src/neo4j/_async/auth_management.py b/src/neo4j/_async/auth_management.py new file mode 100644 index 00000000..de6d0055 --- /dev/null +++ b/src/neo4j/_async/auth_management.py @@ -0,0 +1,206 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# from __future__ import annotations +# work around for https://github.com/sphinx-doc/sphinx/pull/10880 +# make sure TAuth is resolved in the docs, else they're pretty useless + +import time +import typing as t +from logging import getLogger + +from .._async_compat.concurrency import AsyncLock +from .._auth_management import ( + AsyncAuthManager, + ExpiringAuth, +) +from .._meta import preview + +# work around for https://github.com/sphinx-doc/sphinx/pull/10880 +# make sure TAuth is resolved in the docs, else they're pretty useless +# if t.TYPE_CHECKING: +from ..api import _TAuth + + +log = getLogger("neo4j") + + +class AsyncStaticAuthManager(AsyncAuthManager): + _auth: _TAuth + + def __init__(self, auth: _TAuth) -> None: + self._auth = auth + + async def get_auth(self) -> _TAuth: + return self._auth + + async def on_auth_expired(self, auth: _TAuth) -> None: + pass + + +class _ExpiringAuthHolder: + def __init__(self, auth: ExpiringAuth) -> None: + self._auth = auth + self._expiry = None + if auth.expires_in is not None: + self._expiry = time.monotonic() + auth.expires_in + + @property + def auth(self) -> _TAuth: + return self._auth.auth + + def expired(self) -> bool: + if self._expiry is None: + return False + return time.monotonic() > self._expiry + +class AsyncExpirationBasedAuthManager(AsyncAuthManager): + _current_auth: t.Optional[_ExpiringAuthHolder] + _provider: t.Callable[[], t.Awaitable[ExpiringAuth]] + _lock: AsyncLock + + + def __init__( + self, + provider: t.Callable[[], t.Awaitable[ExpiringAuth]] + ) -> None: + self._provider = provider + self._current_auth = None + self._lock = AsyncLock() + + async def _refresh_auth(self): + self._current_auth = _ExpiringAuthHolder(await self._provider()) + + async def get_auth(self) -> _TAuth: + async with self._lock: + auth = self._current_auth + if auth is not None and not auth.expired(): + return auth.auth + log.debug("[ ] _: refreshing (time out)") + await self._refresh_auth() + assert self._current_auth is not None + return self._current_auth.auth + + async def on_auth_expired(self, auth: _TAuth) -> None: + async with self._lock: + cur_auth = self._current_auth + if cur_auth is not None and cur_auth.auth == auth: + log.debug("[ ] _: refreshing (error)") + await self._refresh_auth() + + +class AsyncAuthManagers: + """A collection of :class:`.AsyncAuthManager` factories. + + **This is a preview** (see :ref:`filter-warnings-ref`). + It might be changed without following the deprecation policy. + See also https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + .. versionadded:: 5.8 + """ + + @staticmethod + @preview("Auth managers are a preview feature.") + def static(auth: _TAuth) -> AsyncAuthManager: + """Create a static auth manager. + + Example:: + + # NOTE: this example is for illustration purposes only. + # The driver will automatically wrap static auth info in a + # static auth manager. + + import neo4j + from neo4j.auth_management import AsyncAuthManagers + + + auth = neo4j.basic_auth("neo4j", "password") + + with neo4j.GraphDatabase.driver( + "neo4j://example.com:7687", + auth=AsyncAuthManagers.static(auth) + # auth=auth # this is equivalent + ) as driver: + ... # do stuff + + :param auth: The auth to return. + + :returns: + An instance of an implementation of :class:`.AsyncAuthManager` that + always returns the same auth. + """ + return AsyncStaticAuthManager(auth) + + @staticmethod + @preview("Auth managers are a preview feature.") + def expiration_based( + provider: t.Callable[[], t.Awaitable[ExpiringAuth]] + ) -> AsyncAuthManager: + """Create an auth manager for potentially expiring auth info. + + .. warning:: + + The provider function **must not** interact with the driver in any + way as this can cause deadlocks and undefined behaviour. + + The provider function only ever return auth information belonging + to the same identity. + Switching identities is undefined behavior. + + Example:: + + import neo4j + from neo4j.auth_management import ( + AsyncAuthManagers, + ExpiringAuth, + ) + + + async def auth_provider(): + # some way to getting a token + sso_token = await get_sso_token() + # assume we know our tokens expire every 60 seconds + expires_in = 60 + + return ExpiringAuth( + auth=neo4j.bearer_auth(sso_token), + # Include a little buffer so that we fetch a new token + # *before* the old one expires + expires_in=expires_in - 10 + ) + + + with neo4j.GraphDatabase.driver( + "neo4j://example.com:7687", + auth=AsyncAuthManagers.temporal(auth_provider) + ) as driver: + ... # do stuff + + :param provider: + A callable that provides a :class:`.ExpiringAuth` instance. + + :returns: + An instance of an implementation of :class:`.AsyncAuthManager` that + returns auth info from the given provider and refreshes it, calling + the provider again, when the auth info expires (either because it's + reached its expiry time or because the server flagged it as + expired). + + + """ + return AsyncExpirationBasedAuthManager(provider) diff --git a/src/neo4j/_async/driver.py b/src/neo4j/_async/driver.py index b9a5110c..81180bb7 100644 --- a/src/neo4j/_async/driver.py +++ b/src/neo4j/_async/driver.py @@ -32,7 +32,6 @@ T_NotificationMinimumSeverity, ) - from .._api import RoutingControl from .._async_compat.util import AsyncUtil from .._conf import ( @@ -48,6 +47,9 @@ experimental, experimental_warn, ExperimentalWarning, + preview, + preview_warn, + PreviewWarning, unclosed_resource_warn, ) from .._work import EagerResult @@ -74,6 +76,11 @@ URI_SCHEME_NEO4J_SECURE, URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE, ) +from ..auth_management import ( + AsyncAuthManager, + AsyncAuthManagers, +) +from ..exceptions import Neo4jError from .bookmark_manager import ( AsyncNeo4jBookmarkManager, TBmConsumer as _TBmConsumer, @@ -93,6 +100,7 @@ import typing_extensions as te from .._api import T_RoutingControl + from ..api import _TAuth class _DefaultEnum(Enum): @@ -117,7 +125,13 @@ def driver( cls, uri: str, *, - auth: t.Union[t.Tuple[t.Any, t.Any], Auth, None] = ..., + auth: t.Union[ + # work around https://github.com/sphinx-doc/sphinx/pull/10880 + # make sure TAuth is resolved in the docs + # TAuth, + t.Union[t.Tuple[t.Any, t.Any], Auth, None], + AsyncAuthManager, + ] = ..., max_connection_lifetime: float = ..., max_connection_pool_size: int = ..., connection_timeout: float = ..., @@ -161,7 +175,13 @@ def driver( @classmethod def driver( cls, uri: str, *, - auth: t.Union[t.Tuple[t.Any, t.Any], Auth, None] = None, + auth: t.Union[ + # work around https://github.com/sphinx-doc/sphinx/pull/10880 + # make sure TAuth is resolved in the docs + # TAuth, + t.Union[t.Tuple[t.Any, t.Any], Auth, None], + AsyncAuthManager, + ] = None, **config ) -> AsyncDriver: """Create a driver. @@ -177,6 +197,18 @@ def driver( driver_type, security_type, parsed = parse_neo4j_uri(uri) + if not isinstance(auth, AsyncAuthManager): + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message=r".*\bAuth managers\b.*", + category=PreviewWarning + ) + auth = AsyncAuthManagers.static(auth) + else: + preview_warn("Auth managers are a preview feature.", + stack_level=2) + config["auth"] = auth + # TODO: 6.0 - remove "trust" config option if "trust" in config.keys(): if config["trust"] not in ( @@ -254,10 +286,10 @@ def driver( # 'Routing parameters are not supported with scheme ' # '"bolt". Given URI "{}".'.format(uri) # ) - return cls.bolt_driver(parsed.netloc, auth=auth, **config) + return cls.bolt_driver(parsed.netloc, **config) # else driver_type == DRIVER_NEO4J routing_context = parse_routing_context(parsed.query) - return cls.neo4j_driver(parsed.netloc, auth=auth, + return cls.neo4j_driver(parsed.netloc, routing_context=routing_context, **config) @classmethod @@ -325,7 +357,7 @@ def bookmark_manager( :returns: A default implementation of :class:`AsyncBookmarkManager`. - **This is experimental.** (See :ref:`filter-warnings-ref`) + **This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. .. versionadded:: 5.0 @@ -349,7 +381,7 @@ def bookmark_manager( ) @classmethod - def bolt_driver(cls, target, *, auth=None, **config): + def bolt_driver(cls, target, **config): """ Create a driver for direct Bolt server access that uses socket I/O and thread-based concurrency. """ @@ -359,13 +391,13 @@ def bolt_driver(cls, target, *, auth=None, **config): ) try: - return AsyncBoltDriver.open(target, auth=auth, **config) + return AsyncBoltDriver.open(target, **config) except (BoltHandshakeError, BoltSecurityError) as error: from ..exceptions import ServiceUnavailable raise ServiceUnavailable(str(error)) from error @classmethod - def neo4j_driver(cls, *targets, auth=None, routing_context=None, **config): + def neo4j_driver(cls, *targets, routing_context=None, **config): """ Create a driver for routing-capable Neo4j service access that uses socket I/O and thread-based concurrency. """ @@ -375,7 +407,7 @@ def neo4j_driver(cls, *targets, auth=None, routing_context=None, **config): ) try: - return AsyncNeo4jDriver.open(*targets, auth=auth, routing_context=routing_context, **config) + return AsyncNeo4jDriver.open(*targets, routing_context=routing_context, **config) except (BoltHandshakeError, BoltSecurityError) as error: from ..exceptions import ServiceUnavailable raise ServiceUnavailable(str(error)) from error @@ -482,6 +514,9 @@ def encrypted(self) -> bool: return bool(self._pool.pool_config.encrypted) def _prepare_session_config(self, **config): + if "auth" in config: + preview_warn("User switching is a preview features.", + stack_level=3) _normalize_notifications_config(config) return config @@ -499,6 +534,7 @@ def session( default_access_mode: str = ..., bookmark_manager: t.Union[AsyncBookmarkManager, BookmarkManager, None] = ..., + auth: _TAuth = ..., notifications_min_severity: t.Optional[ T_NotificationMinimumSeverity ] = ..., @@ -549,6 +585,7 @@ async def execute_query( bookmark_manager_: t.Union[ AsyncBookmarkManager, BookmarkManager, None ] = ..., + auth_: _TAuth = None, result_transformer_: t.Callable[ [AsyncResult], t.Awaitable[EagerResult] ] = ..., @@ -567,6 +604,7 @@ async def execute_query( bookmark_manager_: t.Union[ AsyncBookmarkManager, BookmarkManager, None ] = ..., + auth_: _TAuth = None, result_transformer_: t.Callable[ [AsyncResult], t.Awaitable[_T] ] = ..., @@ -589,6 +627,7 @@ async def execute_query( AsyncBookmarkManager, BookmarkManager, None, te.Literal[_DefaultEnum.default] ] = _default, + auth_: _TAuth = None, result_transformer_: t.Callable[ [AsyncResult], t.Awaitable[t.Any] ] = AsyncResult.to_eager_result, @@ -610,7 +649,7 @@ async def execute_query( async def execute_query( query_, parameters_, routing_, database_, impersonated_user_, - bookmark_manager_, result_transformer_, **kwargs + bookmark_manager_, auth_, result_transformer_, **kwargs ): async def work(tx): result = await tx.run(query_, parameters_, **kwargs) @@ -620,6 +659,7 @@ async def work(tx): database=database_, impersonated_user=impersonated_user_, bookmark_manager=bookmark_manager_, + auth=auth_, ) as session: if routing_ == RoutingControl.WRITERS: return await session.execute_write(work) @@ -655,13 +695,14 @@ async def example(driver: neo4j.AsyncDriver) -> List[str]: async def example(driver: neo4j.AsyncDriver) -> int: \"""Call all young people "My dear" and get their count.\""" record = await driver.execute_query( - "MATCH (p:Person) WHERE p.age <= 15 " + "MATCH (p:Person) WHERE p.age <= $age " "SET p.nickname = 'My dear' " "RETURN count(*)", # optional routing parameter, as write is default # routing_=neo4j.RoutingControl.WRITERS, # or just "w", database_="neo4j", result_transformer_=neo4j.AsyncResult.single, + age=15, ) assert record is not None # for typechecking and illustration count = record[0] @@ -698,6 +739,20 @@ async def example(driver: neo4j.AsyncDriver) -> int: See also the Session config :ref:`impersonated-user-ref`. :type impersonated_user_: typing.Optional[str] + :param auth_: + Authentication information to use for this query. + + By default, the driver configuration is used. + + **This is a preview** (see :ref:`filter-warnings-ref`). + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + See also the Session config :ref:`session-auth-ref`. + :type auth_: typing.Union[ + typing.Tuple[typing.Any, typing.Any], neo4j.Auth, None + ] :param result_transformer_: A function that gets passed the :class:`neo4j.AsyncResult` object resulting from the query and converts it to a different type. The @@ -768,7 +823,7 @@ async def example(driver: neo4j.AsyncDriver) -> neo4j.Record:: Defaults to the driver's :attr:`.query_bookmark_manager`. - Pass :const:`None` to disable causal consistency. + Pass :data:`None` to disable causal consistency. :type bookmark_manager_: typing.Union[neo4j.AsyncBookmarkManager, neo4j.BookmarkManager, None] @@ -782,10 +837,17 @@ async def example(driver: neo4j.AsyncDriver) -> neo4j.Record:: :returns: the result of the ``result_transformer`` :rtype: T - **This is experimental.** (See :ref:`filter-warnings-ref`) + **This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. + We are looking for feedback on this feature. Please let us know what + you think about it here: + https://github.com/neo4j/neo4j-python-driver/discussions/896 + .. versionadded:: 5.5 + + .. versionchanged:: 5.8 + Added the ``auth_`` parameter. """ invalid_kwargs = [k for k in kwargs if k[-2:-1] != "_" and k[-1:] == "_"] @@ -807,9 +869,13 @@ async def example(driver: neo4j.AsyncDriver) -> neo4j.Record:: warnings.filterwarnings("ignore", message=r".*\bbookmark_manager\b.*", category=ExperimentalWarning) + warnings.filterwarnings("ignore", + message=r"^User switching\b.*", + category=PreviewWarning) session = self.session(database=database_, impersonated_user=impersonated_user_, - bookmark_manager=bookmark_manager_) + bookmark_manager=bookmark_manager_, + auth=auth_) async with session: if routing_ == RoutingControl.WRITERS: executor = session.execute_write @@ -846,7 +912,7 @@ async def example(driver: neo4j.AsyncDriver) -> None: # (i.e., can read what was written by ) await driver.execute_query("") - **This is experimental.** (See :ref:`filter-warnings-ref`) + **This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. .. versionadded:: 5.5 @@ -870,6 +936,7 @@ async def verify_connectivity( default_access_mode: str = ..., bookmark_manager: t.Union[AsyncBookmarkManager, BookmarkManager, None] = ..., + auth: t.Union[Auth, t.Tuple[t.Any, t.Any]] = ..., notifications_min_severity: t.Optional[ T_NotificationMinimumSeverity ] = ..., @@ -886,7 +953,6 @@ async def verify_connectivity( else: - # TODO: 6.0 - remove config argument async def verify_connectivity(self, **config) -> None: """Verify that the driver can establish a connection to the server. @@ -905,7 +971,7 @@ async def verify_connectivity(self, **config) -> None: They might be changed or removed in any future version without prior notice. - :raises DriverError: if the driver cannot connect to the remote. + :raises Exception: if the driver cannot connect to the remote. Use the exception to further understand the cause of the connectivity problem. @@ -921,8 +987,7 @@ async def verify_connectivity(self, **config) -> None: "changed or removed in any future version without prior " "notice." ) - async with self.session(**config) as session: - await session._get_server_info() + await self._get_server_info() if t.TYPE_CHECKING: @@ -941,6 +1006,7 @@ async def get_server_info( default_access_mode: str = ..., bookmark_manager: t.Union[AsyncBookmarkManager, BookmarkManager, None] = ..., + auth: t.Union[Auth, t.Tuple[t.Any, t.Any]] = ..., notifications_min_severity: t.Optional[ T_NotificationMinimumSeverity ] = ..., @@ -979,7 +1045,7 @@ async def get_server_info(self, **config) -> ServerInfo: They might be changed or removed in any future version without prior notice. - :raises DriverError: if the driver cannot connect to the remote. + :raises Exception: if the driver cannot connect to the remote. Use the exception to further understand the cause of the connectivity problem. @@ -988,14 +1054,12 @@ async def get_server_info(self, **config) -> ServerInfo: if config: experimental_warn( "All configuration key-word arguments to " - "verify_connectivity() are experimental. They might be " + "get_server_info() are experimental. They might be " "changed or removed in any future version without prior " "notice." ) - async with self.session(**config) as session: - return await session._get_server_info() + return await self._get_server_info() - @experimental("Feature support query, based on Bolt protocol version and Neo4j server version will change in the future.") async def supports_multi_db(self) -> bool: """ Check if the server or cluster supports multi-databases. @@ -1003,14 +1067,142 @@ async def supports_multi_db(self) -> bool: supports multi-databases, otherwise false. .. note:: - Feature support query, based on Bolt Protocol Version and Neo4j - server version will change in the future. + Feature support query based solely on the Bolt protocol version. + The feature might still be disabled on the server side even if this + function return :const:`True`. It just guarantees that the driver + won't throw a :exc:`ConfigurationError` when trying to use this + driver feature. """ async with self.session() as session: await session._connect(READ_ACCESS) assert session._connection return session._connection.supports_multiple_databases + if t.TYPE_CHECKING: + + async def verify_authentication( + self, + auth: t.Union[Auth, t.Tuple[t.Any, t.Any], None] = None, + # all other arguments are experimental + # they may be change or removed any time without prior notice + session_connection_timeout: float = ..., + connection_acquisition_timeout: float = ..., + max_transaction_retry_time: float = ..., + database: t.Optional[str] = ..., + fetch_size: int = ..., + impersonated_user: t.Optional[str] = ..., + bookmarks: t.Union[t.Iterable[str], Bookmarks, None] = ..., + default_access_mode: str = ..., + bookmark_manager: t.Union[ + AsyncBookmarkManager, BookmarkManager, None + ] = ..., + + # undocumented/unsupported options + initial_retry_delay: float = ..., + retry_delay_multiplier: float = ..., + retry_delay_jitter_factor: float = ... + ) -> bool: + ... + + else: + + @preview("User switching is a preview feature.") + async def verify_authentication( + self, + auth: t.Union[Auth, t.Tuple[t.Any, t.Any], None] = None, + **config + ) -> bool: + """Verify that the authentication information is valid. + + Like :meth:`.verify_connectivity`, but for checking authentication. + + Try to establish a working read connection to the remote server or + a member of a cluster and exchange some data. In a cluster, there + is no guarantee about which server will be contacted. If the data + exchange is successful, the authentication information is valid and + :const:`True` is returned. Otherwise, the error will be matched + against a list of known authentication errors. If the error is on + that list, :const:`False` is returned indicating that the + authentication information is invalid. Otherwise, the error is + re-raised. + + :param auth: authentication information to verify. + Same as the session config :ref:`auth-ref`. + :param config: accepts the same configuration key-word arguments as + :meth:`session`. + + .. warning:: + All configuration key-word arguments (except ``auth``) are + experimental. They might be changed or removed in any + future version without prior notice. + + :raises Exception: if the driver cannot connect to the remote. + Use the exception to further understand the cause of the + connectivity problem. + + **This is a preview** (see :ref:`filter-warnings-ref`). + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + .. versionadded:: 5.8 + """ + if config: + experimental_warn( + "All configuration key-word arguments but auth to " + "verify_authentication() are experimental. They might be " + "changed or removed in any future version without prior " + "notice." + ) + config["auth"] = auth + if "database" not in config: + config["database"] = "system" + try: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message=r"^User switching\b.*", + category=PreviewWarning + ) + session = self.session(**config) + async with session as session: + await session._verify_authentication() + except Neo4jError as exc: + if exc.code in ( + "Neo.ClientError.Security.CredentialsExpired", + "Neo.ClientError.Security.Forbidden", + "Neo.ClientError.Security.TokenExpired", + "Neo.ClientError.Security.Unauthorized", + ): + return False + raise + return True + + + async def supports_session_auth(self) -> bool: + """Check if the remote supports connection re-authentication. + + :returns: Returns true if the server or cluster the driver connects to + supports re-authentication of existing connections, otherwise + false. + + .. note:: + Feature support query based solely on the Bolt protocol version. + The feature might still be disabled on the server side even if this + function return :const:`True`. It just guarantees that the driver + won't throw a :exc:`ConfigurationError` when trying to use this + driver feature. + + .. versionadded:: 5.x + """ + async with self.session() as session: + await session._connect(READ_ACCESS) + assert session._connection + return session._connection.supports_re_auth + + async def _get_server_info(self, **config) -> ServerInfo: + async with self.session(**config) as session: + return await session._get_server_info() + async def _work( tx: AsyncManagedTransaction, @@ -1041,7 +1233,7 @@ class AsyncBoltDriver(_Direct, AsyncDriver): """ @classmethod - def open(cls, target, *, auth=None, **config): + def open(cls, target, **config): """ :param target: :param auth: @@ -1053,7 +1245,7 @@ def open(cls, target, *, auth=None, **config): from .io import AsyncBoltPool address = cls.parse_target(target) pool_config, default_workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) - pool = AsyncBoltPool.open(address, auth=auth, pool_config=pool_config, workspace_config=default_workspace_config) + pool = AsyncBoltPool.open(address, pool_config=pool_config, workspace_config=default_workspace_config) return cls(pool, default_workspace_config) def __init__(self, pool, default_workspace_config): @@ -1089,11 +1281,11 @@ class AsyncNeo4jDriver(_Routing, AsyncDriver): """ @classmethod - def open(cls, *targets, auth=None, routing_context=None, **config): + def open(cls, *targets, routing_context=None, **config): from .io import AsyncNeo4jPool addresses = cls.parse_targets(*targets) pool_config, default_workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) - pool = AsyncNeo4jPool.open(*addresses, auth=auth, routing_context=routing_context, pool_config=pool_config, workspace_config=default_workspace_config) + pool = AsyncNeo4jPool.open(*addresses, routing_context=routing_context, pool_config=pool_config, workspace_config=default_workspace_config) return cls(pool, default_workspace_config) def __init__(self, pool, default_workspace_config): diff --git a/src/neo4j/_async/io/__init__.py b/src/neo4j/_async/io/__init__.py index 3b28553e..f3f628f1 100644 --- a/src/neo4j/_async/io/__init__.py +++ b/src/neo4j/_async/io/__init__.py @@ -24,6 +24,7 @@ __all__ = [ + "AcquireAuth", "AsyncBolt", "AsyncBoltPool", "AsyncNeo4jPool", @@ -38,6 +39,7 @@ ConnectionErrorHandler, ) from ._pool import ( + AcquireAuth, AsyncBoltPool, AsyncNeo4jPool, ) diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index 4c9ebfbb..28472ecd 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -25,6 +25,7 @@ from time import perf_counter from ..._async_compat.network import AsyncBoltSocket +from ..._async_compat.util import AsyncUtil from ..._codec.hydration import v1 as hydration_v1 from ..._codec.packstream import v1 as packstream_v1 from ..._conf import PoolConfig @@ -34,7 +35,7 @@ BoltHandshakeError, ) from ..._meta import get_user_agent -from ...addressing import Address +from ...addressing import ResolvedAddress from ...api import ( ServerInfo, Version, @@ -58,6 +59,20 @@ log = getLogger("neo4j") +class ServerStateManagerBase(abc.ABC): + @abc.abstractmethod + def __init__(self, init_state, on_change=None): + ... + + @abc.abstractmethod + def transition(self, message, metadata): + ... + + @abc.abstractmethod + def failed(self): + ... + + class AsyncBolt: """ Server connection for Bolt protocol. @@ -104,13 +119,17 @@ class AsyncBolt: most_recent_qid = None def __init__(self, unresolved_address, sock, max_connection_lifetime, *, - auth=None, user_agent=None, routing_context=None, - notifications_min_severity=None, notifications_disabled_categories=None): + auth=None, auth_manager=None, user_agent=None, + routing_context=None, notifications_min_severity=None, + notifications_disabled_categories=None): self.unresolved_address = unresolved_address self.socket = sock self.local_port = self.socket.getsockname()[1] - self.server_info = ServerInfo(Address(sock.getpeername()), - self.PROTOCOL_VERSION) + self.server_info = ServerInfo( + ResolvedAddress(sock.getpeername(), + host_name=unresolved_address.host), + self.PROTOCOL_VERSION + ) # so far `connection.recv_timeout_seconds` is the only available # configuration hint that exists. Therefore, all hints can be stored at # connection level. This might change in the future. @@ -137,26 +156,9 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, else: self.user_agent = get_user_agent() - # Determine auth details - if not auth: - self.auth_dict = {} - elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: - from ...api import Auth - self.auth_dict = vars(Auth("basic", *auth)) - else: - try: - self.auth_dict = vars(auth) - except (KeyError, TypeError): - raise AuthError("Cannot determine auth details from %r" % auth) - - # Check for missing password - try: - credentials = self.auth_dict["credentials"] - except KeyError: - pass - else: - if credentials is None: - raise AuthError("Password cannot be None") + self.auth = auth + self.auth_dict = self._to_auth_dict(auth) + self.auth_manager = auth_manager self.notifications_min_severity = notifications_min_severity self.notifications_disabled_categories = \ @@ -166,6 +168,24 @@ def __del__(self): if not asyncio.iscoroutinefunction(self.close): self.close() + @abc.abstractmethod + def _get_server_state_manager(self) -> ServerStateManagerBase: + ... + + @classmethod + def _to_auth_dict(cls, auth): + # Determine auth details + if not auth: + return {} + elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: + from ...api import Auth + return vars(Auth("basic", *auth)) + else: + try: + return vars(auth) + except (KeyError, TypeError): + raise AuthError("Cannot determine auth details from %r" % auth) + @property def connection_id(self): return self.server_info._metadata.get("connection_id", "") @@ -187,6 +207,20 @@ def supports_multiple_databases(self): """ pass + @property + @abc.abstractmethod + def supports_re_auth(self): + """Whether the connection version supports re-authentication.""" + pass + + def assert_re_auth_support(self): + if not self.supports_re_auth: + raise ConfigurationError( + "User switching is not supported for Bolt " + f"Protocol {self.PROTOCOL_VERSION!r}. Server Agent " + f"{self.server_info.agent!r}" + ) + @property @abc.abstractmethod def supports_notification_filtering(self): @@ -318,13 +352,13 @@ async def ping(cls, address, *, deadline=None, pool_config=None): # [bolt-version-bump] search tag when changing bolt version support @classmethod async def open( - cls, address, *, auth=None, deadline=None, routing_context=None, - pool_config=None + cls, address, *, auth_manager=None, deadline=None, + routing_context=None, pool_config=None ): """Open a new Bolt connection to a given server address. :param address: - :param auth: + :param auth_manager: :param deadline: how long to wait for the connection to be established :param routing_context: dict containing routing context :param pool_config: @@ -386,7 +420,7 @@ async def open( from ._bolt3 import AsyncBolt3 bolt_cls = AsyncBolt3 else: - log.debug("[#%04X] S: ", s.getsockname()[1]) + log.debug("[#%04X] C: ", s.getsockname()[1]) await AsyncBoltSocket.close_socket(s) supported_versions = cls.protocol_handlers().keys() @@ -397,9 +431,23 @@ async def open( address=address, request_data=handshake, response_data=data ) + try: + auth = await AsyncUtil.callback(auth_manager.get_auth) + except asyncio.CancelledError as e: + log.debug("[#%04X] C: open auth manager failed: %r", + s.getsockname()[1], e) + s.kill() + raise + except Exception as e: + log.debug("[#%04X] C: open auth manager failed: %r", + s.getsockname()[1], e) + await s.close() + raise + connection = bolt_cls( address, s, pool_config.max_connection_lifetime, auth=auth, - user_agent=pool_config.user_agent, routing_context=routing_context, + auth_manager=auth_manager, user_agent=pool_config.user_agent, + routing_context=routing_context, notifications_min_severity=pool_config.notifications_min_severity, notifications_disabled_categories= pool_config.notifications_disabled_categories @@ -448,6 +496,45 @@ async def hello(self, dehydration_hooks=None, hydration_hooks=None): """ pass + @abc.abstractmethod + def logon(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGON message to the outgoing queue.""" + pass + + @abc.abstractmethod + def logoff(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGOFF message to the outgoing queue.""" + pass + + def mark_unauthenticated(self): + """Mark the connection as unauthenticated.""" + self.auth_dict = {} + + def re_auth( + self, auth, auth_manager, force=False, + dehydration_hooks=None, hydration_hooks=None, + ): + """Append LOGON, LOGOFF to the outgoing queue. + + If auth is the same as the current auth, this method does nothing. + + :returns: whether the auth was changed + """ + new_auth_dict = self._to_auth_dict(auth) + if not force and new_auth_dict == self.auth_dict: + self.auth_manager = auth_manager + self.auth = auth + return False + self.logoff(dehydration_hooks=dehydration_hooks, + hydration_hooks=hydration_hooks) + self.auth_dict = new_auth_dict + self.auth_manager = auth_manager + self.auth = auth + self.logon(dehydration_hooks=dehydration_hooks, + hydration_hooks=hydration_hooks) + return True + + @abc.abstractmethod async def route( self, database=None, imp_user=None, bookmarks=None, @@ -767,7 +854,7 @@ async def _set_defunct(self, message, error=None, silent=False): # remove the connection from the pool, nor to try to close the # connection again. await self.close() - if self.pool: + if self.pool and not self._get_server_state_manager().failed(): await self.pool.deactivate(address=self.unresolved_address) # Iterate through the outstanding responses, and if any correspond diff --git a/src/neo4j/_async/io/_bolt3.py b/src/neo4j/_async/io/_bolt3.py index 483c525a..a3629c5f 100644 --- a/src/neo4j/_async/io/_bolt3.py +++ b/src/neo4j/_async/io/_bolt3.py @@ -16,6 +16,9 @@ # limitations under the License. +from __future__ import annotations + +import typing as t from enum import Enum from logging import getLogger from ssl import SSLSocket @@ -33,7 +36,10 @@ NotALeader, ServiceUnavailable, ) -from ._bolt import AsyncBolt +from ._bolt import ( + AsyncBolt, + ServerStateManagerBase, +) from ._common import ( check_supported_server_product, CommitResponse, @@ -53,8 +59,8 @@ class ServerStates(Enum): FAILED = "FAILED" -class ServerStateManager: - _STATE_TRANSITIONS = { +class ServerStateManager(ServerStateManagerBase): + _STATE_TRANSITIONS: t.Dict[Enum, t.Dict[str, Enum]] = { ServerStates.CONNECTED: { "hello": ServerStates.READY, }, @@ -91,6 +97,9 @@ def transition(self, message, metadata): if state_before != self.state and callable(self._on_change): self._on_change(state_before, self.state) + def failed(self): + return self.state == ServerStates.FAILED + class AsyncBolt3(AsyncBolt): """ Protocol handler for Bolt 3. @@ -104,6 +113,8 @@ class AsyncBolt3(AsyncBolt): supports_multiple_databases = False + supports_re_auth = False + supports_notification_filtering = False def __init__(self, *args, **kwargs): @@ -116,6 +127,9 @@ def _on_server_state_change(self, old_state, new_state): log.debug("[#%04X] _: state: %s > %s", self.local_port, old_state.name, new_state.name) + def _get_server_state_manager(self) -> ServerStateManagerBase: + return self._server_state_manager + @property def is_reset(self): # We can't be sure of the server's state if there are still pending @@ -159,6 +173,14 @@ async def hello(self, dehydration_hooks=None, hydration_hooks=None): await self.fetch_all() check_supported_server_product(self.server_info.agent) + def logon(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGON message to the outgoing queue.""" + self.assert_re_auth_support() + + def logoff(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGOFF message to the outgoing queue.""" + self.assert_re_auth_support() + async def route( self, database=None, imp_user=None, bookmarks=None, dehydration_hooks=None, hydration_hooks=None @@ -397,8 +419,7 @@ async def _process_message(self, tag, fields): self.pool.on_write_failure(address=self.unresolved_address) raise except Neo4jError as e: - if self.pool and e._invalidates_all_connections(): - await self.pool.mark_all_stale() + await self.pool.on_neo4j_error(e, self) raise else: raise BoltProtocolError("Unexpected response message with signature %02X" % summary_signature, address=self.unresolved_address) diff --git a/src/neo4j/_async/io/_bolt4.py b/src/neo4j/_async/io/_bolt4.py index 10dc99e3..2c5dfb78 100644 --- a/src/neo4j/_async/io/_bolt4.py +++ b/src/neo4j/_async/io/_bolt4.py @@ -19,7 +19,6 @@ from logging import getLogger from ssl import SSLSocket -from ..._async_compat.util import AsyncUtil from ..._exceptions import BoltProtocolError from ...api import ( READ_ACCESS, @@ -34,7 +33,10 @@ NotALeader, ServiceUnavailable, ) -from ._bolt import AsyncBolt +from ._bolt import ( + AsyncBolt, + ServerStateManagerBase, +) from ._bolt3 import ( ServerStateManager, ServerStates, @@ -62,6 +64,8 @@ class AsyncBolt4x0(AsyncBolt): supports_multiple_databases = True + supports_re_auth = False + supports_notification_filtering = False def __init__(self, *args, **kwargs): @@ -74,6 +78,9 @@ def _on_server_state_change(self, old_state, new_state): log.debug("[#%04X] _: state: %s > %s", self.local_port, old_state.name, new_state.name) + def _get_server_state_manager(self) -> ServerStateManagerBase: + return self._server_state_manager + @property def is_reset(self): # We can't be sure of the server's state if there are still pending @@ -119,6 +126,14 @@ async def hello(self, dehydration_hooks=None, hydration_hooks=None): await self.fetch_all() check_supported_server_product(self.server_info.agent) + def logon(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGON message to the outgoing queue.""" + self.assert_re_auth_support() + + def logoff(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGOFF message to the outgoing queue.""" + self.assert_re_auth_support() + async def route( self, database=None, imp_user=None, bookmarks=None, dehydration_hooks=None, hydration_hooks=None @@ -353,8 +368,8 @@ async def _process_message(self, tag, fields): self.pool.on_write_failure(address=self.unresolved_address) raise except Neo4jError as e: - if self.pool and e._invalidates_all_connections(): - await self.pool.mark_all_stale() + if self.pool: + await self.pool.on_neo4j_error(e, self) raise else: raise BoltProtocolError("Unexpected response message with signature " diff --git a/src/neo4j/_async/io/_bolt5.py b/src/neo4j/_async/io/_bolt5.py index e545647e..25c8521a 100644 --- a/src/neo4j/_async/io/_bolt5.py +++ b/src/neo4j/_async/io/_bolt5.py @@ -16,6 +16,8 @@ # limitations under the License. +import typing as t +from enum import Enum from logging import getLogger from ssl import SSLSocket @@ -32,7 +34,10 @@ NotALeader, ServiceUnavailable, ) -from ._bolt import AsyncBolt +from ._bolt import ( + AsyncBolt, + ServerStateManagerBase, +) from ._bolt3 import ( ServerStateManager, ServerStates, @@ -41,6 +46,7 @@ check_supported_server_product, CommitResponse, InitResponse, + LogonResponse, Response, ) @@ -49,7 +55,7 @@ class AsyncBolt5x0(AsyncBolt): - """Protocol handler for Bolt 5.0. """ + """Protocol handler for Bolt 5.0.""" PROTOCOL_VERSION = Version(5, 0) @@ -59,16 +65,26 @@ class AsyncBolt5x0(AsyncBolt): supports_multiple_databases = True + supports_re_auth = False + + supports_notification_filtering = False + + server_states: t.Any = ServerStates + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._server_state_manager = ServerStateManager( - ServerStates.CONNECTED, on_change=self._on_server_state_change + self.server_states.CONNECTED, + on_change=self._on_server_state_change ) def _on_server_state_change(self, old_state, new_state): log.debug("[#%04X] _: state: %s > %s", self.local_port, old_state.name, new_state.name) + def _get_server_state_manager(self) -> ServerStateManagerBase: + return self._server_state_manager + @property def is_reset(self): # We can't be sure of the server's state if there are still pending @@ -77,7 +93,7 @@ def is_reset(self): if (self.responses and self.responses[-1] and self.responses[-1].message == "reset"): return True - return self._server_state_manager.state == ServerStates.READY + return self._server_state_manager.state == self.server_states.READY @property def encrypted(self): @@ -130,6 +146,14 @@ def on_success(metadata): await self.fetch_all() check_supported_server_product(self.server_info.agent) + def logon(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGON message to the outgoing queue.""" + self.assert_re_auth_support() + + def logoff(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGOFF message to the outgoing queue.""" + self.assert_re_auth_support() + async def route(self, database=None, imp_user=None, bookmarks=None, dehydration_hooks=None, hydration_hooks=None): routing_context = self.routing_context or {} @@ -331,7 +355,7 @@ async def _process_message(self, tag, fields): elif summary_signature == b"\x7F": log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata) - self._server_state_manager.state = ServerStates.FAILED + self._server_state_manager.state = self.server_states.FAILED try: await response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): @@ -343,8 +367,8 @@ async def _process_message(self, tag, fields): self.pool.on_write_failure(address=self.unresolved_address) raise except Neo4jError as e: - if self.pool and e._invalidates_all_connections(): - await self.pool.mark_all_stale() + if self.pool: + await self.pool.on_neo4j_error(e, self) raise else: raise BoltProtocolError( @@ -356,10 +380,63 @@ async def _process_message(self, tag, fields): return len(details), 1 +class ServerStates5x1(Enum): + CONNECTED = "CONNECTED" + READY = "READY" + STREAMING = "STREAMING" + TX_READY_OR_TX_STREAMING = "TX_READY||TX_STREAMING" + FAILED = "FAILED" + AUTHENTICATION = "AUTHENTICATION" + + +class ServerStateManager5x1(ServerStateManager): + _STATE_TRANSITIONS = { # type: ignore + ServerStates5x1.CONNECTED: { + "hello": ServerStates5x1.AUTHENTICATION, + }, + ServerStates5x1.AUTHENTICATION: { + "logon": ServerStates5x1.READY, + }, + ServerStates5x1.READY: { + "run": ServerStates5x1.STREAMING, + "begin": ServerStates5x1.TX_READY_OR_TX_STREAMING, + "logoff": ServerStates5x1.AUTHENTICATION, + }, + ServerStates5x1.STREAMING: { + "pull": ServerStates5x1.READY, + "discard": ServerStates5x1.READY, + "reset": ServerStates5x1.READY, + }, + ServerStates5x1.TX_READY_OR_TX_STREAMING: { + "commit": ServerStates5x1.READY, + "rollback": ServerStates5x1.READY, + "reset": ServerStates5x1.READY, + }, + ServerStates5x1.FAILED: { + "reset": ServerStates5x1.READY, + } + } + + + def failed(self): + return self.state == ServerStates5x1.FAILED + + class AsyncBolt5x1(AsyncBolt5x0): + """Protocol handler for Bolt 5.1.""" PROTOCOL_VERSION = Version(5, 1) + supports_re_auth = True + + server_states = ServerStates5x1 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._server_state_manager = ServerStateManager5x1( + ServerStates5x1.CONNECTED, on_change=self._on_server_state_change + ) + async def hello(self, dehydration_hooks=None, hydration_hooks=None): if ( self.notifications_min_severity is not None @@ -383,14 +460,15 @@ def on_success(metadata): "the server and network is set up correctly.", self.local_port, recv_timeout) - extra = self.get_base_headers() - log.debug("[#%04X] C: HELLO %r", self.local_port, extra) - self._append(b"\x01", (extra,), + headers = self.get_base_headers() + logged_headers = dict(headers) + log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) + self._append(b"\x01", (headers,), response=InitResponse(self, "hello", hydration_hooks, on_success=on_success), dehydration_hooks=dehydration_hooks) - - self.logon(dehydration_hooks, hydration_hooks) + self.logon(dehydration_hooks=dehydration_hooks, + hydration_hooks=hydration_hooks) await self.send_all() await self.fetch_all() check_supported_server_product(self.server_info.agent) @@ -401,7 +479,13 @@ def logon(self, dehydration_hooks=None, hydration_hooks=None): logged_auth_dict["credentials"] = "*******" log.debug("[#%04X] C: LOGON %r", self.local_port, logged_auth_dict) self._append(b"\x6A", (self.auth_dict,), - response=Response(self, "logon", hydration_hooks), + response=LogonResponse(self, "logon", hydration_hooks), + dehydration_hooks=dehydration_hooks) + + def logoff(self, dehydration_hooks=None, hydration_hooks=None): + log.debug("[#%04X] C: LOGOFF", self.local_port) + self._append(b"\x6B", + response=LogonResponse(self, "logoff", hydration_hooks), dehydration_hooks=dehydration_hooks) @@ -409,10 +493,12 @@ class AsyncBolt5x2(AsyncBolt5x1): PROTOCOL_VERSION = Version(5, 2) + supports_notification_filtering = True + def get_base_headers(self): headers = super().get_base_headers() if self.notifications_min_severity is not None: - headers["notifications_minimum_severity"] =\ + headers["notifications_minimum_severity"] = \ self.notifications_min_severity if self.notifications_disabled_categories is not None: headers["notifications_disabled_categories"] = \ @@ -448,15 +534,6 @@ def on_success(metadata): await self.fetch_all() check_supported_server_product(self.server_info.agent) - def logon(self, dehydration_hooks=None, hydration_hooks=None): - logged_auth_dict = dict(self.auth_dict) - if "credentials" in logged_auth_dict: - logged_auth_dict["credentials"] = "*******" - log.debug("[#%04X] C: LOGON %r", self.local_port, logged_auth_dict) - self._append(b"\x6A", (self.auth_dict,), - response=Response(self, "logon", hydration_hooks), - dehydration_hooks=dehydration_hooks) - def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, imp_user=None, notifications_min_severity=None, @@ -535,7 +612,7 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, extra["notifications_minimum_severity"] = \ notifications_min_severity if notifications_disabled_categories is not None: - extra["notifications_disabled_categories"] =\ + extra["notifications_disabled_categories"] = \ notifications_disabled_categories log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) self._append(b"\x11", (extra,), diff --git a/src/neo4j/_async/io/_common.py b/src/neo4j/_async/io/_common.py index fa21d3b6..a83ab09c 100644 --- a/src/neo4j/_async/io/_common.py +++ b/src/neo4j/_async/io/_common.py @@ -256,19 +256,34 @@ async def on_ignored(self, metadata=None): class InitResponse(Response): + async def on_failure(self, metadata): + # No sense in resetting the connection, + # the server will have closed it already. + self.connection.kill() + handler = self.handlers.get("on_failure") + await AsyncUtil.callback(handler, metadata) + handler = self.handlers.get("on_summary") + await AsyncUtil.callback(handler) + metadata["message"] = metadata.get( + "message", + "Connection initialisation failed due to an unknown error" + ) + raise Neo4jError.hydrate(**metadata) + +class LogonResponse(InitResponse): async def on_failure(self, metadata): - code = metadata.get("code") - if code == "Neo.ClientError.Security.Unauthorized": - raise Neo4jError.hydrate(**metadata) - else: - raise ServiceUnavailable( - metadata.get("message", "Connection initialisation failed") - ) + # No sense in resetting the connection, + # the server will have closed it already. + self.connection.kill() + handler = self.handlers.get("on_failure") + await AsyncUtil.callback(handler, metadata) + handler = self.handlers.get("on_summary") + await AsyncUtil.callback(handler) + raise Neo4jError.hydrate(**metadata) class CommitResponse(Response): - pass diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index 9058bb5a..714a9dfc 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -16,13 +16,18 @@ # limitations under the License. +from __future__ import annotations + import abc import asyncio import logging +import typing as t from collections import ( defaultdict, deque, ) +from copy import copy +from dataclasses import dataclass from logging import getLogger from random import choice @@ -32,6 +37,7 @@ AsyncRLock, ) from ..._async_compat.network import AsyncNetworkUtil +from ..._async_compat.util import AsyncUtil from ..._conf import ( PoolConfig, WorkspaceConfig, @@ -42,10 +48,15 @@ ) from ..._exceptions import BoltError from ..._routing import RoutingTable +from ..._sync.auth_management import StaticAuthManager from ...api import ( READ_ACCESS, WRITE_ACCESS, ) +from ...auth_management import ( + AsyncAuthManager, + AuthManager, +) from ...exceptions import ( ClientError, ConfigurationError, @@ -54,8 +65,11 @@ ReadServiceUnavailable, ServiceUnavailable, SessionExpired, + TokenExpired, + TokenExpiredRetryable, WriteServiceUnavailable, ) +from ..auth_management import AsyncStaticAuthManager from ._bolt import AsyncBolt @@ -63,6 +77,12 @@ log = getLogger("neo4j") +@dataclass +class AcquireAuth: + auth: t.Union[AsyncAuthManager, AuthManager, None] + force_auth: bool = False + + class AsyncIOPool(abc.ABC): """ A collection of connections to one or more server addresses. """ @@ -86,6 +106,9 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_value, traceback): await self.close() + async def get_auth(self): + return await AsyncUtil.callback(self.pool_config.auth.get_auth) + async def _acquire_from_pool(self, address): with self.lock: for connection in list(self.connections.get(address, [])): @@ -96,6 +119,22 @@ async def _acquire_from_pool(self, address): return connection return None # no free connection available + def _remove_connection(self, connection): + address = connection.unresolved_address + with self.lock: + log.debug( + "[#%04X] _: remove connection from pool %r %s", + connection.local_port, address, connection.connection_id + ) + try: + self.connections.get(address, []).remove(connection) + except ValueError: + # If closure fails (e.g. because the server went + # down), all connections to the same address will + # be removed. Therefore, we silently ignore if the + # connection isn't in the pool anymore. + pass + async def _acquire_from_pool_checked( self, address, health_check, deadline ): @@ -110,35 +149,40 @@ async def _acquire_from_pool_checked( # `stale` but still alive. if log.isEnabledFor(logging.DEBUG): log.debug( - "[#%04X] _: removing old connection %s " + "[#%04X] _: found unhealthy connection %s " "(closed=%s, defunct=%s, stale=%s, in_use=%s)", connection.local_port, connection.connection_id, connection.closed(), connection.defunct(), connection.stale(), connection.in_use ) await connection.close() - with self.lock: - try: - self.connections.get(address, []).remove(connection) - except ValueError: - # If closure fails (e.g. because the server went - # down), all connections to the same address will - # be removed. Therefore, we silently ignore if the - # connection isn't in the pool anymore. - pass + self._remove_connection(connection) continue # try again with a new connection else: return connection - def _acquire_new_later(self, address, deadline): + def _acquire_new_later(self, address, auth, deadline): async def connection_creator(): released_reservation = False try: try: - connection = await self.opener(address, deadline) + connection = await self.opener( + address, auth or self.pool_config.auth, deadline + ) except ServiceUnavailable: await self.deactivate(address) raise + if auth: + # It's unfortunate that we have to create a connection + # first to determine if the session-level auth is supported + # by the protocol or not. + try: + connection.assert_re_auth_support() + except ConfigurationError: + log.debug("[#%04X] _: no re-auth support", + connection.local_port) + await connection.close() + raise connection.pool = self connection.in_use = True with self.lock: @@ -164,13 +208,47 @@ async def connection_creator(): return connection_creator return None - async def _acquire(self, address, deadline, liveness_check_timeout): + async def _re_auth_connection(self, connection, auth, force): + if auth: + # Assert session auth is supported by the protocol. + # The Bolt implementation will try as hard as it can to make the + # re-auth work. So if the session auth token is identical to the + # driver auth token, the connection doesn't have to do anything so + # it won't fail regardless of the protocol version. + connection.assert_re_auth_support() + new_auth_manager = auth or self.pool_config.auth + log_auth = "******" if auth else "None" + new_auth = await AsyncUtil.callback(new_auth_manager.get_auth) + try: + updated = connection.re_auth(new_auth, new_auth_manager, + force=force) + log.debug("[#%04X] _: checked re_auth auth=%s updated=%s " + "force=%s", + connection.local_port, log_auth, updated, force) + except Exception as exc: + log.debug("[#%04X] _: check re_auth failed %r auth=%s " + "force=%s", + connection.local_port, exc, log_auth, force) + raise + assert not force or updated # force=True implies updated=True + if force: + await connection.send_all() + await connection.fetch_all() + + async def _acquire( + self, address, auth, deadline, liveness_check_timeout + ): """ Acquire a connection to a given address from the pool. The address supplied should always be an IP address, not a host name. This method is thread safe. """ + if auth is None: + auth = AcquireAuth(None) + force_auth = auth.force_auth + auth = auth.auth + async def health_check(connection_, deadline_): if (connection_.closed() or connection_.defunct() @@ -193,13 +271,33 @@ async def health_check(connection_, deadline_): address, health_check, deadline ) if connection: - log.debug("[#%04X] _: handing out existing connection " - "%s", connection.local_port, - connection.connection_id) + log.debug("[#%04X] _: picked existing connection %s", + connection.local_port, connection.connection_id) + try: + await self._re_auth_connection( + connection, auth, force_auth + ) + except ConfigurationError: + if auth: + # protocol version lacks support for re-auth + # => session auth token is not supported + raise + # expiring tokens supported by flushing the pool + # => give up this connection + log.debug("[#%04X] _: backwards compatible " + "auth token refresh: purge connection", + connection.local_port) + await connection.close() + await self.release(connection) + continue + log.debug("[#%04X] _: handing out existing connection", + connection.local_port) return connection # all connections in pool are in-use with self.lock: - connection_creator = self._acquire_new_later(address, deadline) + connection_creator = self._acquire_new_later( + address, auth, deadline, + ) if connection_creator: break @@ -220,7 +318,8 @@ async def health_check(connection_, deadline_): @abc.abstractmethod async def acquire( - self, access_mode, timeout, database, bookmarks, liveness_check_timeout + self, access_mode, timeout, database, bookmarks, auth: AcquireAuth, + liveness_check_timeout ): """ Acquire a connection to a server that can satisfy a set of parameters. @@ -229,6 +328,7 @@ async def acquire( (excluding potential preparation like fetching routing tables). :param database: :param bookmarks: + :param auth: :param liveness_check_timeout: """ ... @@ -263,24 +363,24 @@ async def release(self, *connections): or connection.is_reset): if cancelled is not None: log.debug( - "[#%04X] _: released unclean connection %s", + "[#%04X] _: kill unclean connection %s", connection.local_port, connection.connection_id ) connection.kill() continue try: log.debug( - "[#%04X] _: released unclean connection %s", + "[#%04X] _: release unclean connection %s", connection.local_port, connection.connection_id ) await connection.reset() - except (Neo4jError, DriverError, BoltError) as e: + except (Neo4jError, DriverError, BoltError) as exc: log.debug("[#%04X] _: failed to reset connection " - "on release: %r", connection.local_port, e) - except asyncio.CancelledError as e: + "on release: %r", connection.local_port, exc) + except asyncio.CancelledError as exc: log.debug("[#%04X] _: cancelled reset connection " - "on release: %r", connection.local_port, e) - cancelled = e + "on release: %r", connection.local_port, exc) + cancelled = exc connection.kill() with self.lock: for connection in connections: @@ -351,6 +451,27 @@ def on_write_failure(self, address): "No write service available for pool {}".format(self) ) + async def on_neo4j_error(self, error, connection): + assert isinstance(error, Neo4jError) + if error._unauthenticates_all_connections(): + address = connection.unresolved_address + log.debug( + "[#0000] _: mark all connections to %r as " + "unauthenticated", address + ) + with self.lock: + for connection in self.connections.get(address, ()): + connection.mark_unauthenticated() + if error._requires_new_credentials(): + await AsyncUtil.callback( + connection.auth_manager.on_auth_expired, + connection.auth + ) + if (isinstance(error, TokenExpired) + and not isinstance(self.pool_config.auth, (AsyncStaticAuthManager, + StaticAuthManager))): + error.__class__ = TokenExpiredRetryable + async def close(self): """ Close all connections and empty the pool. This method is thread safe. @@ -370,20 +491,19 @@ async def close(self): class AsyncBoltPool(AsyncIOPool): @classmethod - def open(cls, address, *, auth, pool_config, workspace_config): + def open(cls, address, *, pool_config, workspace_config): """Create a new BoltPool :param address: - :param auth: :param pool_config: :param workspace_config: :returns: BoltPool """ - async def opener(addr, deadline): + async def opener(addr, auth_manager, deadline): return await AsyncBolt.open( - addr, auth=auth, deadline=deadline, routing_context=None, - pool_config=pool_config + addr, auth_manager=auth_manager, deadline=deadline, + routing_context=None, pool_config=pool_config ) pool = cls(opener, pool_config, workspace_config, address) @@ -399,7 +519,8 @@ def __repr__(self): self.address) async def acquire( - self, access_mode, timeout, database, bookmarks, liveness_check_timeout + self, access_mode, timeout, database, bookmarks, auth: AcquireAuth, + liveness_check_timeout ): # The access_mode and database is not needed for a direct connection, # it's just there for consistency. @@ -407,7 +528,7 @@ async def acquire( "access_mode=%r, database=%r", access_mode, database) deadline = Deadline.from_timeout_or_deadline(timeout) return await self._acquire( - self.address, deadline, liveness_check_timeout + self.address, auth, deadline, liveness_check_timeout ) @@ -416,12 +537,11 @@ class AsyncNeo4jPool(AsyncIOPool): """ @classmethod - def open(cls, *addresses, auth, pool_config, workspace_config, + def open(cls, *addresses, pool_config, workspace_config, routing_context=None): """Create a new Neo4jPool :param addresses: one or more address as positional argument - :param auth: :param pool_config: :param workspace_config: :param routing_context: @@ -435,9 +555,9 @@ def open(cls, *addresses, auth, pool_config, workspace_config, raise ConfigurationError("The key 'address' is reserved for routing context.") routing_context["address"] = str(address) - async def opener(addr, deadline): + async def opener(addr, auth_manager, deadline): return await AsyncBolt.open( - addr, auth=auth, deadline=deadline, + addr, auth_manager=auth_manager, deadline=deadline, routing_context=routing_context, pool_config=pool_config ) @@ -478,7 +598,7 @@ async def get_or_create_routing_table(self, database): return self.routing_tables[database] async def fetch_routing_info( - self, address, database, imp_user, bookmarks, acquisition_timeout + self, address, database, imp_user, bookmarks, auth, acquisition_timeout ): """ Fetch raw routing info from a given router address. @@ -489,6 +609,7 @@ async def fetch_routing_info( :type imp_user: str or None :param bookmarks: iterable of bookmark values after which the routing info should be fetched + :param auth: auth :param acquisition_timeout: connection acquisition timeout :returns: list of routing records, or None if no connection @@ -499,7 +620,10 @@ async def fetch_routing_info( deadline = Deadline.from_timeout_or_deadline(acquisition_timeout) log.debug("[#0000] _: _acquire router connection, " "database=%r, address=%r", database, address) - cx = await self._acquire(address, deadline, None) + if auth: + auth = copy(auth) + auth.force_auth = False + cx = await self._acquire(address, auth, deadline, None) try: routing_table = await cx.route( database=database or self.workspace_config.database, @@ -511,7 +635,8 @@ async def fetch_routing_info( return routing_table async def fetch_routing_table( - self, *, address, acquisition_timeout, database, imp_user, bookmarks + self, *, address, acquisition_timeout, database, imp_user, + bookmarks, auth ): """ Fetch a routing table from a given router address. @@ -523,6 +648,7 @@ async def fetch_routing_table( table :type imp_user: str or None :param bookmarks: bookmarks used when fetching routing table + :param auth: auth :returns: a new RoutingTable instance or None if the given router is currently unable to provide routing information @@ -530,7 +656,8 @@ async def fetch_routing_table( new_routing_info = None try: new_routing_info = await self.fetch_routing_info( - address, database, imp_user, bookmarks, acquisition_timeout + address, database, imp_user, bookmarks, auth, + acquisition_timeout ) except Neo4jError as e: # checks if the code is an error that is caused by the client. In @@ -576,8 +703,8 @@ async def fetch_routing_table( return new_routing_table async def _update_routing_table_from( - self, *routers, database, imp_user, bookmarks, acquisition_timeout, - database_callback + self, *routers, database, imp_user, bookmarks, auth, + acquisition_timeout, database_callback ): """ Try to update routing tables with the given routers. @@ -593,7 +720,8 @@ async def _update_routing_table_from( ): new_routing_table = await self.fetch_routing_table( address=address, acquisition_timeout=acquisition_timeout, - database=database, imp_user=imp_user, bookmarks=bookmarks + database=database, imp_user=imp_user, bookmarks=bookmarks, + auth=auth ) if new_routing_table is not None: new_database = new_routing_table.database @@ -613,8 +741,8 @@ async def _update_routing_table_from( return False async def update_routing_table( - self, *, database, imp_user, bookmarks, acquisition_timeout=None, - database_callback=None + self, *, database, imp_user, bookmarks, auth=None, + acquisition_timeout=None, database_callback=None ): """ Update the routing table from the first router able to provide valid routing information. @@ -624,6 +752,7 @@ async def update_routing_table( table :type imp_user: str or None :param bookmarks: bookmarks used when fetching routing table + :param auth: auth :param acquisition_timeout: connection acquisition timeout :param database_callback: A callback function that will be called with the database name as only argument when a new routing table has been @@ -645,15 +774,15 @@ async def update_routing_table( # TODO: Test this state if await self._update_routing_table_from( self.address, database=database, - imp_user=imp_user, bookmarks=bookmarks, + imp_user=imp_user, bookmarks=bookmarks, auth=auth, acquisition_timeout=acquisition_timeout, database_callback=database_callback ): # Why is only the first initial routing address used? return if await self._update_routing_table_from( - *(existing_routers - {self.address}), - database=database, imp_user=imp_user, bookmarks=bookmarks, + *(existing_routers - {self.address}), database=database, + imp_user=imp_user, bookmarks=bookmarks, auth=auth, acquisition_timeout=acquisition_timeout, database_callback=database_callback ): @@ -662,7 +791,7 @@ async def update_routing_table( if not prefer_initial_routing_address: if await self._update_routing_table_from( self.address, database=database, - imp_user=imp_user, bookmarks=bookmarks, + imp_user=imp_user, bookmarks=bookmarks, auth=auth, acquisition_timeout=acquisition_timeout, database_callback=database_callback ): @@ -681,7 +810,7 @@ async def update_connection_pool(self, *, database): await super(AsyncNeo4jPool, self).deactivate(address) async def ensure_routing_table_is_fresh( - self, *, access_mode, database, imp_user, bookmarks, + self, *, access_mode, database, imp_user, bookmarks, auth=None, acquisition_timeout=None, database_callback=None ): """ Update the routing table if stale. @@ -718,7 +847,7 @@ async def ensure_routing_table_is_fresh( await self.update_routing_table( database=database, imp_user=imp_user, bookmarks=bookmarks, - acquisition_timeout=acquisition_timeout, + auth=auth, acquisition_timeout=acquisition_timeout, database_callback=database_callback ) await self.update_connection_pool(database=database) @@ -755,7 +884,8 @@ async def _select_address(self, *, access_mode, database): return choice(addresses_by_usage[min(addresses_by_usage)]) async def acquire( - self, access_mode, timeout, database, bookmarks, liveness_check_timeout + self, access_mode, timeout, database, bookmarks, + auth: t.Optional[AcquireAuth], liveness_check_timeout ): if access_mode not in (WRITE_ACCESS, READ_ACCESS): raise ClientError("Non valid 'access_mode'; {}".format(access_mode)) @@ -775,7 +905,7 @@ async def acquire( "access_mode=%r, database=%r", access_mode, database) await self.ensure_routing_table_is_fresh( access_mode=access_mode, database=database, - imp_user=None, bookmarks=bookmarks, + imp_user=None, bookmarks=bookmarks, auth=auth, acquisition_timeout=timeout ) @@ -794,7 +924,7 @@ async def acquire( deadline = Deadline.from_timeout_or_deadline(timeout) # should always be a resolved address connection = await self._acquire( - address, deadline, liveness_check_timeout + address, auth, deadline, liveness_check_timeout, ) except (ServiceUnavailable, SessionExpired): await self.deactivate(address=address) diff --git a/src/neo4j/_async/work/result.py b/src/neo4j/_async/work/result.py index 6743556c..5a8dea97 100644 --- a/src/neo4j/_async/work/result.py +++ b/src/neo4j/_async/work/result.py @@ -282,7 +282,7 @@ async def _buffer(self, n=None): Might end up with more records in the buffer if the fetch size makes it overshoot. - Might ent up with fewer records in the buffer if there are not enough + Might end up with fewer records in the buffer if there are not enough records available. """ if self._out_of_scope: @@ -624,7 +624,7 @@ async def to_eager_result(self) -> EagerResult: was obtained has been closed or the Result has been explicitly consumed. - **This is experimental.** (See :ref:`filter-warnings-ref`) + **This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. .. versionadded:: 5.5 diff --git a/src/neo4j/_async/work/session.py b/src/neo4j/_async/work/session.py index 27a4cbb1..41227ea0 100644 --- a/src/neo4j/_async/work/session.py +++ b/src/neo4j/_async/work/session.py @@ -20,6 +20,7 @@ import asyncio import typing as t +import warnings from logging import getLogger from random import random from time import perf_counter @@ -27,7 +28,10 @@ from ..._async_compat import async_sleep from ..._async_compat.util import AsyncUtil from ..._conf import SessionConfig -from ..._meta import deprecated +from ..._meta import ( + deprecated, + PreviewWarning, +) from ..._work import Query from ...api import ( Bookmarks, @@ -42,6 +46,7 @@ SessionExpired, TransactionError, ) +from ..auth_management import AsyncAuthManagers from .result import AsyncResult from .transaction import ( AsyncManagedTransaction, @@ -97,7 +102,17 @@ class AsyncSession(AsyncWorkspace): def __init__(self, pool, session_config): assert isinstance(session_config, SessionConfig) + if session_config.auth is not None: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message=r".*\bAuth managers\b.*", + category=PreviewWarning + ) + session_config.auth = AsyncAuthManagers.static( + session_config.auth + ) super().__init__(pool, session_config) + self._config = session_config self._initialize_bookmarks(session_config.bookmarks) self._bookmark_manager = session_config.bookmark_manager @@ -113,11 +128,13 @@ async def __aexit__(self, exception_type, exception_value, traceback): self._state_failed = True await self.close() - async def _connect(self, access_mode, **access_kwargs): + async def _connect(self, access_mode, **acquire_kwargs): if access_mode is None: access_mode = self._config.default_access_mode try: - await super()._connect(access_mode, **access_kwargs) + await super()._connect( + access_mode, auth=self._config.auth, **acquire_kwargs + ) except asyncio.CancelledError: self._handle_cancellation(message="_connect") raise @@ -162,6 +179,11 @@ async def _get_server_info(self): await self._disconnect() return server_info + async def _verify_authentication(self): + assert not self._connection + await self._connect(READ_ACCESS, force_auth=True) + await self._disconnect() + async def close(self) -> None: """Close the session. diff --git a/src/neo4j/_async/work/workspace.py b/src/neo4j/_async/work/workspace.py index 21a18d5b..6a3d43d6 100644 --- a/src/neo4j/_async/work/workspace.py +++ b/src/neo4j/_async/work/workspace.py @@ -33,7 +33,10 @@ SessionError, SessionExpired, ) -from ..io import AsyncNeo4jPool +from ..io import ( + AcquireAuth, + AsyncNeo4jPool, +) log = logging.getLogger("neo4j") @@ -130,8 +133,13 @@ async def _update_bookmark(self, bookmark): return await self._update_bookmarks((bookmark,)) - async def _connect(self, access_mode, **acquire_kwargs): + async def _connect(self, access_mode, auth=None, **acquire_kwargs): acquisition_timeout = self._config.connection_acquisition_timeout + auth = AcquireAuth( + auth, + force_auth=acquire_kwargs.pop("force_auth", False), + ) + if self._connection: # TODO: Investigate this # log.warning("FIXME: should always disconnect before connect") @@ -154,6 +162,7 @@ async def _connect(self, access_mode, **acquire_kwargs): database=self._config.database, imp_user=self._config.impersonated_user, bookmarks=await self._get_bookmarks(), + auth=auth, acquisition_timeout=acquisition_timeout, database_callback=self._set_cached_database ) @@ -162,6 +171,7 @@ async def _connect(self, access_mode, **acquire_kwargs): "timeout": acquisition_timeout, "database": self._config.database, "bookmarks": await self._get_bookmarks(), + "auth": auth, "liveness_check_timeout": None, } acquire_kwargs_.update(acquire_kwargs) diff --git a/src/neo4j/_async_compat/network/_bolt_socket.py b/src/neo4j/_async_compat/network/_bolt_socket.py index 507d080a..ce2ce11f 100644 --- a/src/neo4j/_async_compat/network/_bolt_socket.py +++ b/src/neo4j/_async_compat/network/_bolt_socket.py @@ -20,7 +20,6 @@ import asyncio import logging -import selectors import struct import typing as t from socket import ( diff --git a/src/neo4j/_auth_management.py b/src/neo4j/_auth_management.py new file mode 100644 index 00000000..6d3c7a92 --- /dev/null +++ b/src/neo4j/_auth_management.py @@ -0,0 +1,140 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# from __future__ import annotations +# work around for https://github.com/sphinx-doc/sphinx/pull/10880 +# make sure TAuth is resolved in the docs, else they're pretty useless + +import abc +import typing as t +from dataclasses import dataclass + +from ._meta import preview +from .api import _TAuth + + +@preview("Auth managers are a preview feature.") +@dataclass +class ExpiringAuth: + """Represents potentially expiring authentication information. + + This class is used with :meth:`.AuthManagers.temporal` and + :meth:`.AsyncAuthManagers.temporal`. + + :param auth: The authentication information. + :param expires_in: The number of seconds until the authentication + information expires. If :data:`None`, the authentication information + is considered to not expire until the server explicitly indicates so. + + **This is a preview** (see :ref:`filter-warnings-ref`). + It might be changed without following the deprecation policy. + See also https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + .. seealso:: + :meth:`.AuthManagers.temporal`, :meth:`.AsyncAuthManagers.temporal` + + .. versionadded:: 5.8 + """ + auth: _TAuth + expires_in: t.Optional[float] = None + + +class AuthManager(metaclass=abc.ABCMeta): + """Baseclass for authentication information managers. + + The driver provides some default implementations of this class in + :class:`.AuthManagers` for convenience. + + Custom implementations of this class can be used to provide more complex + authentication refresh functionality. + + .. warning:: + + The manager **must not** interact with the driver in any way as this + can cause deadlocks and undefined behaviour. + + Furthermore, the manager is expected to be thread-safe. + + The token returned must always belong to the same identity. + Switching identities using the `AuthManager` is undefined behavior. + + **This is a preview** (see :ref:`filter-warnings-ref`). + It might be changed without following the deprecation policy. + See also https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + .. seealso:: :class:`.AuthManagers` + + .. versionadded:: 5.8 + """ + + @abc.abstractmethod + def get_auth(self) -> _TAuth: + """Return the current authentication information. + + The driver will call this method very frequently. It is recommended + to implement some form of caching to avoid unnecessary overhead. + + .. warning:: + + The method must only ever return auth information belonging to the + same identity. + Switching identities using the `AuthManager` is undefined behavior. + """ + ... + + @abc.abstractmethod + def on_auth_expired(self, auth: _TAuth) -> None: + """Handle the server indicating expired authentication information. + + The driver will call this method when the server indicates that the + provided authentication information is no longer valid. + + :param auth: + The authentication information that the server flagged as no longer + valid. + """ + ... + + +class AsyncAuthManager(metaclass=abc.ABCMeta): + """Async version of :class:`.AuthManager`. + + **This is a preview** (see :ref:`filter-warnings-ref`). + It might be changed without following the deprecation policy. + See also https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + .. seealso:: :class:`.AuthManager` + + .. versionadded:: 5.8 + """ + + @abc.abstractmethod + async def get_auth(self) -> _TAuth: + """Async version of :meth:`.AuthManager.get_auth`. + + .. seealso:: :meth:`.AuthManager.get_auth` + """ + ... + + @abc.abstractmethod + async def on_auth_expired(self, auth: _TAuth) -> None: + """Async version of :meth:`.AuthManager.on_auth_expired`. + + .. seealso:: :meth:`.AuthManager.on_auth_expired` + """ + ... diff --git a/src/neo4j/_conf.py b/src/neo4j/_conf.py index b3a4d0b0..61b6f676 100644 --- a/src/neo4j/_conf.py +++ b/src/neo4j/_conf.py @@ -407,6 +407,9 @@ class PoolConfig(Config): keep_alive = True # Specify whether TCP keep-alive should be enabled. + #: Authentication provider + auth = None + #: Lowest notification severity for the server to return notifications_min_severity = None @@ -511,6 +514,9 @@ class SessionConfig(WorkspaceConfig): #: Default AccessMode default_access_mode = WRITE_ACCESS + #: Auth token to temporarily switch the user + auth = None + #: Lowest notification severity for the server to return notifications_min_severity = None diff --git a/src/neo4j/_meta.py b/src/neo4j/_meta.py index f8f8cd9c..5489fc44 100644 --- a/src/neo4j/_meta.py +++ b/src/neo4j/_meta.py @@ -94,6 +94,9 @@ def decorator(f): class ExperimentalWarning(Warning): """ Base class for warnings about experimental features. + + .. deprecated:: 5.8 + we now use "preview" instead of "experimental". """ @@ -101,6 +104,20 @@ def experimental_warn(message, stack_level=1): warn(message, category=ExperimentalWarning, stacklevel=stack_level + 1) +class PreviewWarning(Warning): + """ Base class for warnings about experimental features. + """ + + +def preview_warn(message, stack_level=1): + message += ( + " It might be changed without following the deprecation policy. " + "See also " + "https://github.com/neo4j/neo4j-python-driver/wiki/preview-features." + ) + warn(message, category=PreviewWarning, stacklevel=stack_level + 1) + + def experimental(message) -> t.Callable[[_FuncT], _FuncT]: """ Decorator for tagging experimental functions and methods. @@ -111,6 +128,8 @@ def experimental(message) -> t.Callable[[_FuncT], _FuncT]: def foo(x): pass + .. deprecated:: 5.8 + we now use "preview" instead of "experimental". """ def decorator(f): if asyncio.iscoroutinefunction(f): @@ -131,6 +150,33 @@ def inner(*args, **kwargs): return decorator +def preview(message) -> t.Callable[[_FuncT], _FuncT]: + """ + Decorator for tagging preview functions and methods. + + @preview("foo is a preview.") + def foo(x): + pass + """ + def decorator(f): + if asyncio.iscoroutinefunction(f): + @wraps(f) + async def inner(*args, **kwargs): + preview_warn(message, stack_level=2) + return await f(*args, **kwargs) + + return inner + else: + @wraps(f) + def inner(*args, **kwargs): + preview_warn(message, stack_level=2) + return f(*args, **kwargs) + + return inner + + return decorator + + def unclosed_resource_warn(obj): msg = f"Unclosed {obj!r}." trace = tracemalloc.get_object_traceback(obj) diff --git a/src/neo4j/_sync/auth_management.py b/src/neo4j/_sync/auth_management.py new file mode 100644 index 00000000..fc016f4b --- /dev/null +++ b/src/neo4j/_sync/auth_management.py @@ -0,0 +1,206 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# from __future__ import annotations +# work around for https://github.com/sphinx-doc/sphinx/pull/10880 +# make sure TAuth is resolved in the docs, else they're pretty useless + +import time +import typing as t +from logging import getLogger + +from .._async_compat.concurrency import Lock +from .._auth_management import ( + AuthManager, + ExpiringAuth, +) +from .._meta import preview + +# work around for https://github.com/sphinx-doc/sphinx/pull/10880 +# make sure TAuth is resolved in the docs, else they're pretty useless +# if t.TYPE_CHECKING: +from ..api import _TAuth + + +log = getLogger("neo4j") + + +class StaticAuthManager(AuthManager): + _auth: _TAuth + + def __init__(self, auth: _TAuth) -> None: + self._auth = auth + + def get_auth(self) -> _TAuth: + return self._auth + + def on_auth_expired(self, auth: _TAuth) -> None: + pass + + +class _ExpiringAuthHolder: + def __init__(self, auth: ExpiringAuth) -> None: + self._auth = auth + self._expiry = None + if auth.expires_in is not None: + self._expiry = time.monotonic() + auth.expires_in + + @property + def auth(self) -> _TAuth: + return self._auth.auth + + def expired(self) -> bool: + if self._expiry is None: + return False + return time.monotonic() > self._expiry + +class ExpirationBasedAuthManager(AuthManager): + _current_auth: t.Optional[_ExpiringAuthHolder] + _provider: t.Callable[[], t.Union[ExpiringAuth]] + _lock: Lock + + + def __init__( + self, + provider: t.Callable[[], t.Union[ExpiringAuth]] + ) -> None: + self._provider = provider + self._current_auth = None + self._lock = Lock() + + def _refresh_auth(self): + self._current_auth = _ExpiringAuthHolder(self._provider()) + + def get_auth(self) -> _TAuth: + with self._lock: + auth = self._current_auth + if auth is not None and not auth.expired(): + return auth.auth + log.debug("[ ] _: refreshing (time out)") + self._refresh_auth() + assert self._current_auth is not None + return self._current_auth.auth + + def on_auth_expired(self, auth: _TAuth) -> None: + with self._lock: + cur_auth = self._current_auth + if cur_auth is not None and cur_auth.auth == auth: + log.debug("[ ] _: refreshing (error)") + self._refresh_auth() + + +class AuthManagers: + """A collection of :class:`.AuthManager` factories. + + **This is a preview** (see :ref:`filter-warnings-ref`). + It might be changed without following the deprecation policy. + See also https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + .. versionadded:: 5.8 + """ + + @staticmethod + @preview("Auth managers are a preview feature.") + def static(auth: _TAuth) -> AuthManager: + """Create a static auth manager. + + Example:: + + # NOTE: this example is for illustration purposes only. + # The driver will automatically wrap static auth info in a + # static auth manager. + + import neo4j + from neo4j.auth_management import AuthManagers + + + auth = neo4j.basic_auth("neo4j", "password") + + with neo4j.GraphDatabase.driver( + "neo4j://example.com:7687", + auth=AuthManagers.static(auth) + # auth=auth # this is equivalent + ) as driver: + ... # do stuff + + :param auth: The auth to return. + + :returns: + An instance of an implementation of :class:`.AuthManager` that + always returns the same auth. + """ + return StaticAuthManager(auth) + + @staticmethod + @preview("Auth managers are a preview feature.") + def expiration_based( + provider: t.Callable[[], t.Union[ExpiringAuth]] + ) -> AuthManager: + """Create an auth manager for potentially expiring auth info. + + .. warning:: + + The provider function **must not** interact with the driver in any + way as this can cause deadlocks and undefined behaviour. + + The provider function only ever return auth information belonging + to the same identity. + Switching identities is undefined behavior. + + Example:: + + import neo4j + from neo4j.auth_management import ( + AuthManagers, + ExpiringAuth, + ) + + + def auth_provider(): + # some way to getting a token + sso_token = get_sso_token() + # assume we know our tokens expire every 60 seconds + expires_in = 60 + + return ExpiringAuth( + auth=neo4j.bearer_auth(sso_token), + # Include a little buffer so that we fetch a new token + # *before* the old one expires + expires_in=expires_in - 10 + ) + + + with neo4j.GraphDatabase.driver( + "neo4j://example.com:7687", + auth=AuthManagers.temporal(auth_provider) + ) as driver: + ... # do stuff + + :param provider: + A callable that provides a :class:`.ExpiringAuth` instance. + + :returns: + An instance of an implementation of :class:`.AuthManager` that + returns auth info from the given provider and refreshes it, calling + the provider again, when the auth info expires (either because it's + reached its expiry time or because the server flagged it as + expired). + + + """ + return ExpirationBasedAuthManager(provider) diff --git a/src/neo4j/_sync/driver.py b/src/neo4j/_sync/driver.py index 5118cc00..69ccc904 100644 --- a/src/neo4j/_sync/driver.py +++ b/src/neo4j/_sync/driver.py @@ -47,6 +47,9 @@ experimental, experimental_warn, ExperimentalWarning, + preview, + preview_warn, + PreviewWarning, unclosed_resource_warn, ) from .._work import EagerResult @@ -72,6 +75,11 @@ URI_SCHEME_NEO4J_SECURE, URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE, ) +from ..auth_management import ( + AuthManager, + AuthManagers, +) +from ..exceptions import Neo4jError from .bookmark_manager import ( Neo4jBookmarkManager, TBmConsumer as _TBmConsumer, @@ -91,6 +99,7 @@ import typing_extensions as te from .._api import T_RoutingControl + from ..api import _TAuth class _DefaultEnum(Enum): @@ -115,7 +124,13 @@ def driver( cls, uri: str, *, - auth: t.Union[t.Tuple[t.Any, t.Any], Auth, None] = ..., + auth: t.Union[ + # work around https://github.com/sphinx-doc/sphinx/pull/10880 + # make sure TAuth is resolved in the docs + # TAuth, + t.Union[t.Tuple[t.Any, t.Any], Auth, None], + AuthManager, + ] = ..., max_connection_lifetime: float = ..., max_connection_pool_size: int = ..., connection_timeout: float = ..., @@ -159,7 +174,13 @@ def driver( @classmethod def driver( cls, uri: str, *, - auth: t.Union[t.Tuple[t.Any, t.Any], Auth, None] = None, + auth: t.Union[ + # work around https://github.com/sphinx-doc/sphinx/pull/10880 + # make sure TAuth is resolved in the docs + # TAuth, + t.Union[t.Tuple[t.Any, t.Any], Auth, None], + AuthManager, + ] = None, **config ) -> Driver: """Create a driver. @@ -175,6 +196,18 @@ def driver( driver_type, security_type, parsed = parse_neo4j_uri(uri) + if not isinstance(auth, AuthManager): + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message=r".*\bAuth managers\b.*", + category=PreviewWarning + ) + auth = AuthManagers.static(auth) + else: + preview_warn("Auth managers are a preview feature.", + stack_level=2) + config["auth"] = auth + # TODO: 6.0 - remove "trust" config option if "trust" in config.keys(): if config["trust"] not in ( @@ -252,10 +285,10 @@ def driver( # 'Routing parameters are not supported with scheme ' # '"bolt". Given URI "{}".'.format(uri) # ) - return cls.bolt_driver(parsed.netloc, auth=auth, **config) + return cls.bolt_driver(parsed.netloc, **config) # else driver_type == DRIVER_NEO4J routing_context = parse_routing_context(parsed.query) - return cls.neo4j_driver(parsed.netloc, auth=auth, + return cls.neo4j_driver(parsed.netloc, routing_context=routing_context, **config) @classmethod @@ -323,7 +356,7 @@ def bookmark_manager( :returns: A default implementation of :class:`BookmarkManager`. - **This is experimental.** (See :ref:`filter-warnings-ref`) + **This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. .. versionadded:: 5.0 @@ -347,7 +380,7 @@ def bookmark_manager( ) @classmethod - def bolt_driver(cls, target, *, auth=None, **config): + def bolt_driver(cls, target, **config): """ Create a driver for direct Bolt server access that uses socket I/O and thread-based concurrency. """ @@ -357,13 +390,13 @@ def bolt_driver(cls, target, *, auth=None, **config): ) try: - return BoltDriver.open(target, auth=auth, **config) + return BoltDriver.open(target, **config) except (BoltHandshakeError, BoltSecurityError) as error: from ..exceptions import ServiceUnavailable raise ServiceUnavailable(str(error)) from error @classmethod - def neo4j_driver(cls, *targets, auth=None, routing_context=None, **config): + def neo4j_driver(cls, *targets, routing_context=None, **config): """ Create a driver for routing-capable Neo4j service access that uses socket I/O and thread-based concurrency. """ @@ -373,7 +406,7 @@ def neo4j_driver(cls, *targets, auth=None, routing_context=None, **config): ) try: - return Neo4jDriver.open(*targets, auth=auth, routing_context=routing_context, **config) + return Neo4jDriver.open(*targets, routing_context=routing_context, **config) except (BoltHandshakeError, BoltSecurityError) as error: from ..exceptions import ServiceUnavailable raise ServiceUnavailable(str(error)) from error @@ -480,6 +513,9 @@ def encrypted(self) -> bool: return bool(self._pool.pool_config.encrypted) def _prepare_session_config(self, **config): + if "auth" in config: + preview_warn("User switching is a preview features.", + stack_level=3) _normalize_notifications_config(config) return config @@ -497,6 +533,7 @@ def session( default_access_mode: str = ..., bookmark_manager: t.Union[BookmarkManager, BookmarkManager, None] = ..., + auth: _TAuth = ..., notifications_min_severity: t.Optional[ T_NotificationMinimumSeverity ] = ..., @@ -547,6 +584,7 @@ def execute_query( bookmark_manager_: t.Union[ BookmarkManager, BookmarkManager, None ] = ..., + auth_: _TAuth = None, result_transformer_: t.Callable[ [Result], t.Union[EagerResult] ] = ..., @@ -565,6 +603,7 @@ def execute_query( bookmark_manager_: t.Union[ BookmarkManager, BookmarkManager, None ] = ..., + auth_: _TAuth = None, result_transformer_: t.Callable[ [Result], t.Union[_T] ] = ..., @@ -587,6 +626,7 @@ def execute_query( BookmarkManager, BookmarkManager, None, te.Literal[_DefaultEnum.default] ] = _default, + auth_: _TAuth = None, result_transformer_: t.Callable[ [Result], t.Union[t.Any] ] = Result.to_eager_result, @@ -608,7 +648,7 @@ def execute_query( def execute_query( query_, parameters_, routing_, database_, impersonated_user_, - bookmark_manager_, result_transformer_, **kwargs + bookmark_manager_, auth_, result_transformer_, **kwargs ): def work(tx): result = tx.run(query_, parameters_, **kwargs) @@ -618,6 +658,7 @@ def work(tx): database=database_, impersonated_user=impersonated_user_, bookmark_manager=bookmark_manager_, + auth=auth_, ) as session: if routing_ == RoutingControl.WRITERS: return session.execute_write(work) @@ -653,13 +694,14 @@ def example(driver: neo4j.Driver) -> List[str]: def example(driver: neo4j.Driver) -> int: \"""Call all young people "My dear" and get their count.\""" record = driver.execute_query( - "MATCH (p:Person) WHERE p.age <= 15 " + "MATCH (p:Person) WHERE p.age <= $age " "SET p.nickname = 'My dear' " "RETURN count(*)", # optional routing parameter, as write is default # routing_=neo4j.RoutingControl.WRITERS, # or just "w", database_="neo4j", result_transformer_=neo4j.Result.single, + age=15, ) assert record is not None # for typechecking and illustration count = record[0] @@ -696,6 +738,20 @@ def example(driver: neo4j.Driver) -> int: See also the Session config :ref:`impersonated-user-ref`. :type impersonated_user_: typing.Optional[str] + :param auth_: + Authentication information to use for this query. + + By default, the driver configuration is used. + + **This is a preview** (see :ref:`filter-warnings-ref`). + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + See also the Session config :ref:`session-auth-ref`. + :type auth_: typing.Union[ + typing.Tuple[typing.Any, typing.Any], neo4j.Auth, None + ] :param result_transformer_: A function that gets passed the :class:`neo4j.Result` object resulting from the query and converts it to a different type. The @@ -766,7 +822,7 @@ def example(driver: neo4j.Driver) -> neo4j.Record:: Defaults to the driver's :attr:`.query_bookmark_manager`. - Pass :const:`None` to disable causal consistency. + Pass :data:`None` to disable causal consistency. :type bookmark_manager_: typing.Union[neo4j.BookmarkManager, neo4j.BookmarkManager, None] @@ -780,10 +836,17 @@ def example(driver: neo4j.Driver) -> neo4j.Record:: :returns: the result of the ``result_transformer`` :rtype: T - **This is experimental.** (See :ref:`filter-warnings-ref`) + **This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. + We are looking for feedback on this feature. Please let us know what + you think about it here: + https://github.com/neo4j/neo4j-python-driver/discussions/896 + .. versionadded:: 5.5 + + .. versionchanged:: 5.8 + Added the ``auth_`` parameter. """ invalid_kwargs = [k for k in kwargs if k[-2:-1] != "_" and k[-1:] == "_"] @@ -805,9 +868,13 @@ def example(driver: neo4j.Driver) -> neo4j.Record:: warnings.filterwarnings("ignore", message=r".*\bbookmark_manager\b.*", category=ExperimentalWarning) + warnings.filterwarnings("ignore", + message=r"^User switching\b.*", + category=PreviewWarning) session = self.session(database=database_, impersonated_user=impersonated_user_, - bookmark_manager=bookmark_manager_) + bookmark_manager=bookmark_manager_, + auth=auth_) with session: if routing_ == RoutingControl.WRITERS: executor = session.execute_write @@ -844,7 +911,7 @@ def example(driver: neo4j.Driver) -> None: # (i.e., can read what was written by ) driver.execute_query("") - **This is experimental.** (See :ref:`filter-warnings-ref`) + **This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. .. versionadded:: 5.5 @@ -868,6 +935,7 @@ def verify_connectivity( default_access_mode: str = ..., bookmark_manager: t.Union[BookmarkManager, BookmarkManager, None] = ..., + auth: t.Union[Auth, t.Tuple[t.Any, t.Any]] = ..., notifications_min_severity: t.Optional[ T_NotificationMinimumSeverity ] = ..., @@ -884,7 +952,6 @@ def verify_connectivity( else: - # TODO: 6.0 - remove config argument def verify_connectivity(self, **config) -> None: """Verify that the driver can establish a connection to the server. @@ -903,7 +970,7 @@ def verify_connectivity(self, **config) -> None: They might be changed or removed in any future version without prior notice. - :raises DriverError: if the driver cannot connect to the remote. + :raises Exception: if the driver cannot connect to the remote. Use the exception to further understand the cause of the connectivity problem. @@ -919,8 +986,7 @@ def verify_connectivity(self, **config) -> None: "changed or removed in any future version without prior " "notice." ) - with self.session(**config) as session: - session._get_server_info() + self._get_server_info() if t.TYPE_CHECKING: @@ -939,6 +1005,7 @@ def get_server_info( default_access_mode: str = ..., bookmark_manager: t.Union[BookmarkManager, BookmarkManager, None] = ..., + auth: t.Union[Auth, t.Tuple[t.Any, t.Any]] = ..., notifications_min_severity: t.Optional[ T_NotificationMinimumSeverity ] = ..., @@ -977,7 +1044,7 @@ def get_server_info(self, **config) -> ServerInfo: They might be changed or removed in any future version without prior notice. - :raises DriverError: if the driver cannot connect to the remote. + :raises Exception: if the driver cannot connect to the remote. Use the exception to further understand the cause of the connectivity problem. @@ -986,14 +1053,12 @@ def get_server_info(self, **config) -> ServerInfo: if config: experimental_warn( "All configuration key-word arguments to " - "verify_connectivity() are experimental. They might be " + "get_server_info() are experimental. They might be " "changed or removed in any future version without prior " "notice." ) - with self.session(**config) as session: - return session._get_server_info() + return self._get_server_info() - @experimental("Feature support query, based on Bolt protocol version and Neo4j server version will change in the future.") def supports_multi_db(self) -> bool: """ Check if the server or cluster supports multi-databases. @@ -1001,14 +1066,142 @@ def supports_multi_db(self) -> bool: supports multi-databases, otherwise false. .. note:: - Feature support query, based on Bolt Protocol Version and Neo4j - server version will change in the future. + Feature support query based solely on the Bolt protocol version. + The feature might still be disabled on the server side even if this + function return :const:`True`. It just guarantees that the driver + won't throw a :exc:`ConfigurationError` when trying to use this + driver feature. """ with self.session() as session: session._connect(READ_ACCESS) assert session._connection return session._connection.supports_multiple_databases + if t.TYPE_CHECKING: + + def verify_authentication( + self, + auth: t.Union[Auth, t.Tuple[t.Any, t.Any], None] = None, + # all other arguments are experimental + # they may be change or removed any time without prior notice + session_connection_timeout: float = ..., + connection_acquisition_timeout: float = ..., + max_transaction_retry_time: float = ..., + database: t.Optional[str] = ..., + fetch_size: int = ..., + impersonated_user: t.Optional[str] = ..., + bookmarks: t.Union[t.Iterable[str], Bookmarks, None] = ..., + default_access_mode: str = ..., + bookmark_manager: t.Union[ + BookmarkManager, BookmarkManager, None + ] = ..., + + # undocumented/unsupported options + initial_retry_delay: float = ..., + retry_delay_multiplier: float = ..., + retry_delay_jitter_factor: float = ... + ) -> bool: + ... + + else: + + @preview("User switching is a preview feature.") + def verify_authentication( + self, + auth: t.Union[Auth, t.Tuple[t.Any, t.Any], None] = None, + **config + ) -> bool: + """Verify that the authentication information is valid. + + Like :meth:`.verify_connectivity`, but for checking authentication. + + Try to establish a working read connection to the remote server or + a member of a cluster and exchange some data. In a cluster, there + is no guarantee about which server will be contacted. If the data + exchange is successful, the authentication information is valid and + :const:`True` is returned. Otherwise, the error will be matched + against a list of known authentication errors. If the error is on + that list, :const:`False` is returned indicating that the + authentication information is invalid. Otherwise, the error is + re-raised. + + :param auth: authentication information to verify. + Same as the session config :ref:`auth-ref`. + :param config: accepts the same configuration key-word arguments as + :meth:`session`. + + .. warning:: + All configuration key-word arguments (except ``auth``) are + experimental. They might be changed or removed in any + future version without prior notice. + + :raises Exception: if the driver cannot connect to the remote. + Use the exception to further understand the cause of the + connectivity problem. + + **This is a preview** (see :ref:`filter-warnings-ref`). + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + .. versionadded:: 5.8 + """ + if config: + experimental_warn( + "All configuration key-word arguments but auth to " + "verify_authentication() are experimental. They might be " + "changed or removed in any future version without prior " + "notice." + ) + config["auth"] = auth + if "database" not in config: + config["database"] = "system" + try: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message=r"^User switching\b.*", + category=PreviewWarning + ) + session = self.session(**config) + with session as session: + session._verify_authentication() + except Neo4jError as exc: + if exc.code in ( + "Neo.ClientError.Security.CredentialsExpired", + "Neo.ClientError.Security.Forbidden", + "Neo.ClientError.Security.TokenExpired", + "Neo.ClientError.Security.Unauthorized", + ): + return False + raise + return True + + + def supports_session_auth(self) -> bool: + """Check if the remote supports connection re-authentication. + + :returns: Returns true if the server or cluster the driver connects to + supports re-authentication of existing connections, otherwise + false. + + .. note:: + Feature support query based solely on the Bolt protocol version. + The feature might still be disabled on the server side even if this + function return :const:`True`. It just guarantees that the driver + won't throw a :exc:`ConfigurationError` when trying to use this + driver feature. + + .. versionadded:: 5.x + """ + with self.session() as session: + session._connect(READ_ACCESS) + assert session._connection + return session._connection.supports_re_auth + + def _get_server_info(self, **config) -> ServerInfo: + with self.session(**config) as session: + return session._get_server_info() + def _work( tx: ManagedTransaction, @@ -1039,7 +1232,7 @@ class BoltDriver(_Direct, Driver): """ @classmethod - def open(cls, target, *, auth=None, **config): + def open(cls, target, **config): """ :param target: :param auth: @@ -1051,7 +1244,7 @@ def open(cls, target, *, auth=None, **config): from .io import BoltPool address = cls.parse_target(target) pool_config, default_workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) - pool = BoltPool.open(address, auth=auth, pool_config=pool_config, workspace_config=default_workspace_config) + pool = BoltPool.open(address, pool_config=pool_config, workspace_config=default_workspace_config) return cls(pool, default_workspace_config) def __init__(self, pool, default_workspace_config): @@ -1087,11 +1280,11 @@ class Neo4jDriver(_Routing, Driver): """ @classmethod - def open(cls, *targets, auth=None, routing_context=None, **config): + def open(cls, *targets, routing_context=None, **config): from .io import Neo4jPool addresses = cls.parse_targets(*targets) pool_config, default_workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) - pool = Neo4jPool.open(*addresses, auth=auth, routing_context=routing_context, pool_config=pool_config, workspace_config=default_workspace_config) + pool = Neo4jPool.open(*addresses, routing_context=routing_context, pool_config=pool_config, workspace_config=default_workspace_config) return cls(pool, default_workspace_config) def __init__(self, pool, default_workspace_config): diff --git a/src/neo4j/_sync/io/__init__.py b/src/neo4j/_sync/io/__init__.py index 18aa1149..f5cddef2 100644 --- a/src/neo4j/_sync/io/__init__.py +++ b/src/neo4j/_sync/io/__init__.py @@ -24,6 +24,7 @@ __all__ = [ + "AcquireAuth", "Bolt", "BoltPool", "Neo4jPool", @@ -38,6 +39,7 @@ ConnectionErrorHandler, ) from ._pool import ( + AcquireAuth, BoltPool, Neo4jPool, ) diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 09f832e2..f2b97e58 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -25,6 +25,7 @@ from time import perf_counter from ..._async_compat.network import BoltSocket +from ..._async_compat.util import Util from ..._codec.hydration import v1 as hydration_v1 from ..._codec.packstream import v1 as packstream_v1 from ..._conf import PoolConfig @@ -34,7 +35,7 @@ BoltHandshakeError, ) from ..._meta import get_user_agent -from ...addressing import Address +from ...addressing import ResolvedAddress from ...api import ( ServerInfo, Version, @@ -58,6 +59,20 @@ log = getLogger("neo4j") +class ServerStateManagerBase(abc.ABC): + @abc.abstractmethod + def __init__(self, init_state, on_change=None): + ... + + @abc.abstractmethod + def transition(self, message, metadata): + ... + + @abc.abstractmethod + def failed(self): + ... + + class Bolt: """ Server connection for Bolt protocol. @@ -104,13 +119,17 @@ class Bolt: most_recent_qid = None def __init__(self, unresolved_address, sock, max_connection_lifetime, *, - auth=None, user_agent=None, routing_context=None, - notifications_min_severity=None, notifications_disabled_categories=None): + auth=None, auth_manager=None, user_agent=None, + routing_context=None, notifications_min_severity=None, + notifications_disabled_categories=None): self.unresolved_address = unresolved_address self.socket = sock self.local_port = self.socket.getsockname()[1] - self.server_info = ServerInfo(Address(sock.getpeername()), - self.PROTOCOL_VERSION) + self.server_info = ServerInfo( + ResolvedAddress(sock.getpeername(), + host_name=unresolved_address.host), + self.PROTOCOL_VERSION + ) # so far `connection.recv_timeout_seconds` is the only available # configuration hint that exists. Therefore, all hints can be stored at # connection level. This might change in the future. @@ -137,26 +156,9 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, else: self.user_agent = get_user_agent() - # Determine auth details - if not auth: - self.auth_dict = {} - elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: - from ...api import Auth - self.auth_dict = vars(Auth("basic", *auth)) - else: - try: - self.auth_dict = vars(auth) - except (KeyError, TypeError): - raise AuthError("Cannot determine auth details from %r" % auth) - - # Check for missing password - try: - credentials = self.auth_dict["credentials"] - except KeyError: - pass - else: - if credentials is None: - raise AuthError("Password cannot be None") + self.auth = auth + self.auth_dict = self._to_auth_dict(auth) + self.auth_manager = auth_manager self.notifications_min_severity = notifications_min_severity self.notifications_disabled_categories = \ @@ -166,6 +168,24 @@ def __del__(self): if not asyncio.iscoroutinefunction(self.close): self.close() + @abc.abstractmethod + def _get_server_state_manager(self) -> ServerStateManagerBase: + ... + + @classmethod + def _to_auth_dict(cls, auth): + # Determine auth details + if not auth: + return {} + elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: + from ...api import Auth + return vars(Auth("basic", *auth)) + else: + try: + return vars(auth) + except (KeyError, TypeError): + raise AuthError("Cannot determine auth details from %r" % auth) + @property def connection_id(self): return self.server_info._metadata.get("connection_id", "") @@ -187,6 +207,20 @@ def supports_multiple_databases(self): """ pass + @property + @abc.abstractmethod + def supports_re_auth(self): + """Whether the connection version supports re-authentication.""" + pass + + def assert_re_auth_support(self): + if not self.supports_re_auth: + raise ConfigurationError( + "User switching is not supported for Bolt " + f"Protocol {self.PROTOCOL_VERSION!r}. Server Agent " + f"{self.server_info.agent!r}" + ) + @property @abc.abstractmethod def supports_notification_filtering(self): @@ -318,13 +352,13 @@ def ping(cls, address, *, deadline=None, pool_config=None): # [bolt-version-bump] search tag when changing bolt version support @classmethod def open( - cls, address, *, auth=None, deadline=None, routing_context=None, - pool_config=None + cls, address, *, auth_manager=None, deadline=None, + routing_context=None, pool_config=None ): """Open a new Bolt connection to a given server address. :param address: - :param auth: + :param auth_manager: :param deadline: how long to wait for the connection to be established :param routing_context: dict containing routing context :param pool_config: @@ -386,7 +420,7 @@ def open( from ._bolt3 import Bolt3 bolt_cls = Bolt3 else: - log.debug("[#%04X] S: ", s.getsockname()[1]) + log.debug("[#%04X] C: ", s.getsockname()[1]) BoltSocket.close_socket(s) supported_versions = cls.protocol_handlers().keys() @@ -397,9 +431,23 @@ def open( address=address, request_data=handshake, response_data=data ) + try: + auth = Util.callback(auth_manager.get_auth) + except asyncio.CancelledError as e: + log.debug("[#%04X] C: open auth manager failed: %r", + s.getsockname()[1], e) + s.kill() + raise + except Exception as e: + log.debug("[#%04X] C: open auth manager failed: %r", + s.getsockname()[1], e) + s.close() + raise + connection = bolt_cls( address, s, pool_config.max_connection_lifetime, auth=auth, - user_agent=pool_config.user_agent, routing_context=routing_context, + auth_manager=auth_manager, user_agent=pool_config.user_agent, + routing_context=routing_context, notifications_min_severity=pool_config.notifications_min_severity, notifications_disabled_categories= pool_config.notifications_disabled_categories @@ -448,6 +496,45 @@ def hello(self, dehydration_hooks=None, hydration_hooks=None): """ pass + @abc.abstractmethod + def logon(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGON message to the outgoing queue.""" + pass + + @abc.abstractmethod + def logoff(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGOFF message to the outgoing queue.""" + pass + + def mark_unauthenticated(self): + """Mark the connection as unauthenticated.""" + self.auth_dict = {} + + def re_auth( + self, auth, auth_manager, force=False, + dehydration_hooks=None, hydration_hooks=None, + ): + """Append LOGON, LOGOFF to the outgoing queue. + + If auth is the same as the current auth, this method does nothing. + + :returns: whether the auth was changed + """ + new_auth_dict = self._to_auth_dict(auth) + if not force and new_auth_dict == self.auth_dict: + self.auth_manager = auth_manager + self.auth = auth + return False + self.logoff(dehydration_hooks=dehydration_hooks, + hydration_hooks=hydration_hooks) + self.auth_dict = new_auth_dict + self.auth_manager = auth_manager + self.auth = auth + self.logon(dehydration_hooks=dehydration_hooks, + hydration_hooks=hydration_hooks) + return True + + @abc.abstractmethod def route( self, database=None, imp_user=None, bookmarks=None, @@ -767,7 +854,7 @@ def _set_defunct(self, message, error=None, silent=False): # remove the connection from the pool, nor to try to close the # connection again. self.close() - if self.pool: + if self.pool and not self._get_server_state_manager().failed(): self.pool.deactivate(address=self.unresolved_address) # Iterate through the outstanding responses, and if any correspond diff --git a/src/neo4j/_sync/io/_bolt3.py b/src/neo4j/_sync/io/_bolt3.py index b98b3f25..4a125797 100644 --- a/src/neo4j/_sync/io/_bolt3.py +++ b/src/neo4j/_sync/io/_bolt3.py @@ -16,6 +16,9 @@ # limitations under the License. +from __future__ import annotations + +import typing as t from enum import Enum from logging import getLogger from ssl import SSLSocket @@ -33,7 +36,10 @@ NotALeader, ServiceUnavailable, ) -from ._bolt import Bolt +from ._bolt import ( + Bolt, + ServerStateManagerBase, +) from ._common import ( check_supported_server_product, CommitResponse, @@ -53,8 +59,8 @@ class ServerStates(Enum): FAILED = "FAILED" -class ServerStateManager: - _STATE_TRANSITIONS = { +class ServerStateManager(ServerStateManagerBase): + _STATE_TRANSITIONS: t.Dict[Enum, t.Dict[str, Enum]] = { ServerStates.CONNECTED: { "hello": ServerStates.READY, }, @@ -91,6 +97,9 @@ def transition(self, message, metadata): if state_before != self.state and callable(self._on_change): self._on_change(state_before, self.state) + def failed(self): + return self.state == ServerStates.FAILED + class Bolt3(Bolt): """ Protocol handler for Bolt 3. @@ -104,6 +113,8 @@ class Bolt3(Bolt): supports_multiple_databases = False + supports_re_auth = False + supports_notification_filtering = False def __init__(self, *args, **kwargs): @@ -116,6 +127,9 @@ def _on_server_state_change(self, old_state, new_state): log.debug("[#%04X] _: state: %s > %s", self.local_port, old_state.name, new_state.name) + def _get_server_state_manager(self) -> ServerStateManagerBase: + return self._server_state_manager + @property def is_reset(self): # We can't be sure of the server's state if there are still pending @@ -159,6 +173,14 @@ def hello(self, dehydration_hooks=None, hydration_hooks=None): self.fetch_all() check_supported_server_product(self.server_info.agent) + def logon(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGON message to the outgoing queue.""" + self.assert_re_auth_support() + + def logoff(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGOFF message to the outgoing queue.""" + self.assert_re_auth_support() + def route( self, database=None, imp_user=None, bookmarks=None, dehydration_hooks=None, hydration_hooks=None @@ -397,8 +419,7 @@ def _process_message(self, tag, fields): self.pool.on_write_failure(address=self.unresolved_address) raise except Neo4jError as e: - if self.pool and e._invalidates_all_connections(): - self.pool.mark_all_stale() + self.pool.on_neo4j_error(e, self) raise else: raise BoltProtocolError("Unexpected response message with signature %02X" % summary_signature, address=self.unresolved_address) diff --git a/src/neo4j/_sync/io/_bolt4.py b/src/neo4j/_sync/io/_bolt4.py index 147a1686..96cc8167 100644 --- a/src/neo4j/_sync/io/_bolt4.py +++ b/src/neo4j/_sync/io/_bolt4.py @@ -19,7 +19,6 @@ from logging import getLogger from ssl import SSLSocket -from ..._async_compat.util import Util from ..._exceptions import BoltProtocolError from ...api import ( READ_ACCESS, @@ -34,7 +33,10 @@ NotALeader, ServiceUnavailable, ) -from ._bolt import Bolt +from ._bolt import ( + Bolt, + ServerStateManagerBase, +) from ._bolt3 import ( ServerStateManager, ServerStates, @@ -62,6 +64,8 @@ class Bolt4x0(Bolt): supports_multiple_databases = True + supports_re_auth = False + supports_notification_filtering = False def __init__(self, *args, **kwargs): @@ -74,6 +78,9 @@ def _on_server_state_change(self, old_state, new_state): log.debug("[#%04X] _: state: %s > %s", self.local_port, old_state.name, new_state.name) + def _get_server_state_manager(self) -> ServerStateManagerBase: + return self._server_state_manager + @property def is_reset(self): # We can't be sure of the server's state if there are still pending @@ -119,6 +126,14 @@ def hello(self, dehydration_hooks=None, hydration_hooks=None): self.fetch_all() check_supported_server_product(self.server_info.agent) + def logon(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGON message to the outgoing queue.""" + self.assert_re_auth_support() + + def logoff(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGOFF message to the outgoing queue.""" + self.assert_re_auth_support() + def route( self, database=None, imp_user=None, bookmarks=None, dehydration_hooks=None, hydration_hooks=None @@ -353,8 +368,8 @@ def _process_message(self, tag, fields): self.pool.on_write_failure(address=self.unresolved_address) raise except Neo4jError as e: - if self.pool and e._invalidates_all_connections(): - self.pool.mark_all_stale() + if self.pool: + self.pool.on_neo4j_error(e, self) raise else: raise BoltProtocolError("Unexpected response message with signature " diff --git a/src/neo4j/_sync/io/_bolt5.py b/src/neo4j/_sync/io/_bolt5.py index dcf52050..a511d1d8 100644 --- a/src/neo4j/_sync/io/_bolt5.py +++ b/src/neo4j/_sync/io/_bolt5.py @@ -16,6 +16,8 @@ # limitations under the License. +import typing as t +from enum import Enum from logging import getLogger from ssl import SSLSocket @@ -32,7 +34,10 @@ NotALeader, ServiceUnavailable, ) -from ._bolt import Bolt +from ._bolt import ( + Bolt, + ServerStateManagerBase, +) from ._bolt3 import ( ServerStateManager, ServerStates, @@ -41,6 +46,7 @@ check_supported_server_product, CommitResponse, InitResponse, + LogonResponse, Response, ) @@ -49,7 +55,7 @@ class Bolt5x0(Bolt): - """Protocol handler for Bolt 5.0. """ + """Protocol handler for Bolt 5.0.""" PROTOCOL_VERSION = Version(5, 0) @@ -59,16 +65,26 @@ class Bolt5x0(Bolt): supports_multiple_databases = True + supports_re_auth = False + + supports_notification_filtering = False + + server_states: t.Any = ServerStates + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._server_state_manager = ServerStateManager( - ServerStates.CONNECTED, on_change=self._on_server_state_change + self.server_states.CONNECTED, + on_change=self._on_server_state_change ) def _on_server_state_change(self, old_state, new_state): log.debug("[#%04X] _: state: %s > %s", self.local_port, old_state.name, new_state.name) + def _get_server_state_manager(self) -> ServerStateManagerBase: + return self._server_state_manager + @property def is_reset(self): # We can't be sure of the server's state if there are still pending @@ -77,7 +93,7 @@ def is_reset(self): if (self.responses and self.responses[-1] and self.responses[-1].message == "reset"): return True - return self._server_state_manager.state == ServerStates.READY + return self._server_state_manager.state == self.server_states.READY @property def encrypted(self): @@ -130,6 +146,14 @@ def on_success(metadata): self.fetch_all() check_supported_server_product(self.server_info.agent) + def logon(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGON message to the outgoing queue.""" + self.assert_re_auth_support() + + def logoff(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGOFF message to the outgoing queue.""" + self.assert_re_auth_support() + def route(self, database=None, imp_user=None, bookmarks=None, dehydration_hooks=None, hydration_hooks=None): routing_context = self.routing_context or {} @@ -331,7 +355,7 @@ def _process_message(self, tag, fields): elif summary_signature == b"\x7F": log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata) - self._server_state_manager.state = ServerStates.FAILED + self._server_state_manager.state = self.server_states.FAILED try: response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): @@ -343,8 +367,8 @@ def _process_message(self, tag, fields): self.pool.on_write_failure(address=self.unresolved_address) raise except Neo4jError as e: - if self.pool and e._invalidates_all_connections(): - self.pool.mark_all_stale() + if self.pool: + self.pool.on_neo4j_error(e, self) raise else: raise BoltProtocolError( @@ -356,10 +380,63 @@ def _process_message(self, tag, fields): return len(details), 1 +class ServerStates5x1(Enum): + CONNECTED = "CONNECTED" + READY = "READY" + STREAMING = "STREAMING" + TX_READY_OR_TX_STREAMING = "TX_READY||TX_STREAMING" + FAILED = "FAILED" + AUTHENTICATION = "AUTHENTICATION" + + +class ServerStateManager5x1(ServerStateManager): + _STATE_TRANSITIONS = { # type: ignore + ServerStates5x1.CONNECTED: { + "hello": ServerStates5x1.AUTHENTICATION, + }, + ServerStates5x1.AUTHENTICATION: { + "logon": ServerStates5x1.READY, + }, + ServerStates5x1.READY: { + "run": ServerStates5x1.STREAMING, + "begin": ServerStates5x1.TX_READY_OR_TX_STREAMING, + "logoff": ServerStates5x1.AUTHENTICATION, + }, + ServerStates5x1.STREAMING: { + "pull": ServerStates5x1.READY, + "discard": ServerStates5x1.READY, + "reset": ServerStates5x1.READY, + }, + ServerStates5x1.TX_READY_OR_TX_STREAMING: { + "commit": ServerStates5x1.READY, + "rollback": ServerStates5x1.READY, + "reset": ServerStates5x1.READY, + }, + ServerStates5x1.FAILED: { + "reset": ServerStates5x1.READY, + } + } + + + def failed(self): + return self.state == ServerStates5x1.FAILED + + class Bolt5x1(Bolt5x0): + """Protocol handler for Bolt 5.1.""" PROTOCOL_VERSION = Version(5, 1) + supports_re_auth = True + + server_states = ServerStates5x1 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._server_state_manager = ServerStateManager5x1( + ServerStates5x1.CONNECTED, on_change=self._on_server_state_change + ) + def hello(self, dehydration_hooks=None, hydration_hooks=None): if ( self.notifications_min_severity is not None @@ -383,14 +460,15 @@ def on_success(metadata): "the server and network is set up correctly.", self.local_port, recv_timeout) - extra = self.get_base_headers() - log.debug("[#%04X] C: HELLO %r", self.local_port, extra) - self._append(b"\x01", (extra,), + headers = self.get_base_headers() + logged_headers = dict(headers) + log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) + self._append(b"\x01", (headers,), response=InitResponse(self, "hello", hydration_hooks, on_success=on_success), dehydration_hooks=dehydration_hooks) - - self.logon(dehydration_hooks, hydration_hooks) + self.logon(dehydration_hooks=dehydration_hooks, + hydration_hooks=hydration_hooks) self.send_all() self.fetch_all() check_supported_server_product(self.server_info.agent) @@ -401,7 +479,13 @@ def logon(self, dehydration_hooks=None, hydration_hooks=None): logged_auth_dict["credentials"] = "*******" log.debug("[#%04X] C: LOGON %r", self.local_port, logged_auth_dict) self._append(b"\x6A", (self.auth_dict,), - response=Response(self, "logon", hydration_hooks), + response=LogonResponse(self, "logon", hydration_hooks), + dehydration_hooks=dehydration_hooks) + + def logoff(self, dehydration_hooks=None, hydration_hooks=None): + log.debug("[#%04X] C: LOGOFF", self.local_port) + self._append(b"\x6B", + response=LogonResponse(self, "logoff", hydration_hooks), dehydration_hooks=dehydration_hooks) @@ -409,6 +493,8 @@ class Bolt5x2(Bolt5x1): PROTOCOL_VERSION = Version(5, 2) + supports_notification_filtering = True + def get_base_headers(self): headers = super().get_base_headers() if self.notifications_min_severity is not None: @@ -448,15 +534,6 @@ def on_success(metadata): self.fetch_all() check_supported_server_product(self.server_info.agent) - def logon(self, dehydration_hooks=None, hydration_hooks=None): - logged_auth_dict = dict(self.auth_dict) - if "credentials" in logged_auth_dict: - logged_auth_dict["credentials"] = "*******" - log.debug("[#%04X] C: LOGON %r", self.local_port, logged_auth_dict) - self._append(b"\x6A", (self.auth_dict,), - response=Response(self, "logon", hydration_hooks), - dehydration_hooks=dehydration_hooks) - def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, imp_user=None, notifications_min_severity=None, diff --git a/src/neo4j/_sync/io/_common.py b/src/neo4j/_sync/io/_common.py index 89681a38..c4e29546 100644 --- a/src/neo4j/_sync/io/_common.py +++ b/src/neo4j/_sync/io/_common.py @@ -256,19 +256,34 @@ def on_ignored(self, metadata=None): class InitResponse(Response): + def on_failure(self, metadata): + # No sense in resetting the connection, + # the server will have closed it already. + self.connection.kill() + handler = self.handlers.get("on_failure") + Util.callback(handler, metadata) + handler = self.handlers.get("on_summary") + Util.callback(handler) + metadata["message"] = metadata.get( + "message", + "Connection initialisation failed due to an unknown error" + ) + raise Neo4jError.hydrate(**metadata) + +class LogonResponse(InitResponse): def on_failure(self, metadata): - code = metadata.get("code") - if code == "Neo.ClientError.Security.Unauthorized": - raise Neo4jError.hydrate(**metadata) - else: - raise ServiceUnavailable( - metadata.get("message", "Connection initialisation failed") - ) + # No sense in resetting the connection, + # the server will have closed it already. + self.connection.kill() + handler = self.handlers.get("on_failure") + Util.callback(handler, metadata) + handler = self.handlers.get("on_summary") + Util.callback(handler) + raise Neo4jError.hydrate(**metadata) class CommitResponse(Response): - pass diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index 229cc8c9..b0af10b5 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -16,13 +16,18 @@ # limitations under the License. +from __future__ import annotations + import abc import asyncio import logging +import typing as t from collections import ( defaultdict, deque, ) +from copy import copy +from dataclasses import dataclass from logging import getLogger from random import choice @@ -32,6 +37,7 @@ RLock, ) from ..._async_compat.network import NetworkUtil +from ..._async_compat.util import Util from ..._conf import ( PoolConfig, WorkspaceConfig, @@ -42,10 +48,12 @@ ) from ..._exceptions import BoltError from ..._routing import RoutingTable +from ..._sync.auth_management import StaticAuthManager from ...api import ( READ_ACCESS, WRITE_ACCESS, ) +from ...auth_management import AuthManager from ...exceptions import ( ClientError, ConfigurationError, @@ -54,8 +62,11 @@ ReadServiceUnavailable, ServiceUnavailable, SessionExpired, + TokenExpired, + TokenExpiredRetryable, WriteServiceUnavailable, ) +from ..auth_management import StaticAuthManager from ._bolt import Bolt @@ -63,6 +74,12 @@ log = getLogger("neo4j") +@dataclass +class AcquireAuth: + auth: t.Union[AuthManager, AuthManager, None] + force_auth: bool = False + + class IOPool(abc.ABC): """ A collection of connections to one or more server addresses. """ @@ -86,6 +103,9 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self.close() + def get_auth(self): + return Util.callback(self.pool_config.auth.get_auth) + def _acquire_from_pool(self, address): with self.lock: for connection in list(self.connections.get(address, [])): @@ -96,6 +116,22 @@ def _acquire_from_pool(self, address): return connection return None # no free connection available + def _remove_connection(self, connection): + address = connection.unresolved_address + with self.lock: + log.debug( + "[#%04X] _: remove connection from pool %r %s", + connection.local_port, address, connection.connection_id + ) + try: + self.connections.get(address, []).remove(connection) + except ValueError: + # If closure fails (e.g. because the server went + # down), all connections to the same address will + # be removed. Therefore, we silently ignore if the + # connection isn't in the pool anymore. + pass + def _acquire_from_pool_checked( self, address, health_check, deadline ): @@ -110,35 +146,40 @@ def _acquire_from_pool_checked( # `stale` but still alive. if log.isEnabledFor(logging.DEBUG): log.debug( - "[#%04X] _: removing old connection %s " + "[#%04X] _: found unhealthy connection %s " "(closed=%s, defunct=%s, stale=%s, in_use=%s)", connection.local_port, connection.connection_id, connection.closed(), connection.defunct(), connection.stale(), connection.in_use ) connection.close() - with self.lock: - try: - self.connections.get(address, []).remove(connection) - except ValueError: - # If closure fails (e.g. because the server went - # down), all connections to the same address will - # be removed. Therefore, we silently ignore if the - # connection isn't in the pool anymore. - pass + self._remove_connection(connection) continue # try again with a new connection else: return connection - def _acquire_new_later(self, address, deadline): + def _acquire_new_later(self, address, auth, deadline): def connection_creator(): released_reservation = False try: try: - connection = self.opener(address, deadline) + connection = self.opener( + address, auth or self.pool_config.auth, deadline + ) except ServiceUnavailable: self.deactivate(address) raise + if auth: + # It's unfortunate that we have to create a connection + # first to determine if the session-level auth is supported + # by the protocol or not. + try: + connection.assert_re_auth_support() + except ConfigurationError: + log.debug("[#%04X] _: no re-auth support", + connection.local_port) + connection.close() + raise connection.pool = self connection.in_use = True with self.lock: @@ -164,13 +205,47 @@ def connection_creator(): return connection_creator return None - def _acquire(self, address, deadline, liveness_check_timeout): + def _re_auth_connection(self, connection, auth, force): + if auth: + # Assert session auth is supported by the protocol. + # The Bolt implementation will try as hard as it can to make the + # re-auth work. So if the session auth token is identical to the + # driver auth token, the connection doesn't have to do anything so + # it won't fail regardless of the protocol version. + connection.assert_re_auth_support() + new_auth_manager = auth or self.pool_config.auth + log_auth = "******" if auth else "None" + new_auth = Util.callback(new_auth_manager.get_auth) + try: + updated = connection.re_auth(new_auth, new_auth_manager, + force=force) + log.debug("[#%04X] _: checked re_auth auth=%s updated=%s " + "force=%s", + connection.local_port, log_auth, updated, force) + except Exception as exc: + log.debug("[#%04X] _: check re_auth failed %r auth=%s " + "force=%s", + connection.local_port, exc, log_auth, force) + raise + assert not force or updated # force=True implies updated=True + if force: + connection.send_all() + connection.fetch_all() + + def _acquire( + self, address, auth, deadline, liveness_check_timeout + ): """ Acquire a connection to a given address from the pool. The address supplied should always be an IP address, not a host name. This method is thread safe. """ + if auth is None: + auth = AcquireAuth(None) + force_auth = auth.force_auth + auth = auth.auth + def health_check(connection_, deadline_): if (connection_.closed() or connection_.defunct() @@ -193,13 +268,33 @@ def health_check(connection_, deadline_): address, health_check, deadline ) if connection: - log.debug("[#%04X] _: handing out existing connection " - "%s", connection.local_port, - connection.connection_id) + log.debug("[#%04X] _: picked existing connection %s", + connection.local_port, connection.connection_id) + try: + self._re_auth_connection( + connection, auth, force_auth + ) + except ConfigurationError: + if auth: + # protocol version lacks support for re-auth + # => session auth token is not supported + raise + # expiring tokens supported by flushing the pool + # => give up this connection + log.debug("[#%04X] _: backwards compatible " + "auth token refresh: purge connection", + connection.local_port) + connection.close() + self.release(connection) + continue + log.debug("[#%04X] _: handing out existing connection", + connection.local_port) return connection # all connections in pool are in-use with self.lock: - connection_creator = self._acquire_new_later(address, deadline) + connection_creator = self._acquire_new_later( + address, auth, deadline, + ) if connection_creator: break @@ -220,7 +315,8 @@ def health_check(connection_, deadline_): @abc.abstractmethod def acquire( - self, access_mode, timeout, database, bookmarks, liveness_check_timeout + self, access_mode, timeout, database, bookmarks, auth: AcquireAuth, + liveness_check_timeout ): """ Acquire a connection to a server that can satisfy a set of parameters. @@ -229,6 +325,7 @@ def acquire( (excluding potential preparation like fetching routing tables). :param database: :param bookmarks: + :param auth: :param liveness_check_timeout: """ ... @@ -263,24 +360,24 @@ def release(self, *connections): or connection.is_reset): if cancelled is not None: log.debug( - "[#%04X] _: released unclean connection %s", + "[#%04X] _: kill unclean connection %s", connection.local_port, connection.connection_id ) connection.kill() continue try: log.debug( - "[#%04X] _: released unclean connection %s", + "[#%04X] _: release unclean connection %s", connection.local_port, connection.connection_id ) connection.reset() - except (Neo4jError, DriverError, BoltError) as e: + except (Neo4jError, DriverError, BoltError) as exc: log.debug("[#%04X] _: failed to reset connection " - "on release: %r", connection.local_port, e) - except asyncio.CancelledError as e: + "on release: %r", connection.local_port, exc) + except asyncio.CancelledError as exc: log.debug("[#%04X] _: cancelled reset connection " - "on release: %r", connection.local_port, e) - cancelled = e + "on release: %r", connection.local_port, exc) + cancelled = exc connection.kill() with self.lock: for connection in connections: @@ -351,6 +448,27 @@ def on_write_failure(self, address): "No write service available for pool {}".format(self) ) + def on_neo4j_error(self, error, connection): + assert isinstance(error, Neo4jError) + if error._unauthenticates_all_connections(): + address = connection.unresolved_address + log.debug( + "[#0000] _: mark all connections to %r as " + "unauthenticated", address + ) + with self.lock: + for connection in self.connections.get(address, ()): + connection.mark_unauthenticated() + if error._requires_new_credentials(): + Util.callback( + connection.auth_manager.on_auth_expired, + connection.auth + ) + if (isinstance(error, TokenExpired) + and not isinstance(self.pool_config.auth, (StaticAuthManager, + StaticAuthManager))): + error.__class__ = TokenExpiredRetryable + def close(self): """ Close all connections and empty the pool. This method is thread safe. @@ -370,20 +488,19 @@ def close(self): class BoltPool(IOPool): @classmethod - def open(cls, address, *, auth, pool_config, workspace_config): + def open(cls, address, *, pool_config, workspace_config): """Create a new BoltPool :param address: - :param auth: :param pool_config: :param workspace_config: :returns: BoltPool """ - def opener(addr, deadline): + def opener(addr, auth_manager, deadline): return Bolt.open( - addr, auth=auth, deadline=deadline, routing_context=None, - pool_config=pool_config + addr, auth_manager=auth_manager, deadline=deadline, + routing_context=None, pool_config=pool_config ) pool = cls(opener, pool_config, workspace_config, address) @@ -399,7 +516,8 @@ def __repr__(self): self.address) def acquire( - self, access_mode, timeout, database, bookmarks, liveness_check_timeout + self, access_mode, timeout, database, bookmarks, auth: AcquireAuth, + liveness_check_timeout ): # The access_mode and database is not needed for a direct connection, # it's just there for consistency. @@ -407,7 +525,7 @@ def acquire( "access_mode=%r, database=%r", access_mode, database) deadline = Deadline.from_timeout_or_deadline(timeout) return self._acquire( - self.address, deadline, liveness_check_timeout + self.address, auth, deadline, liveness_check_timeout ) @@ -416,12 +534,11 @@ class Neo4jPool(IOPool): """ @classmethod - def open(cls, *addresses, auth, pool_config, workspace_config, + def open(cls, *addresses, pool_config, workspace_config, routing_context=None): """Create a new Neo4jPool :param addresses: one or more address as positional argument - :param auth: :param pool_config: :param workspace_config: :param routing_context: @@ -435,9 +552,9 @@ def open(cls, *addresses, auth, pool_config, workspace_config, raise ConfigurationError("The key 'address' is reserved for routing context.") routing_context["address"] = str(address) - def opener(addr, deadline): + def opener(addr, auth_manager, deadline): return Bolt.open( - addr, auth=auth, deadline=deadline, + addr, auth_manager=auth_manager, deadline=deadline, routing_context=routing_context, pool_config=pool_config ) @@ -478,7 +595,7 @@ def get_or_create_routing_table(self, database): return self.routing_tables[database] def fetch_routing_info( - self, address, database, imp_user, bookmarks, acquisition_timeout + self, address, database, imp_user, bookmarks, auth, acquisition_timeout ): """ Fetch raw routing info from a given router address. @@ -489,6 +606,7 @@ def fetch_routing_info( :type imp_user: str or None :param bookmarks: iterable of bookmark values after which the routing info should be fetched + :param auth: auth :param acquisition_timeout: connection acquisition timeout :returns: list of routing records, or None if no connection @@ -499,7 +617,10 @@ def fetch_routing_info( deadline = Deadline.from_timeout_or_deadline(acquisition_timeout) log.debug("[#0000] _: _acquire router connection, " "database=%r, address=%r", database, address) - cx = self._acquire(address, deadline, None) + if auth: + auth = copy(auth) + auth.force_auth = False + cx = self._acquire(address, auth, deadline, None) try: routing_table = cx.route( database=database or self.workspace_config.database, @@ -511,7 +632,8 @@ def fetch_routing_info( return routing_table def fetch_routing_table( - self, *, address, acquisition_timeout, database, imp_user, bookmarks + self, *, address, acquisition_timeout, database, imp_user, + bookmarks, auth ): """ Fetch a routing table from a given router address. @@ -523,6 +645,7 @@ def fetch_routing_table( table :type imp_user: str or None :param bookmarks: bookmarks used when fetching routing table + :param auth: auth :returns: a new RoutingTable instance or None if the given router is currently unable to provide routing information @@ -530,7 +653,8 @@ def fetch_routing_table( new_routing_info = None try: new_routing_info = self.fetch_routing_info( - address, database, imp_user, bookmarks, acquisition_timeout + address, database, imp_user, bookmarks, auth, + acquisition_timeout ) except Neo4jError as e: # checks if the code is an error that is caused by the client. In @@ -576,8 +700,8 @@ def fetch_routing_table( return new_routing_table def _update_routing_table_from( - self, *routers, database, imp_user, bookmarks, acquisition_timeout, - database_callback + self, *routers, database, imp_user, bookmarks, auth, + acquisition_timeout, database_callback ): """ Try to update routing tables with the given routers. @@ -593,7 +717,8 @@ def _update_routing_table_from( ): new_routing_table = self.fetch_routing_table( address=address, acquisition_timeout=acquisition_timeout, - database=database, imp_user=imp_user, bookmarks=bookmarks + database=database, imp_user=imp_user, bookmarks=bookmarks, + auth=auth ) if new_routing_table is not None: new_database = new_routing_table.database @@ -613,8 +738,8 @@ def _update_routing_table_from( return False def update_routing_table( - self, *, database, imp_user, bookmarks, acquisition_timeout=None, - database_callback=None + self, *, database, imp_user, bookmarks, auth=None, + acquisition_timeout=None, database_callback=None ): """ Update the routing table from the first router able to provide valid routing information. @@ -624,6 +749,7 @@ def update_routing_table( table :type imp_user: str or None :param bookmarks: bookmarks used when fetching routing table + :param auth: auth :param acquisition_timeout: connection acquisition timeout :param database_callback: A callback function that will be called with the database name as only argument when a new routing table has been @@ -645,15 +771,15 @@ def update_routing_table( # TODO: Test this state if self._update_routing_table_from( self.address, database=database, - imp_user=imp_user, bookmarks=bookmarks, + imp_user=imp_user, bookmarks=bookmarks, auth=auth, acquisition_timeout=acquisition_timeout, database_callback=database_callback ): # Why is only the first initial routing address used? return if self._update_routing_table_from( - *(existing_routers - {self.address}), - database=database, imp_user=imp_user, bookmarks=bookmarks, + *(existing_routers - {self.address}), database=database, + imp_user=imp_user, bookmarks=bookmarks, auth=auth, acquisition_timeout=acquisition_timeout, database_callback=database_callback ): @@ -662,7 +788,7 @@ def update_routing_table( if not prefer_initial_routing_address: if self._update_routing_table_from( self.address, database=database, - imp_user=imp_user, bookmarks=bookmarks, + imp_user=imp_user, bookmarks=bookmarks, auth=auth, acquisition_timeout=acquisition_timeout, database_callback=database_callback ): @@ -681,7 +807,7 @@ def update_connection_pool(self, *, database): super(Neo4jPool, self).deactivate(address) def ensure_routing_table_is_fresh( - self, *, access_mode, database, imp_user, bookmarks, + self, *, access_mode, database, imp_user, bookmarks, auth=None, acquisition_timeout=None, database_callback=None ): """ Update the routing table if stale. @@ -718,7 +844,7 @@ def ensure_routing_table_is_fresh( self.update_routing_table( database=database, imp_user=imp_user, bookmarks=bookmarks, - acquisition_timeout=acquisition_timeout, + auth=auth, acquisition_timeout=acquisition_timeout, database_callback=database_callback ) self.update_connection_pool(database=database) @@ -755,7 +881,8 @@ def _select_address(self, *, access_mode, database): return choice(addresses_by_usage[min(addresses_by_usage)]) def acquire( - self, access_mode, timeout, database, bookmarks, liveness_check_timeout + self, access_mode, timeout, database, bookmarks, + auth: t.Optional[AcquireAuth], liveness_check_timeout ): if access_mode not in (WRITE_ACCESS, READ_ACCESS): raise ClientError("Non valid 'access_mode'; {}".format(access_mode)) @@ -775,7 +902,7 @@ def acquire( "access_mode=%r, database=%r", access_mode, database) self.ensure_routing_table_is_fresh( access_mode=access_mode, database=database, - imp_user=None, bookmarks=bookmarks, + imp_user=None, bookmarks=bookmarks, auth=auth, acquisition_timeout=timeout ) @@ -794,7 +921,7 @@ def acquire( deadline = Deadline.from_timeout_or_deadline(timeout) # should always be a resolved address connection = self._acquire( - address, deadline, liveness_check_timeout + address, auth, deadline, liveness_check_timeout, ) except (ServiceUnavailable, SessionExpired): self.deactivate(address=address) diff --git a/src/neo4j/_sync/work/result.py b/src/neo4j/_sync/work/result.py index 86255e5f..cfc6df05 100644 --- a/src/neo4j/_sync/work/result.py +++ b/src/neo4j/_sync/work/result.py @@ -282,7 +282,7 @@ def _buffer(self, n=None): Might end up with more records in the buffer if the fetch size makes it overshoot. - Might ent up with fewer records in the buffer if there are not enough + Might end up with fewer records in the buffer if there are not enough records available. """ if self._out_of_scope: @@ -624,7 +624,7 @@ def to_eager_result(self) -> EagerResult: was obtained has been closed or the Result has been explicitly consumed. - **This is experimental.** (See :ref:`filter-warnings-ref`) + **This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. .. versionadded:: 5.5 diff --git a/src/neo4j/_sync/work/session.py b/src/neo4j/_sync/work/session.py index 3afdd453..8f38739b 100644 --- a/src/neo4j/_sync/work/session.py +++ b/src/neo4j/_sync/work/session.py @@ -20,6 +20,7 @@ import asyncio import typing as t +import warnings from logging import getLogger from random import random from time import perf_counter @@ -27,7 +28,10 @@ from ..._async_compat import sleep from ..._async_compat.util import Util from ..._conf import SessionConfig -from ..._meta import deprecated +from ..._meta import ( + deprecated, + PreviewWarning, +) from ..._work import Query from ...api import ( Bookmarks, @@ -42,6 +46,7 @@ SessionExpired, TransactionError, ) +from ..auth_management import AuthManagers from .result import Result from .transaction import ( ManagedTransaction, @@ -97,7 +102,17 @@ class Session(Workspace): def __init__(self, pool, session_config): assert isinstance(session_config, SessionConfig) + if session_config.auth is not None: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message=r".*\bAuth managers\b.*", + category=PreviewWarning + ) + session_config.auth = AuthManagers.static( + session_config.auth + ) super().__init__(pool, session_config) + self._config = session_config self._initialize_bookmarks(session_config.bookmarks) self._bookmark_manager = session_config.bookmark_manager @@ -113,11 +128,13 @@ def __exit__(self, exception_type, exception_value, traceback): self._state_failed = True self.close() - def _connect(self, access_mode, **access_kwargs): + def _connect(self, access_mode, **acquire_kwargs): if access_mode is None: access_mode = self._config.default_access_mode try: - super()._connect(access_mode, **access_kwargs) + super()._connect( + access_mode, auth=self._config.auth, **acquire_kwargs + ) except asyncio.CancelledError: self._handle_cancellation(message="_connect") raise @@ -162,6 +179,11 @@ def _get_server_info(self): self._disconnect() return server_info + def _verify_authentication(self): + assert not self._connection + self._connect(READ_ACCESS, force_auth=True) + self._disconnect() + def close(self) -> None: """Close the session. diff --git a/src/neo4j/_sync/work/workspace.py b/src/neo4j/_sync/work/workspace.py index 2cb91648..b098c73d 100644 --- a/src/neo4j/_sync/work/workspace.py +++ b/src/neo4j/_sync/work/workspace.py @@ -33,7 +33,10 @@ SessionError, SessionExpired, ) -from ..io import Neo4jPool +from ..io import ( + AcquireAuth, + Neo4jPool, +) log = logging.getLogger("neo4j") @@ -130,8 +133,13 @@ def _update_bookmark(self, bookmark): return self._update_bookmarks((bookmark,)) - def _connect(self, access_mode, **acquire_kwargs): + def _connect(self, access_mode, auth=None, **acquire_kwargs): acquisition_timeout = self._config.connection_acquisition_timeout + auth = AcquireAuth( + auth, + force_auth=acquire_kwargs.pop("force_auth", False), + ) + if self._connection: # TODO: Investigate this # log.warning("FIXME: should always disconnect before connect") @@ -154,6 +162,7 @@ def _connect(self, access_mode, **acquire_kwargs): database=self._config.database, imp_user=self._config.impersonated_user, bookmarks=self._get_bookmarks(), + auth=auth, acquisition_timeout=acquisition_timeout, database_callback=self._set_cached_database ) @@ -162,6 +171,7 @@ def _connect(self, access_mode, **acquire_kwargs): "timeout": acquisition_timeout, "database": self._config.database, "bookmarks": self._get_bookmarks(), + "auth": auth, "liveness_check_timeout": None, } acquire_kwargs_.update(acquire_kwargs) diff --git a/src/neo4j/_work/eager_result.py b/src/neo4j/_work/eager_result.py index 87abb410..d88e23f5 100644 --- a/src/neo4j/_work/eager_result.py +++ b/src/neo4j/_work/eager_result.py @@ -33,7 +33,7 @@ class EagerResult(t.NamedTuple): * keys - the list of keys returned by the query (see :attr:`AsyncResult.keys` and :attr:`.Result.keys`) - **This is experimental.** (See :ref:`filter-warnings-ref`) + **This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. .. seealso:: diff --git a/src/neo4j/api.py b/src/neo4j/api.py index e91aed6e..c5d94a33 100644 --- a/src/neo4j/api.py +++ b/src/neo4j/api.py @@ -103,10 +103,21 @@ def __init__( if parameters: self.parameters = parameters + def __eq__(self, other): + if not isinstance(other, Auth): + return NotImplemented + return vars(self) == vars(other) + # For backwards compatibility AuthToken = Auth +# if t.TYPE_CHECKING: +# commented out as work around for +# https://github.com/sphinx-doc/sphinx/pull/10880 +# make sure TAuth is resolved in the docs, else they're pretty useless +_TAuth = t.Union[t.Tuple[t.Any, t.Any], Auth, None] + def basic_auth( user: str, password: str, realm: t.Optional[str] = None diff --git a/src/neo4j/auth_management.py b/src/neo4j/auth_management.py new file mode 100644 index 00000000..8fdad964 --- /dev/null +++ b/src/neo4j/auth_management.py @@ -0,0 +1,34 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ._async.auth_management import AsyncAuthManagers +from ._auth_management import ( + AsyncAuthManager, + AuthManager, + ExpiringAuth, +) +from ._sync.auth_management import AuthManagers + + +__all__ = [ + "AsyncAuthManager", + "AsyncAuthManagers", + "AuthManager", + "AuthManagers", + "ExpiringAuth", +] diff --git a/src/neo4j/exceptions.py b/src/neo4j/exceptions.py index 73be03a5..d6708aee 100644 --- a/src/neo4j/exceptions.py +++ b/src/neo4j/exceptions.py @@ -232,15 +232,18 @@ def is_retryable(self) -> bool: """ return False - def _invalidates_all_connections(self) -> bool: + def _unauthenticates_all_connections(self) -> bool: return self.code == "Neo.ClientError.Security.AuthorizationExpired" + def _requires_new_credentials(self) -> bool: + return self.code == "Neo.ClientError.Security.TokenExpired" + # TODO: 6.0 - Remove this alias invalidates_all_connections = deprecated( "Neo4jError.invalidates_all_connections is deprecated and will be " "removed in a future version. It is an internal method and not meant " "for external use." - )(_invalidates_all_connections) + )(_unauthenticates_all_connections) def _is_fatal_during_discovery(self) -> bool: # checks if the code is an error that is caused by the client. In this @@ -309,9 +312,25 @@ class AuthError(ClientError): class TokenExpired(AuthError): """ Raised when the authentication token has expired. - A new driver instance with a fresh authentication token needs to be created. + A new driver instance with a fresh authentication token needs to be + created, unless the driver was configured using a non-static + :class:`.AuthManager`. In that case, the error will be + :exc:`.TokenExpiredRetryable` instead. + """ + + +# Neo4jError > ClientError > AuthError > TokenExpired > TokenExpiredRetryable +class TokenExpiredRetryable(TokenExpired): + """Raised when the authentication token has expired but can be refreshed. + + This is the same server error as :exc:`.TokenExpired`, but raised when + the driver is configured to be able to refresh the token, hence making + the error retryable. """ + def is_retryable(self) -> bool: + return True + # Neo4jError > ClientError > Forbidden class Forbidden(ClientError): diff --git a/src/neo4j/time/__init__.py b/src/neo4j/time/__init__.py index 29fc4f72..04bba43e 100644 --- a/src/neo4j/time/__init__.py +++ b/src/neo4j/time/__init__.py @@ -1392,7 +1392,7 @@ class Time(time_base_class, metaclass=TimeType): :raises ValueError: if one of the parameters is out of range. - ..versionchanged:: 5.0 + .. versionchanged:: 5.0 The parameter ``second`` no longer accepts :class:`float` values. """ @@ -1512,7 +1512,7 @@ def from_ticks(cls, ticks: int, tz: t.Optional[_tzinfo] = None) -> Time: :raises ValueError: if ticks is out of bounds (0 <= ticks < 86400000000000) - ..versionchanged:: 5.0 + .. versionchanged:: 5.0 The parameter ``ticks`` no longer accepts :class:`float` values but only :class:`int`. It's now nanoseconds since midnight instead of seconds. @@ -2545,7 +2545,7 @@ def as_timezone(self, tz: _tzinfo) -> DateTime: :param tz: the new timezone - :returns: the same object if ``tz`` is :const:``None``. + :returns: the same object if ``tz`` is :data:``None``. Else, a new :class:`.DateTime` that's the same point in time but in a different timezone. """ diff --git a/testkitbackend/__main__.py b/testkitbackend/__main__.py index 91a4a6e1..94739b13 100644 --- a/testkitbackend/__main__.py +++ b/testkitbackend/__main__.py @@ -30,6 +30,7 @@ def sync_main(): server = Server(("0.0.0.0", 9876)) + print("Start serving") while True: server.handle_request() @@ -39,6 +40,7 @@ async def main(): server = AsyncServer(("0.0.0.0", 9876)) await server.start() try: + print("Start serving") await server.serve_forever() finally: server.stop() diff --git a/testkitbackend/_async/backend.py b/testkitbackend/_async/backend.py index 507fbf97..b3bdd885 100644 --- a/testkitbackend/_async/backend.py +++ b/testkitbackend/_async/backend.py @@ -55,6 +55,10 @@ def __init__(self, rd, wr): self.drivers = {} self.custom_resolutions = {} self.dns_resolutions = {} + self.auth_token_managers = {} + self.auth_token_supplies = {} + self.auth_token_on_expiration_supplies = {} + self.expiring_auth_token_supplies = {} self.bookmark_managers = {} self.bookmarks_consumptions = {} self.bookmarks_supplies = {} @@ -64,6 +68,8 @@ def __init__(self, rd, wr): self.transactions = {} self.errors = {} self.key = 0 + self.fake_time = None + self.fake_time_ticker = None # Collect all request handlers self._requestHandlers = dict( [m for m in getmembers(requests, isfunction)]) diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index ce5b92d2..7493c7e5 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -16,21 +16,33 @@ # limitations under the License. +import datetime import json import re import warnings from os import path +from freezegun import freeze_time + import neo4j import neo4j.api +import neo4j.auth_management from neo4j._async_compat.util import AsyncUtil +from neo4j.auth_management import ( + AsyncAuthManager, + AsyncAuthManagers, + ExpiringAuth, +) from .. import ( fromtestkit, test_subtest_skips, totestkit, ) -from .._warning_check import warning_check +from .._warning_check import ( + warning_check, + warnings_check, +) from ..exceptions import MarkdAsDriverException @@ -94,27 +106,15 @@ async def GetFeatures(backend, data): async def NewDriver(backend, data): - auth_token = data["authorizationToken"]["data"] - data["authorizationToken"].mark_item_as_read_if_equals( - "name", "AuthorizationToken" - ) - scheme = auth_token["scheme"] - if scheme == "basic": - auth = neo4j.basic_auth( - auth_token["principal"], auth_token["credentials"], - realm=auth_token.get("realm", None) + auth = fromtestkit.to_auth_token(data, "authorizationToken") + expected_warnings = [] + if auth is None and data.get("authTokenManagerId") is not None: + auth = backend.auth_token_managers[data["authTokenManagerId"]] + expected_warnings.append( + (neo4j.PreviewWarning, "Auth managers are a preview feature.") ) - elif scheme == "kerberos": - auth = neo4j.kerberos_auth(auth_token["credentials"]) - elif scheme == "bearer": - auth = neo4j.bearer_auth(auth_token["credentials"]) else: - auth = neo4j.custom_auth( - auth_token["principal"], auth_token["credentials"], - auth_token["realm"], auth_token["scheme"], - **auth_token.get("parameters", {}) - ) - auth_token.mark_item_as_read("parameters", recursive=True) + data.mark_item_as_read_if_equals("authTokenManagerId", None) kwargs = {} if data["resolverRegistered"] or data["domainNameResolverRegistered"]: kwargs["resolver"] = resolution_func( @@ -131,12 +131,17 @@ async def NewDriver(backend, data): for k in ("sessionConnectionTimeoutMs", "updateRoutingTableTimeoutMs"): if k in data: data.mark_item_as_read_if_equals(k, None) - if data.get("maxConnectionPoolSize"): - kwargs["max_connection_pool_size"] = data["maxConnectionPoolSize"] - if data.get("fetchSize"): - kwargs["fetch_size"] = data["fetchSize"] - if "encrypted" in data: - kwargs["encrypted"] = data["encrypted"] + for (conf_name, data_name) in ( + ("max_connection_pool_size", "maxConnectionPoolSize"), + ("fetch_size", "fetchSize"), + ): + if data.get(data_name): + kwargs[conf_name] = data[data_name] + for (conf_name, data_name) in ( + ("encrypted", "encrypted"), + ): + if data_name in data: + kwargs[conf_name] = data[data_name] if "trustedCertificates" in data: if data["trustedCertificates"] is None: kwargs["trusted_certificates"] = neo4j.TrustSystemCAs() @@ -149,14 +154,134 @@ async def NewDriver(backend, data): fromtestkit.set_notifications_config(kwargs, data) data.mark_item_as_read_if_equals("livenessCheckTimeoutMs", None) - driver = neo4j.AsyncGraphDatabase.driver( - data["uri"], auth=auth, user_agent=data["userAgent"], **kwargs, - ) + if expected_warnings: + with warnings_check(expected_warnings): + driver = neo4j.AsyncGraphDatabase.driver( + data["uri"], auth=auth, user_agent=data["userAgent"], **kwargs, + ) + else: + driver = neo4j.AsyncGraphDatabase.driver( + data["uri"], auth=auth, user_agent=data["userAgent"], **kwargs, + ) key = backend.next_key() backend.drivers[key] = driver await backend.send_response("Driver", {"id": key}) +async def NewAuthTokenManager(backend, data): + auth_token_manager_id = backend.next_key() + + class TestKitAuthManager(AsyncAuthManager): + async def get_auth(self): + key = backend.next_key() + await backend.send_response("AuthTokenManagerGetAuthRequest", { + "id": key, + "authTokenManagerId": auth_token_manager_id, + }) + if not await backend.process_request(): + # connection was closed before end of next message + return None + if key not in backend.auth_token_supplies: + raise RuntimeError( + "Backend did not receive expected " + f"AuthTokenManagerGetAuthCompleted message for id {key}" + ) + return backend.auth_token_supplies.pop(key) + + async def on_auth_expired(self, auth): + key = backend.next_key() + await backend.send_response( + "AuthTokenManagerOnAuthExpiredRequest", { + "id": key, + "authTokenManagerId": auth_token_manager_id, + "auth": totestkit.auth_token(auth), + } + ) + if not await backend.process_request(): + # connection was closed before end of next message + return None + if key not in backend.auth_token_on_expiration_supplies: + raise RuntimeError( + "Backend did not receive expected " + "AuthTokenManagerOnAuthExpiredCompleted message for id " + f"{key}" + ) + backend.auth_token_on_expiration_supplies.pop(key) + + auth_manager = TestKitAuthManager() + backend.auth_token_managers[auth_token_manager_id] = auth_manager + await backend.send_response( + "AuthTokenManager", {"id": auth_token_manager_id} + ) + + +async def AuthTokenManagerGetAuthCompleted(backend, data): + auth_token = fromtestkit.to_auth_token(data, "auth") + + backend.auth_token_supplies[data["requestId"]] = auth_token + + +async def AuthTokenManagerOnAuthExpiredCompleted(backend, data): + backend.auth_token_on_expiration_supplies[data["requestId"]] = True + + +async def AuthTokenManagerClose(backend, data): + auth_token_manager_id = data["id"] + del backend.auth_token_managers[auth_token_manager_id] + await backend.send_response( + "AuthTokenManager", {"id": auth_token_manager_id} + ) + + +async def NewExpirationBasedAuthTokenManager(backend, data): + auth_token_manager_id = backend.next_key() + + async def auth_token_provider(): + key = backend.next_key() + await backend.send_response( + "ExpirationBasedAuthTokenProviderRequest", + { + "id": key, + "expirationBasedAuthTokenManagerId": auth_token_manager_id, + } + ) + if not await backend.process_request(): + # connection was closed before end of next message + return neo4j.auth_management.ExpiringAuth(None, None) + if key not in backend.expiring_auth_token_supplies: + raise RuntimeError( + "Backend did not receive expected " + "ExpirationBasedAuthTokenManagerCompleted message for id " + f"{key}" + ) + return backend.expiring_auth_token_supplies.pop(key) + + with warning_check(neo4j.PreviewWarning, + "Auth managers are a preview feature."): + auth_manager = AsyncAuthManagers.expiration_based(auth_token_provider) + backend.auth_token_managers[auth_token_manager_id] = auth_manager + await backend.send_response( + "ExpirationBasedAuthTokenManager", {"id": auth_token_manager_id} + ) + + +async def ExpirationBasedAuthTokenProviderCompleted(backend, data): + temp_auth_data = data["auth"] + temp_auth_data.mark_item_as_read_if_equals("name", + "AuthTokenAndExpiration") + temp_auth_data = temp_auth_data["data"] + auth_token = fromtestkit.to_auth_token(temp_auth_data, "auth") + if temp_auth_data["expiresInMs"] is not None: + expires_in = temp_auth_data["expiresInMs"] / 1000 + else: + expires_in = None + with warning_check(neo4j.PreviewWarning, + "Auth managers are a preview feature."): + expiring_auth = ExpiringAuth(auth_token, expires_in) + + backend.expiring_auth_token_supplies[data["requestId"]] = expiring_auth + + async def VerifyConnectivity(backend, data): driver_id = data["driverId"] driver = backend.drivers[driver_id] @@ -178,17 +303,33 @@ async def GetServerInfo(backend, data): async def CheckMultiDBSupport(backend, data): driver_id = data["driverId"] driver = backend.drivers[driver_id] - with warning_check( - neo4j.ExperimentalWarning, - "Feature support query, based on Bolt protocol version and Neo4j " - "server version will change in the future." - ): - available = await driver.supports_multi_db() + available = await driver.supports_multi_db() await backend.send_response("MultiDBSupport", { "id": backend.next_key(), "available": available }) +async def VerifyAuthentication(backend, data): + driver_id = data["driverId"] + driver = backend.drivers[driver_id] + auth = fromtestkit.to_auth_token(data, "authorizationToken") + with warning_check(neo4j.PreviewWarning, + "User switching is a preview feature."): + authenticated = await driver.verify_authentication(auth=auth) + await backend.send_response("DriverIsAuthenticated", { + "id": backend.next_key(), "authenticated": authenticated + }) + + +async def CheckSessionAuthSupport(backend, data): + driver_id = data["driverId"] + driver = backend.drivers[driver_id] + available = await driver.supports_session_auth() + await backend.send_response("SessionAuthSupport", { + "id": backend.next_key(), "available": available + }) + + async def ExecuteQuery(backend, data): driver = backend.drivers[data["driverId"]] cypher, params = fromtestkit.to_cypher_and_params(data) @@ -396,6 +537,7 @@ def __init__(self, session): async def NewSession(backend, data): driver = backend.drivers[data["driverId"]] access_mode = data["accessMode"] + expected_warnings = [] if access_mode == "r": access_mode = neo4j.READ_ACCESS elif access_mode == "w": @@ -414,19 +556,25 @@ async def NewSession(backend, data): config["bookmark_manager"] = backend.bookmark_managers[ data["bookmarkManagerId"] ] + expected_warnings.append(( + neo4j.ExperimentalWarning, + "The 'bookmark_manager' config key is experimental. It might be " + "changed or removed any time even without prior notice." + )) for (conf_name, data_name) in ( ("fetch_size", "fetchSize"), ("impersonated_user", "impersonatedUser"), ): if data_name in data: config[conf_name] = data[data_name] + if data.get("authorizationToken"): + config["auth"] = fromtestkit.to_auth_token(data, "authorizationToken") + expected_warnings.append( + (neo4j.PreviewWarning, "User switching is a preview features.") + ) fromtestkit.set_notifications_config(config, data) - if "bookmark_manager" in config: - with warning_check( - neo4j.ExperimentalWarning, - "The 'bookmark_manager' config key is experimental. It might be " - "changed or removed any time even without prior notice." - ): + if expected_warnings: + with warnings_check(expected_warnings): session = driver.session(**config) else: session = driver.session(**config) @@ -647,3 +795,30 @@ async def GetRoutingTable(backend, data): addresses = routing_table.__getattribute__(role) response_data[role] = list(map(str, addresses)) await backend.send_response("RoutingTable", response_data) + +async def FakeTimeInstall(backend, _data): + assert backend.fake_time is None + assert backend.fake_time_ticker is None + + backend.fake_time = freeze_time() + backend.fake_time_ticker = backend.fake_time.start() + await backend.send_response("FakeTimeAck", {}) + +async def FakeTimeTick(backend, data): + assert backend.fake_time is not None + assert backend.fake_time_ticker is not None + + increment_ms = data["incrementMs"] + delta = datetime.timedelta(milliseconds=increment_ms) + backend.fake_time_ticker.tick(delta=delta) + await backend.send_response("FakeTimeAck", {}) + + +async def FakeTimeUninstall(backend, _data): + assert backend.fake_time is not None + assert backend.fake_time_ticker is not None + + backend.fake_time.stop() + backend.fake_time_ticker = None + backend.fake_time = None + await backend.send_response("FakeTimeAck", {}) diff --git a/testkitbackend/_sync/backend.py b/testkitbackend/_sync/backend.py index 1ac6d3b4..b5ad66ac 100644 --- a/testkitbackend/_sync/backend.py +++ b/testkitbackend/_sync/backend.py @@ -55,6 +55,10 @@ def __init__(self, rd, wr): self.drivers = {} self.custom_resolutions = {} self.dns_resolutions = {} + self.auth_token_managers = {} + self.auth_token_supplies = {} + self.auth_token_on_expiration_supplies = {} + self.expiring_auth_token_supplies = {} self.bookmark_managers = {} self.bookmarks_consumptions = {} self.bookmarks_supplies = {} @@ -64,6 +68,8 @@ def __init__(self, rd, wr): self.transactions = {} self.errors = {} self.key = 0 + self.fake_time = None + self.fake_time_ticker = None # Collect all request handlers self._requestHandlers = dict( [m for m in getmembers(requests, isfunction)]) diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index 07a44673..258b116a 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -16,21 +16,33 @@ # limitations under the License. +import datetime import json import re import warnings from os import path +from freezegun import freeze_time + import neo4j import neo4j.api +import neo4j.auth_management from neo4j._async_compat.util import Util +from neo4j.auth_management import ( + AuthManager, + AuthManagers, + ExpiringAuth, +) from .. import ( fromtestkit, test_subtest_skips, totestkit, ) -from .._warning_check import warning_check +from .._warning_check import ( + warning_check, + warnings_check, +) from ..exceptions import MarkdAsDriverException @@ -94,27 +106,15 @@ def GetFeatures(backend, data): def NewDriver(backend, data): - auth_token = data["authorizationToken"]["data"] - data["authorizationToken"].mark_item_as_read_if_equals( - "name", "AuthorizationToken" - ) - scheme = auth_token["scheme"] - if scheme == "basic": - auth = neo4j.basic_auth( - auth_token["principal"], auth_token["credentials"], - realm=auth_token.get("realm", None) + auth = fromtestkit.to_auth_token(data, "authorizationToken") + expected_warnings = [] + if auth is None and data.get("authTokenManagerId") is not None: + auth = backend.auth_token_managers[data["authTokenManagerId"]] + expected_warnings.append( + (neo4j.PreviewWarning, "Auth managers are a preview feature.") ) - elif scheme == "kerberos": - auth = neo4j.kerberos_auth(auth_token["credentials"]) - elif scheme == "bearer": - auth = neo4j.bearer_auth(auth_token["credentials"]) else: - auth = neo4j.custom_auth( - auth_token["principal"], auth_token["credentials"], - auth_token["realm"], auth_token["scheme"], - **auth_token.get("parameters", {}) - ) - auth_token.mark_item_as_read("parameters", recursive=True) + data.mark_item_as_read_if_equals("authTokenManagerId", None) kwargs = {} if data["resolverRegistered"] or data["domainNameResolverRegistered"]: kwargs["resolver"] = resolution_func( @@ -131,12 +131,17 @@ def NewDriver(backend, data): for k in ("sessionConnectionTimeoutMs", "updateRoutingTableTimeoutMs"): if k in data: data.mark_item_as_read_if_equals(k, None) - if data.get("maxConnectionPoolSize"): - kwargs["max_connection_pool_size"] = data["maxConnectionPoolSize"] - if data.get("fetchSize"): - kwargs["fetch_size"] = data["fetchSize"] - if "encrypted" in data: - kwargs["encrypted"] = data["encrypted"] + for (conf_name, data_name) in ( + ("max_connection_pool_size", "maxConnectionPoolSize"), + ("fetch_size", "fetchSize"), + ): + if data.get(data_name): + kwargs[conf_name] = data[data_name] + for (conf_name, data_name) in ( + ("encrypted", "encrypted"), + ): + if data_name in data: + kwargs[conf_name] = data[data_name] if "trustedCertificates" in data: if data["trustedCertificates"] is None: kwargs["trusted_certificates"] = neo4j.TrustSystemCAs() @@ -149,14 +154,134 @@ def NewDriver(backend, data): fromtestkit.set_notifications_config(kwargs, data) data.mark_item_as_read_if_equals("livenessCheckTimeoutMs", None) - driver = neo4j.GraphDatabase.driver( - data["uri"], auth=auth, user_agent=data["userAgent"], **kwargs, - ) + if expected_warnings: + with warnings_check(expected_warnings): + driver = neo4j.GraphDatabase.driver( + data["uri"], auth=auth, user_agent=data["userAgent"], **kwargs, + ) + else: + driver = neo4j.GraphDatabase.driver( + data["uri"], auth=auth, user_agent=data["userAgent"], **kwargs, + ) key = backend.next_key() backend.drivers[key] = driver backend.send_response("Driver", {"id": key}) +def NewAuthTokenManager(backend, data): + auth_token_manager_id = backend.next_key() + + class TestKitAuthManager(AuthManager): + def get_auth(self): + key = backend.next_key() + backend.send_response("AuthTokenManagerGetAuthRequest", { + "id": key, + "authTokenManagerId": auth_token_manager_id, + }) + if not backend.process_request(): + # connection was closed before end of next message + return None + if key not in backend.auth_token_supplies: + raise RuntimeError( + "Backend did not receive expected " + f"AuthTokenManagerGetAuthCompleted message for id {key}" + ) + return backend.auth_token_supplies.pop(key) + + def on_auth_expired(self, auth): + key = backend.next_key() + backend.send_response( + "AuthTokenManagerOnAuthExpiredRequest", { + "id": key, + "authTokenManagerId": auth_token_manager_id, + "auth": totestkit.auth_token(auth), + } + ) + if not backend.process_request(): + # connection was closed before end of next message + return None + if key not in backend.auth_token_on_expiration_supplies: + raise RuntimeError( + "Backend did not receive expected " + "AuthTokenManagerOnAuthExpiredCompleted message for id " + f"{key}" + ) + backend.auth_token_on_expiration_supplies.pop(key) + + auth_manager = TestKitAuthManager() + backend.auth_token_managers[auth_token_manager_id] = auth_manager + backend.send_response( + "AuthTokenManager", {"id": auth_token_manager_id} + ) + + +def AuthTokenManagerGetAuthCompleted(backend, data): + auth_token = fromtestkit.to_auth_token(data, "auth") + + backend.auth_token_supplies[data["requestId"]] = auth_token + + +def AuthTokenManagerOnAuthExpiredCompleted(backend, data): + backend.auth_token_on_expiration_supplies[data["requestId"]] = True + + +def AuthTokenManagerClose(backend, data): + auth_token_manager_id = data["id"] + del backend.auth_token_managers[auth_token_manager_id] + backend.send_response( + "AuthTokenManager", {"id": auth_token_manager_id} + ) + + +def NewExpirationBasedAuthTokenManager(backend, data): + auth_token_manager_id = backend.next_key() + + def auth_token_provider(): + key = backend.next_key() + backend.send_response( + "ExpirationBasedAuthTokenProviderRequest", + { + "id": key, + "expirationBasedAuthTokenManagerId": auth_token_manager_id, + } + ) + if not backend.process_request(): + # connection was closed before end of next message + return neo4j.auth_management.ExpiringAuth(None, None) + if key not in backend.expiring_auth_token_supplies: + raise RuntimeError( + "Backend did not receive expected " + "ExpirationBasedAuthTokenManagerCompleted message for id " + f"{key}" + ) + return backend.expiring_auth_token_supplies.pop(key) + + with warning_check(neo4j.PreviewWarning, + "Auth managers are a preview feature."): + auth_manager = AuthManagers.expiration_based(auth_token_provider) + backend.auth_token_managers[auth_token_manager_id] = auth_manager + backend.send_response( + "ExpirationBasedAuthTokenManager", {"id": auth_token_manager_id} + ) + + +def ExpirationBasedAuthTokenProviderCompleted(backend, data): + temp_auth_data = data["auth"] + temp_auth_data.mark_item_as_read_if_equals("name", + "AuthTokenAndExpiration") + temp_auth_data = temp_auth_data["data"] + auth_token = fromtestkit.to_auth_token(temp_auth_data, "auth") + if temp_auth_data["expiresInMs"] is not None: + expires_in = temp_auth_data["expiresInMs"] / 1000 + else: + expires_in = None + with warning_check(neo4j.PreviewWarning, + "Auth managers are a preview feature."): + expiring_auth = ExpiringAuth(auth_token, expires_in) + + backend.expiring_auth_token_supplies[data["requestId"]] = expiring_auth + + def VerifyConnectivity(backend, data): driver_id = data["driverId"] driver = backend.drivers[driver_id] @@ -178,17 +303,33 @@ def GetServerInfo(backend, data): def CheckMultiDBSupport(backend, data): driver_id = data["driverId"] driver = backend.drivers[driver_id] - with warning_check( - neo4j.ExperimentalWarning, - "Feature support query, based on Bolt protocol version and Neo4j " - "server version will change in the future." - ): - available = driver.supports_multi_db() + available = driver.supports_multi_db() backend.send_response("MultiDBSupport", { "id": backend.next_key(), "available": available }) +def VerifyAuthentication(backend, data): + driver_id = data["driverId"] + driver = backend.drivers[driver_id] + auth = fromtestkit.to_auth_token(data, "authorizationToken") + with warning_check(neo4j.PreviewWarning, + "User switching is a preview feature."): + authenticated = driver.verify_authentication(auth=auth) + backend.send_response("DriverIsAuthenticated", { + "id": backend.next_key(), "authenticated": authenticated + }) + + +def CheckSessionAuthSupport(backend, data): + driver_id = data["driverId"] + driver = backend.drivers[driver_id] + available = driver.supports_session_auth() + backend.send_response("SessionAuthSupport", { + "id": backend.next_key(), "available": available + }) + + def ExecuteQuery(backend, data): driver = backend.drivers[data["driverId"]] cypher, params = fromtestkit.to_cypher_and_params(data) @@ -396,6 +537,7 @@ def __init__(self, session): def NewSession(backend, data): driver = backend.drivers[data["driverId"]] access_mode = data["accessMode"] + expected_warnings = [] if access_mode == "r": access_mode = neo4j.READ_ACCESS elif access_mode == "w": @@ -414,19 +556,25 @@ def NewSession(backend, data): config["bookmark_manager"] = backend.bookmark_managers[ data["bookmarkManagerId"] ] + expected_warnings.append(( + neo4j.ExperimentalWarning, + "The 'bookmark_manager' config key is experimental. It might be " + "changed or removed any time even without prior notice." + )) for (conf_name, data_name) in ( ("fetch_size", "fetchSize"), ("impersonated_user", "impersonatedUser"), ): if data_name in data: config[conf_name] = data[data_name] + if data.get("authorizationToken"): + config["auth"] = fromtestkit.to_auth_token(data, "authorizationToken") + expected_warnings.append( + (neo4j.PreviewWarning, "User switching is a preview features.") + ) fromtestkit.set_notifications_config(config, data) - if "bookmark_manager" in config: - with warning_check( - neo4j.ExperimentalWarning, - "The 'bookmark_manager' config key is experimental. It might be " - "changed or removed any time even without prior notice." - ): + if expected_warnings: + with warnings_check(expected_warnings): session = driver.session(**config) else: session = driver.session(**config) @@ -647,3 +795,30 @@ def GetRoutingTable(backend, data): addresses = routing_table.__getattribute__(role) response_data[role] = list(map(str, addresses)) backend.send_response("RoutingTable", response_data) + +def FakeTimeInstall(backend, _data): + assert backend.fake_time is None + assert backend.fake_time_ticker is None + + backend.fake_time = freeze_time() + backend.fake_time_ticker = backend.fake_time.start() + backend.send_response("FakeTimeAck", {}) + +def FakeTimeTick(backend, data): + assert backend.fake_time is not None + assert backend.fake_time_ticker is not None + + increment_ms = data["incrementMs"] + delta = datetime.timedelta(milliseconds=increment_ms) + backend.fake_time_ticker.tick(delta=delta) + backend.send_response("FakeTimeAck", {}) + + +def FakeTimeUninstall(backend, _data): + assert backend.fake_time is not None + assert backend.fake_time_ticker is not None + + backend.fake_time.stop() + backend.fake_time_ticker = None + backend.fake_time = None + backend.send_response("FakeTimeAck", {}) diff --git a/testkitbackend/fromtestkit.py b/testkitbackend/fromtestkit.py index 67c491dd..f8370a8d 100644 --- a/testkitbackend/fromtestkit.py +++ b/testkitbackend/fromtestkit.py @@ -20,6 +20,7 @@ import pytz +import neo4j from neo4j import ( NotificationDisabledCategory, NotificationMinimumSeverity, @@ -157,6 +158,31 @@ def to_param(m): raise ValueError("Unknown param type " + name) +def to_auth_token(data, key): + if data[key] is None: + return None + auth_token = data[key]["data"] + data[key].mark_item_as_read_if_equals("name", "AuthorizationToken") + scheme = auth_token["scheme"] + if scheme == "basic": + auth = neo4j.basic_auth( + auth_token["principal"], auth_token["credentials"], + realm=auth_token.get("realm", None) + ) + elif scheme == "kerberos": + auth = neo4j.kerberos_auth(auth_token["credentials"]) + elif scheme == "bearer": + auth = neo4j.bearer_auth(auth_token["credentials"]) + else: + auth = neo4j.custom_auth( + auth_token["principal"], auth_token["credentials"], + auth_token["realm"], auth_token["scheme"], + **auth_token.get("parameters", {}) + ) + auth_token.mark_item_as_read("parameters", recursive=True) + return auth + + def set_notifications_config(config, data): if "notificationsMinSeverity" in data: config["notifications_min_severity"] = \ diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 847adae8..48f24e64 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -21,13 +21,16 @@ "Feature:API:Driver.ExecuteQuery": true, "Feature:API:Driver:GetServerInfo": true, "Feature:API:Driver.IsEncrypted": true, + "Feature:API:Driver.VerifyAuthentication": true, "Feature:API:Driver.VerifyConnectivity": true, + "Feature:API:Driver.SupportsSessionAuth": true, "Feature:API:Driver:NotificationsConfig": true, "Feature:API:Liveness.Check": false, "Feature:API:Result.List": true, "Feature:API:Result.Peek": true, "Feature:API:Result.Single": true, "Feature:API:Result.SingleOptional": true, + "Feature:API:Session:AuthConfig": true, "Feature:API:Session:NotificationsConfig": true, "Feature:API:SSLConfig": true, "Feature:API:SSLSchemes": true, @@ -36,6 +39,7 @@ "Feature:Auth:Bearer": true, "Feature:Auth:Custom": true, "Feature:Auth:Kerberos": true, + "Feature:Auth:Managed": true, "Feature:Bolt:3.0": true, "Feature:Bolt:4.1": true, "Feature:Bolt:4.2": true, @@ -62,6 +66,7 @@ "ConfHint:connection.recv_timeout_seconds": true, + "Backend:MockTime": true, "Backend:RTFetch": true, "Backend:RTForceUpdate": true } diff --git a/testkitbackend/totestkit.py b/testkitbackend/totestkit.py index b7a17afd..2cf3eee4 100644 --- a/testkitbackend/totestkit.py +++ b/testkitbackend/totestkit.py @@ -255,3 +255,7 @@ def to(name, val): } raise ValueError("Unhandled type:" + str(type(v))) + + +def auth_token(auth): + return {"name": "AuthorizationToken", "data": vars(auth)} diff --git a/tests/integration/examples/test_bearer_auth_example.py b/tests/integration/examples/test_bearer_auth_example.py index bb6988df..0c29e0e9 100644 --- a/tests/integration/examples/test_bearer_auth_example.py +++ b/tests/integration/examples/test_bearer_auth_example.py @@ -17,6 +17,7 @@ import neo4j +from neo4j._sync.auth_management import StaticAuthManager from . import DriverSetupExample @@ -54,7 +55,8 @@ def test_example(uri, mocker): assert len(calls) == 1 args_, kwargs = calls[0] auth = kwargs.get("auth") - assert isinstance(auth, neo4j.Auth) + assert isinstance(auth, StaticAuthManager) + auth = auth._auth assert auth.scheme == "bearer" assert not hasattr(auth, "principal") assert auth.credentials == token diff --git a/tests/unit/async_/fixtures/fake_connection.py b/tests/unit/async_/fixtures/fake_connection.py index f0d0070c..9e3995cf 100644 --- a/tests/unit/async_/fixtures/fake_connection.py +++ b/tests/unit/async_/fixtures/fake_connection.py @@ -23,6 +23,7 @@ from neo4j import ServerInfo from neo4j._async.io import AsyncBolt from neo4j._deadline import Deadline +from neo4j.auth_management import AsyncAuthManager __all__ = [ @@ -50,7 +51,10 @@ def __init__(self, *args, **kwargs): self.attach_mock(mock.Mock(return_value=False), "stale") self.attach_mock(mock.Mock(return_value=False), "closed") self.attach_mock(mock.Mock(return_value=False), "socket") - self.attach_mock(mock.Mock(), "unresolved_address") + self.attach_mock(mock.Mock(return_value=False), "re_auth") + self.attach_mock(mock.AsyncMock(spec=AsyncAuthManager), + "auth_manager") + self.unresolved_address = next(iter(args), "localhost") def close_side_effect(): self.closed.return_value = True diff --git a/tests/unit/async_/io/conftest.py b/tests/unit/async_/io/conftest.py index 08b6e9c4..d697d182 100644 --- a/tests/unit/async_/io/conftest.py +++ b/tests/unit/async_/io/conftest.py @@ -24,10 +24,12 @@ import pytest +from neo4j import PreviewWarning from neo4j._async.io._common import ( AsyncInbox, AsyncOutbox, ) +from neo4j.auth_management import AsyncAuthManagers class AsyncFakeSocket: @@ -103,6 +105,9 @@ async def sendall(self, data): def close(self): return + def kill(self): + return + def inject(self, data): self.recv_buffer += data @@ -142,3 +147,16 @@ def fake_socket_2(): @pytest.fixture def fake_socket_pair(): return AsyncFakeSocketPair + +@pytest.fixture +def static_auth(): + def inner(auth): + with pytest.warns(PreviewWarning, match="Auth managers"): + return AsyncAuthManagers.static(auth) + + return inner + + +@pytest.fixture +def none_auth(static_auth): + return static_auth(None) diff --git a/tests/unit/async_/io/test_class_bolt.py b/tests/unit/async_/io/test_class_bolt.py index 100c57a6..735e7fd5 100644 --- a/tests/unit/async_/io/test_class_bolt.py +++ b/tests/unit/async_/io/test_class_bolt.py @@ -20,6 +20,7 @@ import pytest +import neo4j.auth_management from neo4j._async.io import AsyncBolt from neo4j._async_compat.network import AsyncBoltSocket from neo4j._exceptions import BoltHandshakeError @@ -94,7 +95,7 @@ def test_magic_preamble(): @AsyncTestDecorators.mark_async_only_test -async def test_cancel_hello_in_open(mocker): +async def test_cancel_hello_in_open(mocker, none_auth): address = ("localhost", 7687) socket_mock = mocker.AsyncMock(spec=AsyncBoltSocket) @@ -112,7 +113,7 @@ async def test_cancel_hello_in_open(mocker): bolt_mock.local_port = 1234 with pytest.raises(asyncio.CancelledError): - await AsyncBolt.open(address) + await AsyncBolt.open(address, auth_manager=none_auth) bolt_mock.kill.assert_called_once_with() @@ -132,7 +133,9 @@ async def test_cancel_hello_in_open(mocker): ), ) @mark_async_test -async def test_version_negotiation(mocker, bolt_version, bolt_cls_path): +async def test_version_negotiation( + mocker, bolt_version, bolt_cls_path, none_auth +): address = ("localhost", 7687) socket_mock = mocker.AsyncMock(spec=AsyncBoltSocket) @@ -147,7 +150,7 @@ async def test_version_negotiation(mocker, bolt_version, bolt_cls_path): bolt_mock = bolt_cls_mock.return_value bolt_mock.socket = socket_mock - connection = await AsyncBolt.open(address) + connection = await AsyncBolt.open(address, auth_manager=none_auth) bolt_cls_mock.assert_called_once() assert connection is bolt_mock @@ -163,7 +166,7 @@ async def test_version_negotiation(mocker, bolt_version, bolt_cls_path): (6, 0), )) @mark_async_test -async def test_failing_version_negotiation(mocker, bolt_version): +async def test_failing_version_negotiation(mocker, bolt_version, none_auth): supported_protocols = \ "('3.0', '4.1', '4.2', '4.3', '4.4', '5.0', '5.1', '5.2')" @@ -178,6 +181,64 @@ async def test_failing_version_negotiation(mocker, bolt_version): socket_mock.getpeername.return_value = address with pytest.raises(BoltHandshakeError) as exc: - await AsyncBolt.open(address) + await AsyncBolt.open(address, auth_manager=none_auth) assert exc.match(supported_protocols) + + + +@AsyncTestDecorators.mark_async_only_test +async def test_cancel_manager_in_open(mocker): + address = ("localhost", 7687) + socket_mock = mocker.AsyncMock(spec=AsyncBoltSocket) + + socket_cls_mock = mocker.patch("neo4j._async.io._bolt.AsyncBoltSocket", + autospec=True) + socket_cls_mock.connect.return_value = ( + socket_mock, (5, 0), None, None + ) + socket_mock.getpeername.return_value = address + bolt_cls_mock = mocker.patch("neo4j._async.io._bolt5.AsyncBolt5x0", + autospec=True) + bolt_mock = bolt_cls_mock.return_value + bolt_mock.socket = socket_mock + bolt_mock.local_port = 1234 + + auth_manager = mocker.AsyncMock( + spec=neo4j.auth_management.AsyncAuthManager + ) + auth_manager.get_auth.side_effect = asyncio.CancelledError() + + with pytest.raises(asyncio.CancelledError): + await AsyncBolt.open(address, auth_manager=auth_manager) + + socket_mock.kill.assert_called_once_with() + + +@AsyncTestDecorators.mark_async_only_test +async def test_fail_manager_in_open(mocker): + address = ("localhost", 7687) + socket_mock = mocker.AsyncMock(spec=AsyncBoltSocket) + + socket_cls_mock = mocker.patch("neo4j._async.io._bolt.AsyncBoltSocket", + autospec=True) + socket_cls_mock.connect.return_value = ( + socket_mock, (5, 0), None, None + ) + socket_mock.getpeername.return_value = address + bolt_cls_mock = mocker.patch("neo4j._async.io._bolt5.AsyncBolt5x0", + autospec=True) + bolt_mock = bolt_cls_mock.return_value + bolt_mock.socket = socket_mock + bolt_mock.local_port = 1234 + + auth_manager = mocker.AsyncMock( + spec=neo4j.auth_management.AsyncAuthManager + ) + auth_manager.get_auth.side_effect = RuntimeError("token fetching failed") + + with pytest.raises(RuntimeError) as exc: + await AsyncBolt.open(address, auth_manager=auth_manager) + assert exc.value is auth_manager.get_auth.side_effect + + socket_mock.close.assert_called_once_with() diff --git a/tests/unit/async_/io/test_class_bolt3.py b/tests/unit/async_/io/test_class_bolt3.py index aacb57c0..e0954f47 100644 --- a/tests/unit/async_/io/test_class_bolt3.py +++ b/tests/unit/async_/io/test_class_bolt3.py @@ -17,12 +17,13 @@ import logging +from itertools import permutations import pytest +import neo4j from neo4j._async.io._bolt3 import AsyncBolt3 from neo4j._conf import PoolConfig -from neo4j.api import Auth from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_async_test @@ -30,7 +31,7 @@ @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = AsyncBolt3(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -40,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = AsyncBolt3(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -50,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = AsyncBolt3(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -59,14 +60,14 @@ def test_conn_is_not_stale(fake_socket, set_stale): def test_db_extra_not_supported_in_begin(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt3(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError): connection.begin(db="something") def test_db_extra_not_supported_in_run(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt3(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError): connection.run("", db="something") @@ -74,7 +75,7 @@ def test_db_extra_not_supported_in_run(fake_socket): @mark_async_test async def test_simple_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt3.UNPACKER_CLS) connection = AsyncBolt3(address, socket, PoolConfig.max_connection_lifetime) connection.discard() @@ -86,7 +87,7 @@ async def test_simple_discard(fake_socket): @mark_async_test async def test_simple_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt3.UNPACKER_CLS) connection = AsyncBolt3(address, socket, PoolConfig.max_connection_lifetime) connection.pull() @@ -101,7 +102,7 @@ async def test_simple_pull(fake_socket): async def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair( address, AsyncBolt3.PACKER_CLS, AsyncBolt3.UNPACKER_CLS ) @@ -117,6 +118,92 @@ async def test_hint_recv_timeout_seconds_gets_ignored( sockets.client.settimeout.assert_not_called() +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_async_test +async def test_credentials_are_not_logged( + auth, fake_socket_pair, mocker, caplog +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt3.PACKER_CLS, + unpacker_cls=AsyncBolt3.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + connection = AsyncBolt3( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + await connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +@pytest.mark.parametrize("message", ("logon", "logoff")) +def test_auth_message_raises_configuration_error(message, fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = AsyncBolt3(address, fake_socket(address), + PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError, + match="User switching is not supported"): + getattr(connection, message)() + + +@pytest.mark.parametrize("auth", ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), +)) +@mark_async_test +async def test_re_auth_noop(auth, fake_socket, mocker): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = AsyncBolt3(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth) + logon_spy = mocker.spy(connection, "logon") + logoff_spy = mocker.spy(connection, "logoff") + res = connection.re_auth(auth, None) + + assert res is False + logon_spy.assert_not_called() + logoff_spy.assert_not_called() + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + permutations( + ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), + ), + 2 + ) +) +@mark_async_test +async def test_re_auth(auth1, auth2, fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = AsyncBolt3(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth1) + with pytest.raises(ConfigurationError, + match="User switching is not supported"): + connection.re_auth(auth2, None) + + @pytest.mark.parametrize(("method", "args"), ( ("run", ("RETURN 1",)), ("begin", ()), @@ -132,7 +219,7 @@ async def test_hint_recv_timeout_seconds_gets_ignored( )) def test_does_not_support_notification_filters(fake_socket, method, args, kwargs): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt3.UNPACKER_CLS) connection = AsyncBolt3(address, socket, PoolConfig.max_connection_lifetime) @@ -154,7 +241,7 @@ def test_does_not_support_notification_filters(fake_socket, method, async def test_hello_does_not_support_notification_filters( fake_socket, kwargs ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt3.UNPACKER_CLS) connection = AsyncBolt3( address, socket, PoolConfig.max_connection_lifetime, @@ -162,65 +249,3 @@ async def test_hello_does_not_support_notification_filters( ) with pytest.raises(ConfigurationError, match="Notification filtering"): await connection.hello() - - -class HackedAuth: - def __init__(self, dict_): - self.__dict__ = dict_ - - -@mark_async_test -@pytest.mark.parametrize("auth", ( - ("awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - foo="bar"), - HackedAuth({ - "scheme": "super nice scheme", "principal": "awesome test user", - "credentials": "safe p4ssw0rd", "realm": "super duper realm", - "parameters": {"credentials": "should be visible!"}, - }) - -)) -async def test_hello_does_not_log_credentials(fake_socket_pair, caplog, auth): - def items(): - if isinstance(auth, tuple): - yield "scheme", "basic" - yield "principal", auth[0] - yield "credentials", auth[1] - elif isinstance(auth, Auth): - for key in ("scheme", "principal", "credentials", "realm", - "parameters"): - value = getattr(auth, key, None) - if value: - yield key, value - elif isinstance(auth, HackedAuth): - yield from auth.__dict__.items() - else: - raise TypeError(auth) - - address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address, - packer_cls=AsyncBolt3.PACKER_CLS, - unpacker_cls=AsyncBolt3.UNPACKER_CLS) - await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) - max_connection_lifetime = 0 - connection = AsyncBolt3(address, sockets.client, - max_connection_lifetime, auth=auth) - - with caplog.at_level(logging.DEBUG): - await connection.hello() - - hellos = [m for m in caplog.messages if "C: HELLO" in m] - assert len(hellos) == 1 - hello = hellos[0] - - for key, value in items(): - if key == "credentials": - assert value not in hello - else: - assert str({key: value})[1:-1] in hello diff --git a/tests/unit/async_/io/test_class_bolt4x0.py b/tests/unit/async_/io/test_class_bolt4x0.py index 3395ae18..ff1edcbd 100644 --- a/tests/unit/async_/io/test_class_bolt4x0.py +++ b/tests/unit/async_/io/test_class_bolt4x0.py @@ -17,12 +17,13 @@ import logging +from itertools import permutations import pytest +import neo4j from neo4j._async.io._bolt4 import AsyncBolt4x0 from neo4j._conf import PoolConfig -from neo4j.api import Auth from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_async_test @@ -30,7 +31,7 @@ @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = AsyncBolt4x0(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -40,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = AsyncBolt4x0(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -50,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = AsyncBolt4x0(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -60,7 +61,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_async_test async def test_db_extra_in_begin(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") @@ -73,7 +74,7 @@ async def test_db_extra_in_begin(fake_socket): @mark_async_test async def test_db_extra_in_run(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") @@ -88,7 +89,7 @@ async def test_db_extra_in_run(fake_socket): @mark_async_test async def test_n_extra_in_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) @@ -108,7 +109,7 @@ async def test_n_extra_in_discard(fake_socket): ) @mark_async_test async def test_qid_extra_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) @@ -128,7 +129,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): ) @mark_async_test async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) @@ -148,7 +149,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): ) @mark_async_test async def test_n_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) @@ -168,7 +169,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): ) @mark_async_test async def test_qid_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) @@ -181,7 +182,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) @@ -197,7 +198,7 @@ async def test_n_and_qid_extras_in_pull(fake_socket): async def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt4x0.PACKER_CLS, unpacker_cls=AsyncBolt4x0.UNPACKER_CLS) @@ -213,6 +214,92 @@ async def test_hint_recv_timeout_seconds_gets_ignored( sockets.client.settimeout.assert_not_called() +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_async_test +async def test_credentials_are_not_logged( + auth, fake_socket_pair, mocker, caplog +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x0.PACKER_CLS, + unpacker_cls=AsyncBolt4x0.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + connection = AsyncBolt4x0( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + await connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +@pytest.mark.parametrize("message", ("logon", "logoff")) +def test_auth_message_raises_configuration_error(message, fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = AsyncBolt4x0(address, fake_socket(address), + PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError, + match="User switching is not supported"): + getattr(connection, message)() + + +@pytest.mark.parametrize("auth", ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), +)) +@mark_async_test +async def test_re_auth_noop(auth, fake_socket, mocker): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = AsyncBolt4x0(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth) + logon_spy = mocker.spy(connection, "logon") + logoff_spy = mocker.spy(connection, "logoff") + res = connection.re_auth(auth, None) + + assert res is False + logon_spy.assert_not_called() + logoff_spy.assert_not_called() + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + permutations( + ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), + ), + 2 + ) +) +@mark_async_test +async def test_re_auth(auth1, auth2, fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = AsyncBolt4x0(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth1) + with pytest.raises(ConfigurationError, + match="User switching is not supported"): + connection.re_auth(auth2, None) + + @pytest.mark.parametrize(("method", "args"), ( ("run", ("RETURN 1",)), ("begin", ()), @@ -228,7 +315,7 @@ async def test_hint_recv_timeout_seconds_gets_ignored( )) def test_does_not_support_notification_filters(fake_socket, method, args, kwargs): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) @@ -250,7 +337,7 @@ def test_does_not_support_notification_filters(fake_socket, method, async def test_hello_does_not_support_notification_filters( fake_socket, kwargs ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0( address, socket, PoolConfig.max_connection_lifetime, @@ -258,65 +345,3 @@ async def test_hello_does_not_support_notification_filters( ) with pytest.raises(ConfigurationError, match="Notification filtering"): await connection.hello() - - -class HackedAuth: - def __init__(self, dict_): - self.__dict__ = dict_ - - -@mark_async_test -@pytest.mark.parametrize("auth", ( - ("awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - foo="bar"), - HackedAuth({ - "scheme": "super nice scheme", "principal": "awesome test user", - "credentials": "safe p4ssw0rd", "realm": "super duper realm", - "parameters": {"credentials": "should be visible!"}, - }) - -)) -async def test_hello_does_not_log_credentials(fake_socket_pair, caplog, auth): - def items(): - if isinstance(auth, tuple): - yield "scheme", "basic" - yield "principal", auth[0] - yield "credentials", auth[1] - elif isinstance(auth, Auth): - for key in ("scheme", "principal", "credentials", "realm", - "parameters"): - value = getattr(auth, key, None) - if value: - yield key, value - elif isinstance(auth, HackedAuth): - yield from auth.__dict__.items() - else: - raise TypeError(auth) - - address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address, - packer_cls=AsyncBolt4x0.PACKER_CLS, - unpacker_cls=AsyncBolt4x0.UNPACKER_CLS) - await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) - max_connection_lifetime = 0 - connection = AsyncBolt4x0(address, sockets.client, - max_connection_lifetime, auth=auth) - - with caplog.at_level(logging.DEBUG): - await connection.hello() - - hellos = [m for m in caplog.messages if "C: HELLO" in m] - assert len(hellos) == 1 - hello = hellos[0] - - for key, value in items(): - if key == "credentials": - assert value not in hello - else: - assert str({key: value})[1:-1] in hello diff --git a/tests/unit/async_/io/test_class_bolt4x1.py b/tests/unit/async_/io/test_class_bolt4x1.py index 6fb4e517..7766c559 100644 --- a/tests/unit/async_/io/test_class_bolt4x1.py +++ b/tests/unit/async_/io/test_class_bolt4x1.py @@ -17,12 +17,13 @@ import logging +from itertools import permutations import pytest +import neo4j from neo4j._async.io._bolt4 import AsyncBolt4x1 from neo4j._conf import PoolConfig -from neo4j.api import Auth from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_async_test @@ -30,7 +31,7 @@ @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = AsyncBolt4x1(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -40,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = AsyncBolt4x1(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -50,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = AsyncBolt4x1(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -60,7 +61,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_async_test async def test_db_extra_in_begin(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") @@ -73,7 +74,7 @@ async def test_db_extra_in_begin(fake_socket): @mark_async_test async def test_db_extra_in_run(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") @@ -88,7 +89,7 @@ async def test_db_extra_in_run(fake_socket): @mark_async_test async def test_n_extra_in_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) @@ -108,7 +109,7 @@ async def test_n_extra_in_discard(fake_socket): ) @mark_async_test async def test_qid_extra_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) @@ -128,7 +129,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): ) @mark_async_test async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) @@ -148,7 +149,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): ) @mark_async_test async def test_n_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) @@ -168,7 +169,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): ) @mark_async_test async def test_qid_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) @@ -181,7 +182,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) @@ -194,7 +195,7 @@ async def test_n_and_qid_extras_in_pull(fake_socket): @mark_async_test async def test_hello_passes_routing_metadata(fake_socket_pair): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt4x1.PACKER_CLS, unpacker_cls=AsyncBolt4x1.UNPACKER_CLS) @@ -215,7 +216,7 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): async def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt4x1.PACKER_CLS, unpacker_cls=AsyncBolt4x1.UNPACKER_CLS) @@ -230,6 +231,92 @@ async def test_hint_recv_timeout_seconds_gets_ignored( sockets.client.settimeout.assert_not_called() +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_async_test +async def test_credentials_are_not_logged( + auth, fake_socket_pair, mocker, caplog +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x1.PACKER_CLS, + unpacker_cls=AsyncBolt4x1.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + connection = AsyncBolt4x1( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + await connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +@pytest.mark.parametrize("message", ("logon", "logoff")) +def test_auth_message_raises_configuration_error(message, fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = AsyncBolt4x1(address, fake_socket(address), + PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError, + match="User switching is not supported"): + getattr(connection, message)() + + +@pytest.mark.parametrize("auth", ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), +)) +@mark_async_test +async def test_re_auth_noop(auth, fake_socket, mocker): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = AsyncBolt4x1(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth) + logon_spy = mocker.spy(connection, "logon") + logoff_spy = mocker.spy(connection, "logoff") + res = connection.re_auth(auth, None) + + assert res is False + logon_spy.assert_not_called() + logoff_spy.assert_not_called() + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + permutations( + ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), + ), + 2 + ) +) +@mark_async_test +async def test_re_auth(auth1, auth2, fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = AsyncBolt4x1(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth1) + with pytest.raises(ConfigurationError, + match="User switching is not supported"): + connection.re_auth(auth2, None) + + @pytest.mark.parametrize(("method", "args"), ( ("run", ("RETURN 1",)), ("begin", ()), @@ -245,7 +332,7 @@ async def test_hint_recv_timeout_seconds_gets_ignored( )) def test_does_not_support_notification_filters(fake_socket, method, args, kwargs): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) @@ -267,7 +354,7 @@ def test_does_not_support_notification_filters(fake_socket, method, async def test_hello_does_not_support_notification_filters( fake_socket, kwargs ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1( address, socket, PoolConfig.max_connection_lifetime, @@ -275,65 +362,3 @@ async def test_hello_does_not_support_notification_filters( ) with pytest.raises(ConfigurationError, match="Notification filtering"): await connection.hello() - - -class HackedAuth: - def __init__(self, dict_): - self.__dict__ = dict_ - - -@mark_async_test -@pytest.mark.parametrize("auth", ( - ("awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - foo="bar"), - HackedAuth({ - "scheme": "super nice scheme", "principal": "awesome test user", - "credentials": "safe p4ssw0rd", "realm": "super duper realm", - "parameters": {"credentials": "should be visible!"}, - }) - -)) -async def test_hello_does_not_log_credentials(fake_socket_pair, caplog, auth): - def items(): - if isinstance(auth, tuple): - yield "scheme", "basic" - yield "principal", auth[0] - yield "credentials", auth[1] - elif isinstance(auth, Auth): - for key in ("scheme", "principal", "credentials", "realm", - "parameters"): - value = getattr(auth, key, None) - if value: - yield key, value - elif isinstance(auth, HackedAuth): - yield from auth.__dict__.items() - else: - raise TypeError(auth) - - address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address, - packer_cls=AsyncBolt4x1.PACKER_CLS, - unpacker_cls=AsyncBolt4x1.UNPACKER_CLS) - await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) - max_connection_lifetime = 0 - connection = AsyncBolt4x1(address, sockets.client, - max_connection_lifetime, auth=auth) - - with caplog.at_level(logging.DEBUG): - await connection.hello() - - hellos = [m for m in caplog.messages if "C: HELLO" in m] - assert len(hellos) == 1 - hello = hellos[0] - - for key, value in items(): - if key == "credentials": - assert value not in hello - else: - assert str({key: value})[1:-1] in hello diff --git a/tests/unit/async_/io/test_class_bolt4x2.py b/tests/unit/async_/io/test_class_bolt4x2.py index 0d96ad0d..0547606a 100644 --- a/tests/unit/async_/io/test_class_bolt4x2.py +++ b/tests/unit/async_/io/test_class_bolt4x2.py @@ -17,12 +17,13 @@ import logging +from itertools import permutations import pytest +import neo4j from neo4j._async.io._bolt4 import AsyncBolt4x2 from neo4j._conf import PoolConfig -from neo4j.api import Auth from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_async_test @@ -30,7 +31,7 @@ @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = AsyncBolt4x2(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -40,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = AsyncBolt4x2(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -50,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = AsyncBolt4x2(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -60,7 +61,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_async_test async def test_db_extra_in_begin(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") @@ -73,7 +74,7 @@ async def test_db_extra_in_begin(fake_socket): @mark_async_test async def test_db_extra_in_run(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") @@ -88,7 +89,7 @@ async def test_db_extra_in_run(fake_socket): @mark_async_test async def test_n_extra_in_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) @@ -108,7 +109,7 @@ async def test_n_extra_in_discard(fake_socket): ) @mark_async_test async def test_qid_extra_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) @@ -128,7 +129,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): ) @mark_async_test async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) @@ -148,7 +149,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): ) @mark_async_test async def test_n_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) @@ -168,7 +169,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): ) @mark_async_test async def test_qid_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) @@ -181,7 +182,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) @@ -194,7 +195,7 @@ async def test_n_and_qid_extras_in_pull(fake_socket): @mark_async_test async def test_hello_passes_routing_metadata(fake_socket_pair): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt4x2.PACKER_CLS, unpacker_cls=AsyncBolt4x2.UNPACKER_CLS) @@ -215,7 +216,7 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): async def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt4x2.PACKER_CLS, unpacker_cls=AsyncBolt4x2.UNPACKER_CLS) @@ -231,6 +232,92 @@ async def test_hint_recv_timeout_seconds_gets_ignored( sockets.client.settimeout.assert_not_called() +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_async_test +async def test_credentials_are_not_logged( + auth, fake_socket_pair, mocker, caplog +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x2.PACKER_CLS, + unpacker_cls=AsyncBolt4x2.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + connection = AsyncBolt4x2( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + await connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +@pytest.mark.parametrize("message", ("logon", "logoff")) +def test_auth_message_raises_configuration_error(message, fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = AsyncBolt4x2(address, fake_socket(address), + PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError, + match="User switching is not supported"): + getattr(connection, message)() + + +@pytest.mark.parametrize("auth", ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), +)) +@mark_async_test +async def test_re_auth_noop(auth, fake_socket, mocker): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = AsyncBolt4x2(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth) + logon_spy = mocker.spy(connection, "logon") + logoff_spy = mocker.spy(connection, "logoff") + res = connection.re_auth(auth, None) + + assert res is False + logon_spy.assert_not_called() + logoff_spy.assert_not_called() + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + permutations( + ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), + ), + 2 + ) +) +@mark_async_test +async def test_re_auth(auth1, auth2, fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = AsyncBolt4x2(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth1) + with pytest.raises(ConfigurationError, + match="User switching is not supported"): + connection.re_auth(auth2, None) + + @pytest.mark.parametrize(("method", "args"), ( ("run", ("RETURN 1",)), ("begin", ()), @@ -246,7 +333,7 @@ async def test_hint_recv_timeout_seconds_gets_ignored( )) def test_does_not_support_notification_filters(fake_socket, method, args, kwargs): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) @@ -268,7 +355,7 @@ def test_does_not_support_notification_filters(fake_socket, method, async def test_hello_does_not_support_notification_filters( fake_socket, kwargs ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2( address, socket, PoolConfig.max_connection_lifetime, @@ -276,65 +363,3 @@ async def test_hello_does_not_support_notification_filters( ) with pytest.raises(ConfigurationError, match="Notification filtering"): await connection.hello() - - -class HackedAuth: - def __init__(self, dict_): - self.__dict__ = dict_ - - -@mark_async_test -@pytest.mark.parametrize("auth", ( - ("awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - foo="bar"), - HackedAuth({ - "scheme": "super nice scheme", "principal": "awesome test user", - "credentials": "safe p4ssw0rd", "realm": "super duper realm", - "parameters": {"credentials": "should be visible!"}, - }) - -)) -async def test_hello_does_not_log_credentials(fake_socket_pair, caplog, auth): - def items(): - if isinstance(auth, tuple): - yield "scheme", "basic" - yield "principal", auth[0] - yield "credentials", auth[1] - elif isinstance(auth, Auth): - for key in ("scheme", "principal", "credentials", "realm", - "parameters"): - value = getattr(auth, key, None) - if value: - yield key, value - elif isinstance(auth, HackedAuth): - yield from auth.__dict__.items() - else: - raise TypeError(auth) - - address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address, - packer_cls=AsyncBolt4x2.PACKER_CLS, - unpacker_cls=AsyncBolt4x2.UNPACKER_CLS) - await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) - max_connection_lifetime = 0 - connection = AsyncBolt4x2(address, sockets.client, - max_connection_lifetime, auth=auth) - - with caplog.at_level(logging.DEBUG): - await connection.hello() - - hellos = [m for m in caplog.messages if "C: HELLO" in m] - assert len(hellos) == 1 - hello = hellos[0] - - for key, value in items(): - if key == "credentials": - assert value not in hello - else: - assert str({key: value})[1:-1] in hello diff --git a/tests/unit/async_/io/test_class_bolt4x3.py b/tests/unit/async_/io/test_class_bolt4x3.py index cf141913..dca81144 100644 --- a/tests/unit/async_/io/test_class_bolt4x3.py +++ b/tests/unit/async_/io/test_class_bolt4x3.py @@ -17,12 +17,13 @@ import logging +from itertools import permutations import pytest +import neo4j from neo4j._async.io._bolt4 import AsyncBolt4x3 from neo4j._conf import PoolConfig -from neo4j.api import Auth from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_async_test @@ -30,7 +31,7 @@ @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = AsyncBolt4x3(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -40,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = AsyncBolt4x3(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -50,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = AsyncBolt4x3(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -60,7 +61,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_async_test async def test_db_extra_in_begin(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") @@ -73,7 +74,7 @@ async def test_db_extra_in_begin(fake_socket): @mark_async_test async def test_db_extra_in_run(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") @@ -88,7 +89,7 @@ async def test_db_extra_in_run(fake_socket): @mark_async_test async def test_n_extra_in_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) @@ -108,7 +109,7 @@ async def test_n_extra_in_discard(fake_socket): ) @mark_async_test async def test_qid_extra_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) @@ -128,7 +129,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): ) @mark_async_test async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) @@ -148,7 +149,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): ) @mark_async_test async def test_n_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) @@ -168,7 +169,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): ) @mark_async_test async def test_qid_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) @@ -181,7 +182,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) @@ -194,7 +195,7 @@ async def test_n_and_qid_extras_in_pull(fake_socket): @mark_async_test async def test_hello_passes_routing_metadata(fake_socket_pair): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt4x3.PACKER_CLS, unpacker_cls=AsyncBolt4x3.UNPACKER_CLS) @@ -226,7 +227,7 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): async def test_hint_recv_timeout_seconds( fake_socket_pair, hints, valid, caplog, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt4x3.PACKER_CLS, unpacker_cls=AsyncBolt4x3.UNPACKER_CLS) @@ -257,6 +258,93 @@ async def test_hint_recv_timeout_seconds( for msg in caplog.messages) +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_async_test +async def test_credentials_are_not_logged( + auth, fake_socket_pair, mocker, caplog +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x3.PACKER_CLS, + unpacker_cls=AsyncBolt4x3.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + connection = AsyncBolt4x3( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + await connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + + +@pytest.mark.parametrize("message", ("logon", "logoff")) +def test_auth_message_raises_configuration_error(message, fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = AsyncBolt4x3(address, fake_socket(address), + PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError, + match="User switching is not supported"): + getattr(connection, message)() + + +@pytest.mark.parametrize("auth", ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), +)) +@mark_async_test +async def test_re_auth_noop(auth, fake_socket, mocker): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = AsyncBolt4x3(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth) + logon_spy = mocker.spy(connection, "logon") + logoff_spy = mocker.spy(connection, "logoff") + res = connection.re_auth(auth, None) + + assert res is False + logon_spy.assert_not_called() + logoff_spy.assert_not_called() + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + permutations( + ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), + ), + 2 + ) +) +@mark_async_test +async def test_re_auth(auth1, auth2, fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = AsyncBolt4x3(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth1) + with pytest.raises(ConfigurationError, + match="User switching is not supported"): + connection.re_auth(auth2, None) + + @pytest.mark.parametrize(("method", "args"), ( ("run", ("RETURN 1",)), ("begin", ()), @@ -272,7 +360,7 @@ async def test_hint_recv_timeout_seconds( )) def test_does_not_support_notification_filters(fake_socket, method, args, kwargs): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) @@ -294,7 +382,7 @@ def test_does_not_support_notification_filters(fake_socket, method, async def test_hello_does_not_support_notification_filters( fake_socket, kwargs ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3( address, socket, PoolConfig.max_connection_lifetime, @@ -302,65 +390,3 @@ async def test_hello_does_not_support_notification_filters( ) with pytest.raises(ConfigurationError, match="Notification filtering"): await connection.hello() - - -class HackedAuth: - def __init__(self, dict_): - self.__dict__ = dict_ - - -@mark_async_test -@pytest.mark.parametrize("auth", ( - ("awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - foo="bar"), - HackedAuth({ - "scheme": "super nice scheme", "principal": "awesome test user", - "credentials": "safe p4ssw0rd", "realm": "super duper realm", - "parameters": {"credentials": "should be visible!"}, - }) - -)) -async def test_hello_does_not_log_credentials(fake_socket_pair, caplog, auth): - def items(): - if isinstance(auth, tuple): - yield "scheme", "basic" - yield "principal", auth[0] - yield "credentials", auth[1] - elif isinstance(auth, Auth): - for key in ("scheme", "principal", "credentials", "realm", - "parameters"): - value = getattr(auth, key, None) - if value: - yield key, value - elif isinstance(auth, HackedAuth): - yield from auth.__dict__.items() - else: - raise TypeError(auth) - - address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address, - packer_cls=AsyncBolt4x3.PACKER_CLS, - unpacker_cls=AsyncBolt4x3.UNPACKER_CLS) - await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) - max_connection_lifetime = 0 - connection = AsyncBolt4x3(address, sockets.client, - max_connection_lifetime, auth=auth) - - with caplog.at_level(logging.DEBUG): - await connection.hello() - - hellos = [m for m in caplog.messages if "C: HELLO" in m] - assert len(hellos) == 1 - hello = hellos[0] - - for key, value in items(): - if key == "credentials": - assert value not in hello - else: - assert str({key: value})[1:-1] in hello diff --git a/tests/unit/async_/io/test_class_bolt4x4.py b/tests/unit/async_/io/test_class_bolt4x4.py index 5dacc747..19cf8b5c 100644 --- a/tests/unit/async_/io/test_class_bolt4x4.py +++ b/tests/unit/async_/io/test_class_bolt4x4.py @@ -17,12 +17,13 @@ import logging +from itertools import permutations import pytest +import neo4j from neo4j._async.io._bolt4 import AsyncBolt4x4 from neo4j._conf import PoolConfig -from neo4j.api import Auth from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_async_test @@ -30,7 +31,7 @@ @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = AsyncBolt4x4(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -40,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = AsyncBolt4x4(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -50,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = AsyncBolt4x4(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -69,7 +70,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): )) @mark_async_test async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.begin(*args, **kwargs) @@ -90,7 +91,7 @@ async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): )) @mark_async_test async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.run(*args, **kwargs) @@ -102,7 +103,7 @@ async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): @mark_async_test async def test_n_extra_in_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) @@ -122,7 +123,7 @@ async def test_n_extra_in_discard(fake_socket): ) @mark_async_test async def test_qid_extra_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) @@ -142,7 +143,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): ) @mark_async_test async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) @@ -162,7 +163,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): ) @mark_async_test async def test_n_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) @@ -182,7 +183,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): ) @mark_async_test async def test_qid_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) @@ -195,7 +196,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) @@ -208,7 +209,7 @@ async def test_n_and_qid_extras_in_pull(fake_socket): @mark_async_test async def test_hello_passes_routing_metadata(fake_socket_pair): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt4x4.PACKER_CLS, unpacker_cls=AsyncBolt4x4.UNPACKER_CLS) @@ -240,7 +241,7 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): async def test_hint_recv_timeout_seconds( fake_socket_pair, hints, valid, caplog, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt4x4.PACKER_CLS, unpacker_cls=AsyncBolt4x4.UNPACKER_CLS) @@ -271,6 +272,92 @@ async def test_hint_recv_timeout_seconds( for msg in caplog.messages) +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_async_test +async def test_credentials_are_not_logged( + auth, fake_socket_pair, mocker, caplog +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x4.PACKER_CLS, + unpacker_cls=AsyncBolt4x4.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + connection = AsyncBolt4x4( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + await connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +@pytest.mark.parametrize("message", ("logon", "logoff")) +def test_auth_message_raises_configuration_error(message, fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = AsyncBolt4x4(address, fake_socket(address), + PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError, + match="User switching is not supported"): + getattr(connection, message)() + + +@pytest.mark.parametrize("auth", ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), +)) +@mark_async_test +async def test_re_auth_noop(auth, fake_socket, mocker): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = AsyncBolt4x4(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth) + logon_spy = mocker.spy(connection, "logon") + logoff_spy = mocker.spy(connection, "logoff") + res = connection.re_auth(auth, None) + + assert res is False + logon_spy.assert_not_called() + logoff_spy.assert_not_called() + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + permutations( + ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), + ), + 2 + ) +) +@mark_async_test +async def test_re_auth(auth1, auth2, fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = AsyncBolt4x4(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth1) + with pytest.raises(ConfigurationError, + match="User switching is not supported"): + connection.re_auth(auth2, None) + + @pytest.mark.parametrize(("method", "args"), ( ("run", ("RETURN 1",)), ("begin", ()), @@ -286,7 +373,7 @@ async def test_hint_recv_timeout_seconds( )) def test_does_not_support_notification_filters(fake_socket, method, args, kwargs): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) @@ -308,7 +395,7 @@ def test_does_not_support_notification_filters(fake_socket, method, async def test_hello_does_not_support_notification_filters( fake_socket, kwargs ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4( address, socket, PoolConfig.max_connection_lifetime, @@ -316,65 +403,3 @@ async def test_hello_does_not_support_notification_filters( ) with pytest.raises(ConfigurationError, match="Notification filtering"): await connection.hello() - - -class HackedAuth: - def __init__(self, dict_): - self.__dict__ = dict_ - - -@mark_async_test -@pytest.mark.parametrize("auth", ( - ("awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - foo="bar"), - HackedAuth({ - "scheme": "super nice scheme", "principal": "awesome test user", - "credentials": "safe p4ssw0rd", "realm": "super duper realm", - "parameters": {"credentials": "should be visible!"}, - }) - -)) -async def test_hello_does_not_log_credentials(fake_socket_pair, caplog, auth): - def items(): - if isinstance(auth, tuple): - yield "scheme", "basic" - yield "principal", auth[0] - yield "credentials", auth[1] - elif isinstance(auth, Auth): - for key in ("scheme", "principal", "credentials", "realm", - "parameters"): - value = getattr(auth, key, None) - if value: - yield key, value - elif isinstance(auth, HackedAuth): - yield from auth.__dict__.items() - else: - raise TypeError(auth) - - address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address, - packer_cls=AsyncBolt4x4.PACKER_CLS, - unpacker_cls=AsyncBolt4x4.UNPACKER_CLS) - await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) - max_connection_lifetime = 0 - connection = AsyncBolt4x4(address, sockets.client, - max_connection_lifetime, auth=auth) - - with caplog.at_level(logging.DEBUG): - await connection.hello() - - hellos = [m for m in caplog.messages if "C: HELLO" in m] - assert len(hellos) == 1 - hello = hellos[0] - - for key, value in items(): - if key == "credentials": - assert value not in hello - else: - assert str({key: value})[1:-1] in hello diff --git a/tests/unit/async_/io/test_class_bolt5x0.py b/tests/unit/async_/io/test_class_bolt5x0.py index bf9cf372..2937d442 100644 --- a/tests/unit/async_/io/test_class_bolt5x0.py +++ b/tests/unit/async_/io/test_class_bolt5x0.py @@ -17,12 +17,13 @@ import logging +from itertools import permutations import pytest +import neo4j from neo4j._async.io._bolt5 import AsyncBolt5x0 from neo4j._conf import PoolConfig -from neo4j.api import Auth from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_async_test @@ -30,7 +31,7 @@ @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = AsyncBolt5x0(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -40,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = AsyncBolt5x0(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -50,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = AsyncBolt5x0(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -69,7 +70,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): )) @mark_async_test async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.begin(*args, **kwargs) @@ -90,7 +91,7 @@ async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): )) @mark_async_test async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.run(*args, **kwargs) @@ -102,7 +103,7 @@ async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): @mark_async_test async def test_n_extra_in_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) @@ -122,7 +123,7 @@ async def test_n_extra_in_discard(fake_socket): ) @mark_async_test async def test_qid_extra_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) @@ -142,7 +143,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): ) @mark_async_test async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) @@ -162,7 +163,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): ) @mark_async_test async def test_n_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) @@ -182,7 +183,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): ) @mark_async_test async def test_qid_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) @@ -195,7 +196,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) @@ -208,7 +209,7 @@ async def test_n_and_qid_extras_in_pull(fake_socket): @mark_async_test async def test_hello_passes_routing_metadata(fake_socket_pair): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt5x0.PACKER_CLS, unpacker_cls=AsyncBolt5x0.UNPACKER_CLS) @@ -240,7 +241,7 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): async def test_hint_recv_timeout_seconds( fake_socket_pair, hints, valid, caplog, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt5x0.PACKER_CLS, unpacker_cls=AsyncBolt5x0.UNPACKER_CLS) @@ -271,6 +272,92 @@ async def test_hint_recv_timeout_seconds( for msg in caplog.messages) +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_async_test +async def test_credentials_are_not_logged( + auth, fake_socket_pair, mocker, caplog +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x0.PACKER_CLS, + unpacker_cls=AsyncBolt5x0.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + connection = AsyncBolt5x0( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + await connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +@pytest.mark.parametrize("message", ("logon", "logoff")) +def test_auth_message_raises_configuration_error(message, fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = AsyncBolt5x0(address, fake_socket(address), + PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError, + match="User switching is not supported"): + getattr(connection, message)() + + +@pytest.mark.parametrize("auth", ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), +)) +@mark_async_test +async def test_re_auth_noop(auth, fake_socket, mocker): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = AsyncBolt5x0(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth) + logon_spy = mocker.spy(connection, "logon") + logoff_spy = mocker.spy(connection, "logoff") + res = connection.re_auth(auth, None) + + assert res is False + logon_spy.assert_not_called() + logoff_spy.assert_not_called() + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + permutations( + ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), + ), + 2 + ) +) +@mark_async_test +async def test_re_auth(auth1, auth2, fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = AsyncBolt5x0(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth1) + with pytest.raises(ConfigurationError, + match="User switching is not supported"): + connection.re_auth(auth2, None) + + @pytest.mark.parametrize(("method", "args"), ( ("run", ("RETURN 1",)), ("begin", ()), @@ -286,7 +373,7 @@ async def test_hint_recv_timeout_seconds( )) def test_does_not_support_notification_filters(fake_socket, method, args, kwargs): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) @@ -308,7 +395,7 @@ def test_does_not_support_notification_filters(fake_socket, method, async def test_hello_does_not_support_notification_filters( fake_socket, kwargs ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) connection = AsyncBolt5x0( address, socket, PoolConfig.max_connection_lifetime, @@ -316,65 +403,3 @@ async def test_hello_does_not_support_notification_filters( ) with pytest.raises(ConfigurationError, match="Notification filtering"): await connection.hello() - - -class HackedAuth: - def __init__(self, dict_): - self.__dict__ = dict_ - - -@mark_async_test -@pytest.mark.parametrize("auth", ( - ("awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - foo="bar"), - HackedAuth({ - "scheme": "super nice scheme", "principal": "awesome test user", - "credentials": "safe p4ssw0rd", "realm": "super duper realm", - "parameters": {"credentials": "should be visible!"}, - }) - -)) -async def test_hello_does_not_log_credentials(fake_socket_pair, caplog, auth): - def items(): - if isinstance(auth, tuple): - yield "scheme", "basic" - yield "principal", auth[0] - yield "credentials", auth[1] - elif isinstance(auth, Auth): - for key in ("scheme", "principal", "credentials", "realm", - "parameters"): - value = getattr(auth, key, None) - if value: - yield key, value - elif isinstance(auth, HackedAuth): - yield from auth.__dict__.items() - else: - raise TypeError(auth) - - address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address, - packer_cls=AsyncBolt5x0.PACKER_CLS, - unpacker_cls=AsyncBolt5x0.UNPACKER_CLS) - await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) - max_connection_lifetime = 0 - connection = AsyncBolt5x0(address, sockets.client, - max_connection_lifetime, auth=auth) - - with caplog.at_level(logging.DEBUG): - await connection.hello() - - hellos = [m for m in caplog.messages if "C: HELLO" in m] - assert len(hellos) == 1 - hello = hellos[0] - - for key, value in items(): - if key == "credentials": - assert value not in hello - else: - assert str({key: value})[1:-1] in hello diff --git a/tests/unit/async_/io/test_class_bolt5x1.py b/tests/unit/async_/io/test_class_bolt5x1.py index 142b1f91..83dab776 100644 --- a/tests/unit/async_/io/test_class_bolt5x1.py +++ b/tests/unit/async_/io/test_class_bolt5x1.py @@ -20,16 +20,397 @@ import pytest +import neo4j +import neo4j.exceptions from neo4j._async.io._bolt5 import AsyncBolt5x1 from neo4j._conf import PoolConfig -from neo4j.api import Auth +from neo4j.auth_management import AsyncAuthManagers from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_async_test -# TODO: proper testing should come from the re-auth ADR, -# which properly introduces Bolt 5.1 +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 0 + connection = AsyncBolt5x1(address, fake_socket(address), + max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = -1 + connection = AsyncBolt5x1(address, fake_socket(address), + max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 999999999 + connection = AsyncBolt5x1(address, fake_socket(address), + max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},) + ), +)) +@mark_async_test +async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) + connection = AsyncBolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.begin(*args, **kwargs) + await connection.send_all() + tag, is_fields = await socket.pop_message() + assert tag == b"\x11" + assert tuple(is_fields) == expected_fields + + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + (("", {}), {"imp_user": "imposter"}, ("", {}, {"imp_user": "imposter"})), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}) + ), +)) +@mark_async_test +async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) + connection = AsyncBolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.run(*args, **kwargs) + await connection.send_all() + tag, is_fields = await socket.pop_message() + assert tag == b"\x10" + assert tuple(is_fields) == expected_fields + + +@mark_async_test +async def test_n_extra_in_discard(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) + connection = AsyncBolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.discard(n=666) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) + connection = AsyncBolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.discard(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ] +) +@mark_async_test +async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) + connection = AsyncBolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.discard(n=666, qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_n_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) + connection = AsyncBolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.pull(n=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_qid_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) + connection = AsyncBolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.pull(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_async_test +async def test_n_and_qid_extras_in_pull(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) + connection = AsyncBolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.pull(n=666, qid=777) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_async_test +async def test_hello_passes_routing_metadata(fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x1.PACKER_CLS, + unpacker_cls=AsyncBolt5x1.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x1( + address, sockets.client, PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"} + ) + await connection.hello() + tag, fields = await sockets.server.pop_message() + assert tag == b"\x01" + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +async def _assert_logon_message(sockets, auth): + tag, fields = await sockets.server.pop_message() + assert tag == b"\x6A" # LOGON + assert len(fields) == 1 + keys = ["scheme", "principal", "credentials"] + assert list(fields[0].keys()) == keys + for key in keys: + assert fields[0][key] == getattr(auth, key) + + +@mark_async_test +async def test_hello_pipelines_logon(fake_socket_pair): + auth = neo4j.Auth("basic", "alice123", "supersecret123") + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x1.PACKER_CLS, + unpacker_cls=AsyncBolt5x1.UNPACKER_CLS) + await sockets.server.send_message( + b"\x7F", {"code": "Neo.DatabaseError.General.MadeUpError", + "message": "kthxbye"} + ) + connection = AsyncBolt5x1( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with pytest.raises(neo4j.exceptions.Neo4jError): + await connection.hello() + tag, fields = await sockets.server.pop_message() + assert tag == b"\x01" # HELLO + assert len(fields) == 1 + assert list(fields[0].keys()) == ["user_agent"] + assert auth.credentials not in repr(fields) + await _assert_logon_message(sockets, auth) + + +@mark_async_test +async def test_logon(fake_socket_pair): + auth = neo4j.Auth("basic", "alice123", "supersecret123") + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x1.PACKER_CLS, + unpacker_cls=AsyncBolt5x1.UNPACKER_CLS) + connection = AsyncBolt5x1(address, sockets.client, + PoolConfig.max_connection_lifetime, auth=auth) + connection.logon() + await connection.send_all() + await _assert_logon_message(sockets, auth) + + +@mark_async_test +async def test_re_auth(fake_socket_pair, mocker, static_auth): + auth = neo4j.Auth("basic", "alice123", "supersecret123") + auth_manager = static_auth(auth) + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x1.PACKER_CLS, + unpacker_cls=AsyncBolt5x1.UNPACKER_CLS) + await sockets.server.send_message( + b"\x7F", {"code": "Neo.DatabaseError.General.MadeUpError", + "message": "kthxbye"} + ) + connection = AsyncBolt5x1(address, sockets.client, + PoolConfig.max_connection_lifetime) + connection.pool = mocker.AsyncMock() + connection.re_auth(auth, auth_manager) + await connection.send_all() + with pytest.raises(neo4j.exceptions.Neo4jError): + await connection.fetch_all() + tag, fields = await sockets.server.pop_message() + assert tag == b"\x6B" # LOGOFF + assert len(fields) == 0 + await _assert_logon_message(sockets, auth) + assert connection.auth is auth + assert connection.auth_manager is auth_manager + + +@mark_async_test +async def test_logoff(fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x1.PACKER_CLS, + unpacker_cls=AsyncBolt5x1.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x1(address, sockets.client, + PoolConfig.max_connection_lifetime) + connection.logoff() + assert not sockets.server.recv_buffer # pipelined, so no response yet + await connection.send_all() + assert sockets.server.recv_buffer # now! + tag, fields = await sockets.server.pop_message() + assert tag == b"\x6B" # LOGOFF + assert len(fields) == 0 + + +@pytest.mark.parametrize(("hints", "valid"), ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), +)) +@mark_async_test +async def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog, mocker +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x1.PACKER_CLS, + unpacker_cls=AsyncBolt5x1.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + await sockets.server.send_message( + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} + ) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x1( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + with caplog.at_level(logging.INFO): + await connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any("recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + else: + sockets.client.settimeout.assert_not_called() + assert any(repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_async_test +async def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x1.PACKER_CLS, + unpacker_cls=AsyncBolt5x1.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x1( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + await connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text @pytest.mark.parametrize(("method", "args"), ( @@ -47,7 +428,7 @@ )) def test_does_not_support_notification_filters(fake_socket, method, args, kwargs): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) connection = AsyncBolt5x1(address, socket, PoolConfig.max_connection_lifetime) @@ -69,7 +450,7 @@ def test_does_not_support_notification_filters(fake_socket, method, async def test_hello_does_not_support_notification_filters( fake_socket, kwargs ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) connection = AsyncBolt5x1( address, socket, PoolConfig.max_connection_lifetime, @@ -77,66 +458,3 @@ async def test_hello_does_not_support_notification_filters( ) with pytest.raises(ConfigurationError, match="Notification filtering"): await connection.hello() - - -class HackedAuth: - def __init__(self, dict_): - self.__dict__ = dict_ - - -@mark_async_test -@pytest.mark.parametrize("auth", ( - ("awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - foo="bar"), - HackedAuth({ - "scheme": "super nice scheme", "principal": "awesome test user", - "credentials": "safe p4ssw0rd", "realm": "super duper realm", - "parameters": {"credentials": "should be visible!"}, - }) - -)) -async def test_hello_does_not_log_credentials(fake_socket_pair, caplog, auth): - def items(): - if isinstance(auth, tuple): - yield "scheme", "basic" - yield "principal", auth[0] - yield "credentials", auth[1] - elif isinstance(auth, Auth): - for key in ("scheme", "principal", "credentials", "realm", - "parameters"): - value = getattr(auth, key, None) - if value: - yield key, value - elif isinstance(auth, HackedAuth): - yield from auth.__dict__.items() - else: - raise TypeError(auth) - - address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address, - packer_cls=AsyncBolt5x1.PACKER_CLS, - unpacker_cls=AsyncBolt5x1.UNPACKER_CLS) - await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) - await sockets.server.send_message(b"\x70", {}) - max_connection_lifetime = 0 - connection = AsyncBolt5x1(address, sockets.client, - max_connection_lifetime, auth=auth) - - with caplog.at_level(logging.DEBUG): - await connection.hello() - - logons = [m for m in caplog.messages if "C: LOGON " in m] - assert len(logons) == 1 - logon = logons[0] - - for key, value in items(): - if key == "credentials": - assert value not in logon - else: - assert str({key: value})[1:-1] in logon diff --git a/tests/unit/async_/io/test_class_bolt5x2.py b/tests/unit/async_/io/test_class_bolt5x2.py index 1a00c955..676eb5b9 100644 --- a/tests/unit/async_/io/test_class_bolt5x2.py +++ b/tests/unit/async_/io/test_class_bolt5x2.py @@ -14,21 +14,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + import itertools import logging import pytest +import neo4j from neo4j._async.io._bolt5 import AsyncBolt5x2 from neo4j._conf import PoolConfig -from neo4j.api import Auth +from neo4j.auth_management import AsyncAuthManagers from ...._async_compat import mark_async_test @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = AsyncBolt5x2(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -38,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = AsyncBolt5x2(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -48,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = AsyncBolt5x2(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -67,7 +70,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): )) @mark_async_test async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x2.UNPACKER_CLS) connection = AsyncBolt5x2(address, socket, PoolConfig.max_connection_lifetime) connection.begin(*args, **kwargs) @@ -88,7 +91,7 @@ async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): )) @mark_async_test async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x2.UNPACKER_CLS) connection = AsyncBolt5x2(address, socket, PoolConfig.max_connection_lifetime) connection.run(*args, **kwargs) @@ -100,7 +103,7 @@ async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): @mark_async_test async def test_n_extra_in_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x2.UNPACKER_CLS) connection = AsyncBolt5x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) @@ -120,7 +123,7 @@ async def test_n_extra_in_discard(fake_socket): ) @mark_async_test async def test_qid_extra_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x2.UNPACKER_CLS) connection = AsyncBolt5x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) @@ -140,7 +143,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): ) @mark_async_test async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x2.UNPACKER_CLS) connection = AsyncBolt5x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) @@ -160,7 +163,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): ) @mark_async_test async def test_n_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x2.UNPACKER_CLS) connection = AsyncBolt5x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) @@ -180,7 +183,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): ) @mark_async_test async def test_qid_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x2.UNPACKER_CLS) connection = AsyncBolt5x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) @@ -193,7 +196,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x2.UNPACKER_CLS) connection = AsyncBolt5x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) @@ -206,7 +209,7 @@ async def test_n_and_qid_extras_in_pull(fake_socket): @mark_async_test async def test_hello_passes_routing_metadata(fake_socket_pair): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt5x2.PACKER_CLS, unpacker_cls=AsyncBolt5x2.UNPACKER_CLS) @@ -223,6 +226,99 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): assert fields[0]["routing"] == {"foo": "bar"} +async def _assert_logon_message(sockets, auth): + tag, fields = await sockets.server.pop_message() + assert tag == b"\x6A" # LOGON + assert len(fields) == 1 + keys = ["scheme", "principal", "credentials"] + assert list(fields[0].keys()) == keys + for key in keys: + assert fields[0][key] == getattr(auth, key) + + +@mark_async_test +async def test_hello_pipelines_logon(fake_socket_pair): + auth = neo4j.Auth("basic", "alice123", "supersecret123") + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x2.PACKER_CLS, + unpacker_cls=AsyncBolt5x2.UNPACKER_CLS) + await sockets.server.send_message( + b"\x7F", {"code": "Neo.DatabaseError.General.MadeUpError", + "message": "kthxbye"} + ) + connection = AsyncBolt5x2( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with pytest.raises(neo4j.exceptions.Neo4jError): + await connection.hello() + tag, fields = await sockets.server.pop_message() + assert tag == b"\x01" # HELLO + assert len(fields) == 1 + assert list(fields[0].keys()) == ["user_agent"] + assert auth.credentials not in repr(fields) + await _assert_logon_message(sockets, auth) + + +@mark_async_test +async def test_logon(fake_socket_pair): + auth = neo4j.Auth("basic", "alice123", "supersecret123") + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x2.PACKER_CLS, + unpacker_cls=AsyncBolt5x2.UNPACKER_CLS) + connection = AsyncBolt5x2(address, sockets.client, + PoolConfig.max_connection_lifetime, auth=auth) + connection.logon() + await connection.send_all() + await _assert_logon_message(sockets, auth) + + +@mark_async_test +async def test_re_auth(fake_socket_pair, mocker, static_auth): + auth = neo4j.Auth("basic", "alice123", "supersecret123") + auth_manager = static_auth(auth) + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x2.PACKER_CLS, + unpacker_cls=AsyncBolt5x2.UNPACKER_CLS) + await sockets.server.send_message( + b"\x7F", {"code": "Neo.DatabaseError.General.MadeUpError", + "message": "kthxbye"} + ) + connection = AsyncBolt5x2(address, sockets.client, + PoolConfig.max_connection_lifetime) + connection.pool = mocker.AsyncMock() + connection.re_auth(auth, auth_manager) + await connection.send_all() + with pytest.raises(neo4j.exceptions.Neo4jError): + await connection.fetch_all() + tag, fields = await sockets.server.pop_message() + assert tag == b"\x6B" # LOGOFF + assert len(fields) == 0 + await _assert_logon_message(sockets, auth) + assert connection.auth is auth + assert connection.auth_manager is auth_manager + + +@mark_async_test +async def test_logoff(fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x2.PACKER_CLS, + unpacker_cls=AsyncBolt5x2.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x2(address, sockets.client, + PoolConfig.max_connection_lifetime) + connection.logoff() + assert not sockets.server.recv_buffer # pipelined, so no response yet + await connection.send_all() + assert sockets.server.recv_buffer # now! + tag, fields = await sockets.server.pop_message() + assert tag == b"\x6B" # LOGOFF + assert len(fields) == 0 + + @pytest.mark.parametrize(("hints", "valid"), ( ({"connection.recv_timeout_seconds": 1}, True), ({"connection.recv_timeout_seconds": 42}, True), @@ -239,7 +335,7 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): async def test_hint_recv_timeout_seconds( fake_socket_pair, hints, valid, caplog, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt5x2.PACKER_CLS, unpacker_cls=AsyncBolt5x2.UNPACKER_CLS) @@ -271,6 +367,40 @@ async def test_hint_recv_timeout_seconds( for msg in caplog.messages) +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_async_test +async def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x2.PACKER_CLS, + unpacker_cls=AsyncBolt5x2.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x2( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + await connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + def _assert_notifications_in_extra(extra, expected): for key in expected: assert key in extra @@ -295,7 +425,7 @@ async def test_supports_notification_filters( fake_socket, method, args, extra_idx, cls_min_sev, method_min_sev, cls_dis_cats, method_dis_cats ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x2.UNPACKER_CLS) connection = AsyncBolt5x2( address, socket, PoolConfig.max_connection_lifetime, @@ -325,7 +455,7 @@ async def test_supports_notification_filters( async def test_hello_supports_notification_filters( fake_socket_pair, min_sev, dis_cats ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt5x2.PACKER_CLS, unpacker_cls=AsyncBolt5x2.UNPACKER_CLS) @@ -347,66 +477,3 @@ async def test_hello_supports_notification_filters( if dis_cats is not None: expected["notifications_disabled_categories"] = dis_cats _assert_notifications_in_extra(extra, expected) - - -class HackedAuth: - def __init__(self, dict_): - self.__dict__ = dict_ - - -@mark_async_test -@pytest.mark.parametrize("auth", ( - ("awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - foo="bar"), - HackedAuth({ - "scheme": "super nice scheme", "principal": "awesome test user", - "credentials": "safe p4ssw0rd", "realm": "super duper realm", - "parameters": {"credentials": "should be visible!"}, - }) - -)) -async def test_hello_does_not_log_credentials(fake_socket_pair, caplog, auth): - def items(): - if isinstance(auth, tuple): - yield "scheme", "basic" - yield "principal", auth[0] - yield "credentials", auth[1] - elif isinstance(auth, Auth): - for key in ("scheme", "principal", "credentials", "realm", - "parameters"): - value = getattr(auth, key, None) - if value: - yield key, value - elif isinstance(auth, HackedAuth): - yield from auth.__dict__.items() - else: - raise TypeError(auth) - - address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address, - packer_cls=AsyncBolt5x2.PACKER_CLS, - unpacker_cls=AsyncBolt5x2.UNPACKER_CLS) - await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) - await sockets.server.send_message(b"\x70", {}) - max_connection_lifetime = 0 - connection = AsyncBolt5x2(address, sockets.client, - max_connection_lifetime, auth=auth) - - with caplog.at_level(logging.DEBUG): - await connection.hello() - - logons = [m for m in caplog.messages if "C: LOGON " in m] - assert len(logons) == 1 - logon = logons[0] - - for key, value in items(): - if key == "credentials": - assert value not in logon - else: - assert str({key: value})[1:-1] in logon diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py index 541fff6f..953287c6 100644 --- a/tests/unit/async_/io/test_direct.py +++ b/tests/unit/async_/io/test_direct.py @@ -18,6 +18,7 @@ import pytest +from neo4j import PreviewWarning from neo4j._async.io import AsyncBolt from neo4j._async.io._pool import AsyncIOPool from neo4j._conf import ( @@ -26,6 +27,7 @@ WorkspaceConfig, ) from neo4j._deadline import Deadline +from neo4j.auth_management import AsyncAuthManagers from neo4j.exceptions import ( ClientError, ServiceUnavailable, @@ -65,6 +67,9 @@ def stale(self): async def reset(self): pass + def re_auth(self, auth, auth_manager, force=False): + return False + def close(self): self.socket.close() @@ -79,39 +84,46 @@ def timedout(self): class AsyncFakeBoltPool(AsyncIOPool): - def __init__(self, address, *, auth=None, **config): + config["auth"] = static_auth(None) self.pool_config, self.workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) if config: raise ValueError("Unexpected config keys: %s" % ", ".join(config.keys())) - async def opener(addr, timeout): + async def opener(addr, auth, timeout): return AsyncQuickConnection(AsyncFakeSocket(addr)) super().__init__(opener, self.pool_config, self.workspace_config) self.address = address async def acquire( - self, access_mode, timeout, database, bookmarks, liveness_check_timeout + self, access_mode, timeout, database, bookmarks, auth, + liveness_check_timeout ): return await self._acquire( - self.address, timeout, liveness_check_timeout + self.address, auth, timeout, liveness_check_timeout ) +def static_auth(auth): + with pytest.warns(PreviewWarning, match="Auth managers"): + return AsyncAuthManagers.static(auth) + +@pytest.fixture +def auth_manager(): + static_auth(("test", "test")) @mark_async_test -async def test_bolt_connection_open(): +async def test_bolt_connection_open(auth_manager): with pytest.raises(ServiceUnavailable): - await AsyncBolt.open( - ("localhost", 9999), auth=("test", "test") - ) + await AsyncBolt.open(("localhost", 9999), auth_manager=auth_manager) @mark_async_test -async def test_bolt_connection_open_timeout(): +async def test_bolt_connection_open_timeout(auth_manager): with pytest.raises(ServiceUnavailable): await AsyncBolt.open( - ("localhost", 9999), auth=("test", "test"), deadline=Deadline(1) + ("localhost", 9999), auth_manager=auth_manager, + deadline=Deadline(1) ) @@ -150,7 +162,7 @@ def assert_pool_size( address, expected_active, expected_inactive, pool): @mark_async_test async def test_pool_can_acquire(pool): address = ("127.0.0.1", 7687) - connection = await pool._acquire(address, Deadline(3), None) + connection = await pool._acquire(address, None, Deadline(3), None) assert connection.address == address assert_pool_size(address, 1, 0, pool) @@ -158,8 +170,8 @@ async def test_pool_can_acquire(pool): @mark_async_test async def test_pool_can_acquire_twice(pool): address = ("127.0.0.1", 7687) - connection_1 = await pool._acquire(address, Deadline(3), None) - connection_2 = await pool._acquire(address, Deadline(3), None) + connection_1 = await pool._acquire(address, None, Deadline(3), None) + connection_2 = await pool._acquire(address, None, Deadline(3), None) assert connection_1.address == address assert connection_2.address == address assert connection_1 is not connection_2 @@ -170,8 +182,8 @@ async def test_pool_can_acquire_twice(pool): async def test_pool_can_acquire_two_addresses(pool): address_1 = ("127.0.0.1", 7687) address_2 = ("127.0.0.1", 7474) - connection_1 = await pool._acquire(address_1, Deadline(3), None) - connection_2 = await pool._acquire(address_2, Deadline(3), None) + connection_1 = await pool._acquire(address_1, None, Deadline(3), None) + connection_2 = await pool._acquire(address_2, None, Deadline(3), None) assert connection_1.address == address_1 assert connection_2.address == address_2 assert_pool_size(address_1, 1, 0, pool) @@ -181,7 +193,7 @@ async def test_pool_can_acquire_two_addresses(pool): @mark_async_test async def test_pool_can_acquire_and_release(pool): address = ("127.0.0.1", 7687) - connection = await pool._acquire(address, Deadline(3), None) + connection = await pool._acquire(address, None, Deadline(3), None) assert_pool_size(address, 1, 0, pool) await pool.release(connection) assert_pool_size(address, 0, 1, pool) @@ -190,7 +202,7 @@ async def test_pool_can_acquire_and_release(pool): @mark_async_test async def test_pool_releasing_twice(pool): address = ("127.0.0.1", 7687) - connection = await pool._acquire(address, Deadline(3), None) + connection = await pool._acquire(address, None, Deadline(3), None) await pool.release(connection) assert_pool_size(address, 0, 1, pool) await pool.release(connection) @@ -201,7 +213,7 @@ async def test_pool_releasing_twice(pool): async def test_pool_in_use_count(pool): address = ("127.0.0.1", 7687) assert pool.in_use_connection_count(address) == 0 - connection = await pool._acquire(address, Deadline(3), None) + connection = await pool._acquire(address, None, Deadline(3), None) assert pool.in_use_connection_count(address) == 1 await pool.release(connection) assert pool.in_use_connection_count(address) == 0 @@ -211,10 +223,10 @@ async def test_pool_in_use_count(pool): async def test_pool_max_conn_pool_size(pool): async with AsyncFakeBoltPool((), max_connection_pool_size=1) as pool: address = ("127.0.0.1", 7687) - await pool._acquire(address, Deadline(0), None) + await pool._acquire(address, None, Deadline(0), None) assert pool.in_use_connection_count(address) == 1 with pytest.raises(ClientError): - await pool._acquire(address, Deadline(0), None) + await pool._acquire(address, None, Deadline(0), None) assert pool.in_use_connection_count(address) == 1 @@ -232,7 +244,7 @@ async def test_pool_reset_when_released(is_reset, pool, mocker): new_callable=mocker.AsyncMock ) is_reset_mock.return_value = is_reset - connection = await pool._acquire(address, Deadline(3), None) + connection = await pool._acquire(address, None, Deadline(3), None) assert isinstance(connection, AsyncQuickConnection) assert is_reset_mock.call_count == 0 assert reset_mock.call_count == 0 diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index ec6352b3..ca27d92f 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -21,6 +21,7 @@ import pytest from neo4j import ( + PreviewWarning, READ_ACCESS, WRITE_ACCESS, ) @@ -36,6 +37,7 @@ ) from neo4j._deadline import Deadline from neo4j.addressing import ResolvedAddress +from neo4j.auth_management import AsyncAuthManagers from neo4j.exceptions import ( Neo4jError, ServiceUnavailable, @@ -72,10 +74,11 @@ def routing_side_effect(*args, **kwargs): }] raise res - async def open_(addr, timeout): + async def open_(addr, auth, timeout): connection = async_fake_connection_generator() - connection.addr = addr + connection.unresolved_address = addr connection.timeout = timeout + connection.auth = auth route_mock = mocker.AsyncMock() route_mock.side_effect = routing_side_effect @@ -97,48 +100,59 @@ def opener(routing_failure_opener): return routing_failure_opener() +def _pool_config(): + pool_config = PoolConfig() + pool_config.auth = _auth_manager(("user", "pass")) + return pool_config + + +def _auth_manager(auth): + with pytest.warns(PreviewWarning, match="Auth managers"): + return AsyncAuthManagers.static(auth) + + +def _simple_pool(opener) -> AsyncNeo4jPool: + return AsyncNeo4jPool( + opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS + ) + + @mark_async_test async def test_acquires_new_routing_table_if_deleted(opener): - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) - cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + pool = _simple_pool(opener) + cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx) assert pool.routing_tables.get("test_db") del pool.routing_tables["test_db"] - cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx) assert pool.routing_tables.get("test_db") @mark_async_test async def test_acquires_new_routing_table_if_stale(opener): - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) - cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + pool = _simple_pool(opener) + cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx) assert pool.routing_tables.get("test_db") old_value = pool.routing_tables["test_db"].last_updated_time pool.routing_tables["test_db"].ttl = 0 - cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx) assert pool.routing_tables["test_db"].last_updated_time > old_value @mark_async_test async def test_removes_old_routing_table(opener): - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) - cx = await pool.acquire(READ_ACCESS, 30, "test_db1", None, None) + pool = _simple_pool(opener) + cx = await pool.acquire(READ_ACCESS, 30, "test_db1", None, None, None) await pool.release(cx) assert pool.routing_tables.get("test_db1") - cx = await pool.acquire(READ_ACCESS, 30, "test_db2", None, None) + cx = await pool.acquire(READ_ACCESS, 30, "test_db2", None, None, None) await pool.release(cx) assert pool.routing_tables.get("test_db2") @@ -147,7 +161,7 @@ async def test_removes_old_routing_table(opener): pool.routing_tables["test_db2"].ttl = \ -RoutingConfig.routing_table_purge_delay - cx = await pool.acquire(READ_ACCESS, 30, "test_db1", None, None) + cx = await pool.acquire(READ_ACCESS, 30, "test_db1", None, None, None) await pool.release(cx) assert pool.routing_tables["test_db1"].last_updated_time > old_value assert "test_db2" not in pool.routing_tables @@ -156,28 +170,24 @@ async def test_removes_old_routing_table(opener): @pytest.mark.parametrize("type_", ("r", "w")) @mark_async_test async def test_chooses_right_connection_type(opener, type_): - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) cx1 = await pool.acquire( READ_ACCESS if type_ == "r" else WRITE_ACCESS, - 30, "test_db", None, None + 30, "test_db", None, None, None ) await pool.release(cx1) if type_ == "r": - assert cx1.addr == READER_ADDRESS + assert cx1.unresolved_address == READER_ADDRESS else: - assert cx1.addr == WRITER_ADDRESS + assert cx1.unresolved_address == WRITER_ADDRESS @mark_async_test async def test_reuses_connection(opener): - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + pool = _simple_pool(opener) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx1) - cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) assert cx1 is cx2 @@ -185,75 +195,70 @@ async def test_reuses_connection(opener): @mark_async_test async def test_closes_stale_connections(opener, break_on_close): async def break_connection(): - await pool.deactivate(cx1.addr) + await pool.deactivate(cx1.unresolved_address) if cx_close_mock_side_effect: res = cx_close_mock_side_effect() if inspect.isawaitable(res): return await res - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + pool = _simple_pool(opener) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx1) - assert cx1 in pool.connections[cx1.addr] - # simulate connection going stale (e.g. exceeding) and then breaking when - # the pool tries to close the connection + assert cx1 in pool.connections[cx1.unresolved_address] + # simulate connection going stale (e.g. exceeding idle timeout) and then + # breaking when the pool tries to close the connection cx1.stale.return_value = True cx_close_mock = cx1.close if break_on_close: cx_close_mock_side_effect = cx_close_mock.side_effect cx_close_mock.side_effect = break_connection - cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx2) if break_on_close: cx1.close.assert_called() else: cx1.close.assert_called_once() assert cx2 is not cx1 - assert cx2.addr == cx1.addr - assert cx1 not in pool.connections[cx1.addr] - assert cx2 in pool.connections[cx2.addr] + assert cx2.unresolved_address == cx1.unresolved_address + assert cx1 not in pool.connections[cx1.unresolved_address] + assert cx2 in pool.connections[cx2.unresolved_address] @mark_async_test async def test_does_not_close_stale_connections_in_use(opener): - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) - assert cx1 in pool.connections[cx1.addr] - # simulate connection going stale (e.g. exceeding) while being in use + pool = _simple_pool(opener) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + assert cx1 in pool.connections[cx1.unresolved_address] + # simulate connection going stale (e.g. exceeding idle timeout) while being + # in use cx1.stale.return_value = True - cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx2) cx1.close.assert_not_called() assert cx2 is not cx1 - assert cx2.addr == cx1.addr - assert cx1 in pool.connections[cx1.addr] - assert cx2 in pool.connections[cx2.addr] + assert cx2.unresolved_address == cx1.unresolved_address + assert cx1 in pool.connections[cx1.unresolved_address] + assert cx2 in pool.connections[cx2.unresolved_address] await pool.release(cx1) # now that cx1 is back in the pool and still stale, # it should be closed when trying to acquire the next connection cx1.close.assert_not_called() - cx3 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx3 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx3) cx1.close.assert_called_once() assert cx2 is cx3 - assert cx3.addr == cx1.addr - assert cx1 not in pool.connections[cx1.addr] - assert cx3 in pool.connections[cx2.addr] + assert cx3.unresolved_address == cx1.unresolved_address + assert cx1 not in pool.connections[cx1.unresolved_address] + assert cx3 in pool.connections[cx2.unresolved_address] @mark_async_test async def test_release_resets_connections(opener): - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + pool = _simple_pool(opener) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) cx1.is_reset_mock.return_value = False cx1.is_reset_mock.reset_mock() await pool.release(cx1) @@ -263,10 +268,8 @@ async def test_release_resets_connections(opener): @mark_async_test async def test_release_does_not_resets_closed_connections(opener): - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + pool = _simple_pool(opener) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) cx1.closed.return_value = True cx1.closed.reset_mock() cx1.is_reset_mock.reset_mock() @@ -278,10 +281,8 @@ async def test_release_does_not_resets_closed_connections(opener): @mark_async_test async def test_release_does_not_resets_defunct_connections(opener): - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + pool = _simple_pool(opener) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) cx1.defunct.return_value = True cx1.defunct.reset_mock() cx1.is_reset_mock.reset_mock() @@ -296,11 +297,10 @@ async def test_release_does_not_resets_defunct_connections(opener): async def test_acquire_performs_no_liveness_check_on_fresh_connection( opener, liveness_timeout ): - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) - cx1 = await pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) - assert cx1.addr == READER_ADDRESS + pool = _simple_pool(opener) + cx1 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) + assert cx1.unresolved_address == READER_ADDRESS cx1.reset.assert_not_called() @@ -309,14 +309,13 @@ async def test_acquire_performs_no_liveness_check_on_fresh_connection( async def test_acquire_performs_liveness_check_on_existing_connection( opener, liveness_timeout ): - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) # populate the pool with a connection - cx1 = await pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) + cx1 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) # make sure we assume the right state - assert cx1.addr == READER_ADDRESS + assert cx1.unresolved_address == READER_ADDRESS cx1.is_idle_for.assert_not_called() cx1.reset.assert_not_called() @@ -327,7 +326,8 @@ async def test_acquire_performs_liveness_check_on_existing_connection( cx1.reset.assert_not_called() # then acquire it again and assert the liveness check was performed - cx2 = await pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) + cx2 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) assert cx1 is cx2 cx1.is_idle_for.assert_called_once_with(liveness_timeout) cx2.reset.assert_awaited_once() @@ -343,14 +343,13 @@ def liveness_side_effect(*args, **kwargs): raise liveness_error("liveness check failed") liveness_timeout = 1 - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) # populate the pool with a connection - cx1 = await pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) + cx1 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) # make sure we assume the right state - assert cx1.addr == READER_ADDRESS + assert cx1.unresolved_address == READER_ADDRESS cx1.is_idle_for.assert_not_called() cx1.reset.assert_not_called() @@ -363,13 +362,14 @@ def liveness_side_effect(*args, **kwargs): cx1.reset.assert_not_called() # then acquire it again and assert the liveness check was performed - cx2 = await pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) + cx2 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) assert cx1 is not cx2 - assert cx1.addr == cx2.addr + assert cx1.unresolved_address == cx2.unresolved_address cx1.is_idle_for.assert_called_once_with(liveness_timeout) cx2.reset.assert_not_called() - assert cx1 not in pool.connections[cx1.addr] - assert cx2 in pool.connections[cx1.addr] + assert cx1 not in pool.connections[cx1.unresolved_address] + assert cx2 in pool.connections[cx1.unresolved_address] @pytest.mark.parametrize("liveness_error", @@ -382,16 +382,16 @@ def liveness_side_effect(*args, **kwargs): raise liveness_error("liveness check failed") liveness_timeout = 1 - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) # populate the pool with a connection - cx1 = await pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) - cx2 = await pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) + cx1 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) + cx2 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) # make sure we assume the right state - assert cx1.addr == READER_ADDRESS - assert cx2.addr == READER_ADDRESS + assert cx1.unresolved_address == READER_ADDRESS + assert cx2.unresolved_address == READER_ADDRESS assert cx1 is not cx2 cx1.is_idle_for.assert_not_called() cx2.is_idle_for.assert_not_called() @@ -409,14 +409,15 @@ def liveness_side_effect(*args, **kwargs): cx2.reset.assert_not_called() # then acquire it again and assert the liveness check was performed - cx3 = await pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) + cx3 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) assert cx3 is cx2 cx1.is_idle_for.assert_called_once_with(liveness_timeout) cx1.reset.assert_awaited_once() cx3.is_idle_for.assert_called_once_with(liveness_timeout) cx3.reset.assert_awaited_once() - assert cx1 not in pool.connections[cx1.addr] - assert cx3 in pool.connections[cx1.addr] + assert cx1 not in pool.connections[cx1.unresolved_address] + assert cx3 in pool.connections[cx1.unresolved_address] @mark_async_test @@ -431,11 +432,9 @@ async def close_side_effect(): "close") # create pool with 2 idle connections - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) - cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + pool = _simple_pool(opener) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx1) await pool.release(cx2) @@ -447,7 +446,7 @@ async def close_side_effect(): # unreachable cx1.stale.return_value = True - cx3 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx3 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) assert cx3 is not cx1 assert cx3 is not cx2 @@ -455,26 +454,24 @@ async def close_side_effect(): @mark_async_test async def test_failing_opener_leaves_connections_in_use_alone(opener): - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + pool = _simple_pool(opener) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) opener.side_effect = ServiceUnavailable("Server overloaded") with pytest.raises((ServiceUnavailable, SessionExpired)): - await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) assert not cx1.closed() @mark_async_test async def test__acquire_new_later_with_room(opener): - config = PoolConfig() + config = _pool_config() config.max_connection_pool_size = 1 pool = AsyncNeo4jPool( opener, config, WorkspaceConfig(), ROUTER1_ADDRESS ) assert pool.connections_reservations[READER_ADDRESS] == 0 - creator = pool._acquire_new_later(READER_ADDRESS, Deadline(1)) + creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1)) assert pool.connections_reservations[READER_ADDRESS] == 1 assert callable(creator) if AsyncUtil.is_async_code: @@ -483,15 +480,15 @@ async def test__acquire_new_later_with_room(opener): @mark_async_test async def test__acquire_new_later_without_room(opener): - config = PoolConfig() + config = _pool_config() config.max_connection_pool_size = 1 pool = AsyncNeo4jPool( opener, config, WorkspaceConfig(), ROUTER1_ADDRESS ) - _ = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + _ = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) # pool is full now assert pool.connections_reservations[READER_ADDRESS] == 0 - creator = pool._acquire_new_later(READER_ADDRESS, Deadline(1)) + creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1)) assert pool.connections_reservations[READER_ADDRESS] == 0 assert creator is None @@ -503,11 +500,11 @@ async def test_passes_pool_config_to_connection(mocker): pool_config = PoolConfig() workspace_config = WorkspaceConfig() pool = AsyncNeo4jPool.open( - mocker.Mock, auth=("a", "b"), pool_config=pool_config, workspace_config=workspace_config + mocker.Mock, pool_config=pool_config, workspace_config=workspace_config ) _ = await pool._acquire( - mocker.Mock, Deadline.from_timeout_or_deadline(30), None + mocker.Mock, None, Deadline.from_timeout_or_deadline(30), None ) bolt_mock.assert_awaited_once() @@ -528,14 +525,14 @@ async def test_discovery_is_retried(routing_failure_opener, error): error, # will be retried ]) pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), + opener, _pool_config(), WorkspaceConfig(), ResolvedAddress(("1.2.3.1", 9999), host_name="host") ) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx1) pool.routing_tables.get("test_db").ttl = 0 - cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx2) assert pool.routing_tables.get("test_db") @@ -572,15 +569,15 @@ async def test_fast_failing_discovery(routing_failure_opener, error): error, # will be retried ]) pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), + opener, _pool_config(), WorkspaceConfig(), ResolvedAddress(("1.2.3.1", 9999), host_name="host") ) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx1) pool.routing_tables.get("test_db").ttl = 0 with pytest.raises(error.__class__) as exc: - await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) assert exc.value is error @@ -588,3 +585,66 @@ async def test_fast_failing_discovery(routing_failure_opener, error): # reader # failed router assert len(opener.connections) == 3 + + + +@pytest.mark.parametrize( + ("error", "marks_unauthenticated", "fetches_new"), + ( + (Neo4jError.hydrate("message", args[0]), *args[1:]) + for args in ( + ("Neo.ClientError.Database.DatabaseNotFound", False, False), + ("Neo.ClientError.Statement.TypeError", False, False), + ("Neo.ClientError.Statement.ArgumentError", False, False), + ("Neo.ClientError.Request.Invalid", False, False), + ("Neo.ClientError.Security.AuthenticationRateLimit", False, False), + ("Neo.ClientError.Security.CredentialsExpired", False, False), + ("Neo.ClientError.Security.Forbidden", False, False), + ("Neo.ClientError.Security.Unauthorized", False, False), + ("Neo.ClientError.Security.MadeUpError", False, False), + ("Neo.ClientError.Security.TokenExpired", False, True), + ("Neo.ClientError.Security.AuthorizationExpired", True, False), + ) + ) +) +@mark_async_test +async def test_connection_error_callback( + opener, error, marks_unauthenticated, fetches_new, mocker +): + config = _pool_config() + auth_manager = _auth_manager(("user", "auth")) + on_auth_expired_mock = mocker.patch.object(auth_manager, "on_auth_expired", + autospec=True) + config.auth = auth_manager + pool = AsyncNeo4jPool( + opener, config, WorkspaceConfig(), ROUTER1_ADDRESS + ) + cxs_read = [ + await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + for _ in range(5) + ] + cxs_write = [ + await pool.acquire(WRITE_ACCESS, 30, "test_db", None, None, None) + for _ in range(5) + ] + + on_auth_expired_mock.assert_not_called() + for cx in cxs_read + cxs_write: + cx.mark_unauthenticated.assert_not_called() + + await pool.on_neo4j_error(error, cxs_read[0]) + + if fetches_new: + cxs_read[0].auth_manager.on_auth_expired.assert_awaited_once() + else: + on_auth_expired_mock.assert_not_called() + for cx in cxs_read: + cx.auth_manager.on_auth_expired.assert_not_called() + + for cx in cxs_read: + if marks_unauthenticated: + cx.mark_unauthenticated.assert_called_once() + else: + cx.mark_unauthenticated.assert_not_called() + for cx in cxs_write: + cx.mark_unauthenticated.assert_not_called() diff --git a/tests/unit/async_/test_addressing.py b/tests/unit/async_/test_addressing.py index ca5eb0f0..54113cb2 100644 --- a/tests/unit/async_/test_addressing.py +++ b/tests/unit/async_/test_addressing.py @@ -16,10 +16,7 @@ # limitations under the License. -from socket import ( - AF_INET, - AF_INET6, -) +from socket import AF_INET import pytest diff --git a/tests/unit/async_/test_auth_manager.py b/tests/unit/async_/test_auth_manager.py new file mode 100644 index 00000000..e29abacf --- /dev/null +++ b/tests/unit/async_/test_auth_manager.py @@ -0,0 +1,153 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import typing as t + +import pytest +from freezegun import freeze_time +from freezegun.api import FrozenDateTimeFactory + +from neo4j import ( + Auth, + basic_auth, + PreviewWarning, +) +from neo4j._meta import copy_signature +from neo4j.auth_management import ( + AsyncAuthManager, + AsyncAuthManagers, + ExpiringAuth, +) + +from ..._async_compat import mark_async_test + + +SAMPLE_AUTHS = ( + None, + ("user", "password"), + basic_auth("foo", "bar"), + basic_auth("foo", "bar", "baz"), + Auth("scheme", "principal", "credentials", "realm", para="meter"), +) + + +@copy_signature(AsyncAuthManagers.static) +def static_auth_manager(*args, **kwargs): + with pytest.warns(PreviewWarning, match="Auth managers"): + return AsyncAuthManagers.static(*args, **kwargs) + +@copy_signature(AsyncAuthManagers.expiration_based) +def expiration_based_auth_manager(*args, **kwargs): + with pytest.warns(PreviewWarning, match="Auth managers"): + return AsyncAuthManagers.expiration_based(*args, **kwargs) + + +@copy_signature(ExpiringAuth) +def expiring_auth(*args, **kwargs): + with pytest.warns(PreviewWarning, match="Auth managers"): + return ExpiringAuth(*args, **kwargs) + + +@mark_async_test +@pytest.mark.parametrize("auth", SAMPLE_AUTHS) +async def test_static_manager( + auth +) -> None: + manager: AsyncAuthManager = static_auth_manager(auth) + assert await manager.get_auth() is auth + + await manager.on_auth_expired(("something", "else")) + assert await manager.get_auth() is auth + + await manager.on_auth_expired(auth) + assert await manager.get_auth() is auth + + +@mark_async_test +@pytest.mark.parametrize(("auth1", "auth2"), + itertools.product(SAMPLE_AUTHS, repeat=2)) +@pytest.mark.parametrize("expires_in", (None, -1, 1., 1, 1000.)) +async def test_expiration_based_manager_manual_expiry( + auth1: t.Union[t.Tuple[str, str], Auth, None], + auth2: t.Union[t.Tuple[str, str], Auth, None], + expires_in: t.Union[float, int], + mocker +) -> None: + if expires_in is None or expires_in >= 0: + temporal_auth = expiring_auth(auth1, expires_in) + else: + temporal_auth = expiring_auth(auth1) + provider = mocker.AsyncMock(return_value=temporal_auth) + manager: AsyncAuthManager = expiration_based_auth_manager(provider) + + provider.assert_not_called() + assert await manager.get_auth() is auth1 + provider.assert_awaited_once() + provider.reset_mock() + + provider.return_value = expiring_auth(auth2) + + await manager.on_auth_expired(("something", "else")) + assert await manager.get_auth() is auth1 + provider.assert_not_called() + + await manager.on_auth_expired(auth1) + provider.assert_awaited_once() + provider.reset_mock() + assert await manager.get_auth() is auth2 + provider.assert_not_called() + + +@mark_async_test +@pytest.mark.parametrize(("auth1", "auth2"), + itertools.product(SAMPLE_AUTHS, repeat=2)) +@pytest.mark.parametrize("expires_in", (None, -1, 1., 1, 1000.)) +async def test_expiration_based_manager_time_expiry( + auth1: t.Union[t.Tuple[str, str], Auth, None], + auth2: t.Union[t.Tuple[str, str], Auth, None], + expires_in: t.Union[float, int, None], + mocker +) -> None: + with freeze_time() as frozen_time: + assert isinstance(frozen_time, FrozenDateTimeFactory) + if expires_in is None or expires_in >= 0: + temporal_auth = expiring_auth(auth1, expires_in) + else: + temporal_auth = expiring_auth(auth1) + provider = mocker.AsyncMock(return_value=temporal_auth) + manager: AsyncAuthManager = expiration_based_auth_manager(provider) + + provider.assert_not_called() + assert await manager.get_auth() is auth1 + provider.assert_awaited_once() + provider.reset_mock() + + provider.return_value = expiring_auth(auth2) + + if expires_in is None or expires_in < 0: + frozen_time.tick(1_000_000) + assert await manager.get_auth() is auth1 + provider.assert_not_called() + else: + frozen_time.tick(expires_in - 0.000001) + assert await manager.get_auth() is auth1 + provider.assert_not_called() + frozen_time.tick(0.000002) + assert await manager.get_auth() is auth2 + provider.assert_awaited_once() diff --git a/tests/unit/async_/test_driver.py b/tests/unit/async_/test_driver.py index 04ef4e97..62848ebb 100644 --- a/tests/unit/async_/test_driver.py +++ b/tests/unit/async_/test_driver.py @@ -232,6 +232,7 @@ async def test_driver_opens_write_session_by_default(uri, fake_pool, mocker): timeout=mocker.ANY, database=mocker.ANY, bookmarks=mocker.ANY, + auth=mocker.ANY, liveness_check_timeout=mocker.ANY ) tx._begin.assert_awaited_once_with( @@ -970,7 +971,8 @@ async def test_execute_query_result_transformer( with assert_warns_execute_query_bmm_experimental(): bmm = driver.query_bookmark_manager res_custom = await driver.execute_query( - "", None, "w", None, None, bmm, result_transformer + "", None, "w", None, None, bmm, None, + result_transformer ) else: res_custom = await driver.execute_query( @@ -987,3 +989,18 @@ async def test_execute_query_result_transformer( _work, mocker.ANY, mocker.ANY, result_transformer ) assert res is session_executor_mock.return_value + + +@mark_async_test +async def test_supports_session_auth(mocker) -> None: + driver = AsyncGraphDatabase.driver("bolt://localhost") + session_cls_mock = mocker.patch("neo4j._async.driver.AsyncSession", + autospec=True) + async with driver as driver: + res = await driver.supports_session_auth() + + session_cls_mock.assert_called_once() + session_cls_mock.return_value.__aenter__.assert_awaited_once() + session_mock = session_cls_mock.return_value.__aenter__.return_value + connection_mock = session_mock._connection + assert res is connection_mock.supports_re_auth diff --git a/tests/unit/async_/work/test_session.py b/tests/unit/async_/work/test_session.py index a66ba16b..fdb51e36 100644 --- a/tests/unit/async_/work/test_session.py +++ b/tests/unit/async_/work/test_session.py @@ -409,7 +409,7 @@ async def test_with_bookmark_manager( additional_session_bookmarks, mocker ): async def update_routing_table_side_effect( - database, imp_user, bookmarks, acquisition_timeout=None, + database, imp_user, bookmarks, auth=None, acquisition_timeout=None, database_callback=None ): if home_db_gets_resolved: diff --git a/tests/unit/common/test_conf.py b/tests/unit/common/test_conf.py index 0a18c535..6377f74a 100644 --- a/tests/unit/common/test_conf.py +++ b/tests/unit/common/test_conf.py @@ -56,6 +56,7 @@ "user_agent": "test", "trusted_certificates": TrustSystemCAs(), "ssl_context": None, + "auth": None, "notifications_min_severity": None, "notifications_disabled_categories": None, } @@ -72,6 +73,7 @@ "impersonated_user": None, "fetch_size": 100, "bookmark_manager": object(), + "auth": None, "notifications_min_severity": None, "notifications_disabled_categories": None, } diff --git a/tests/unit/mixed/io/_common.py b/tests/unit/mixed/io/_common.py new file mode 100644 index 00000000..7cba228a --- /dev/null +++ b/tests/unit/mixed/io/_common.py @@ -0,0 +1,124 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import time +from asyncio import ( + Condition as AsyncCondition, + Lock as AsyncLock, +) +from threading import ( + Condition, + Lock, +) + +from neo4j._async_compat.shims import wait_for + + +class MultiEvent: + # Adopted from threading.Event + + def __init__(self): + super().__init__() + self._cond = Condition(Lock()) + self._counter = 0 + + def _reset_internal_locks(self): + # private! called by Thread._reset_internal_locks by _after_fork() + self._cond.__init__(Lock()) + + def counter(self): + return self._counter + + def increment(self): + with self._cond: + self._counter += 1 + self._cond.notify_all() + + def decrement(self): + with self._cond: + self._counter -= 1 + self._cond.notify_all() + + def clear(self): + with self._cond: + self._counter = 0 + self._cond.notify_all() + + def wait(self, value=0, timeout=None): + with self._cond: + t_start = time.time() + while True: + if value == self._counter: + return True + if timeout is None: + time_left = None + else: + time_left = timeout - (time.time() - t_start) + if time_left <= 0: + return False + if not self._cond.wait(time_left): + return False + + +class AsyncMultiEvent: + # Adopted from threading.Event + + def __init__(self): + super().__init__() + self._cond = AsyncCondition() + self._counter = 0 + + def _reset_internal_locks(self): + # private! called by Thread._reset_internal_locks by _after_fork() + self._cond.__init__(AsyncLock()) + + def counter(self): + return self._counter + + async def increment(self): + async with self._cond: + self._counter += 1 + self._cond.notify_all() + + async def decrement(self): + async with self._cond: + self._counter -= 1 + self._cond.notify_all() + + async def clear(self): + async with self._cond: + self._counter = 0 + self._cond.notify_all() + + async def wait(self, value=0, timeout=None): + async with self._cond: + t_start = time.time() + while True: + if value == self._counter: + return True + if timeout is None: + time_left = None + else: + time_left = timeout - (time.time() - t_start) + if time_left <= 0: + return False + try: + await wait_for(self._cond.wait(), time_left) + except asyncio.TimeoutError: + return False diff --git a/tests/unit/mixed/io/test_direct.py b/tests/unit/mixed/io/test_direct.py index be28f810..ec1e9bb3 100644 --- a/tests/unit/mixed/io/test_direct.py +++ b/tests/unit/mixed/io/test_direct.py @@ -17,14 +17,8 @@ import asyncio -import time -from asyncio import ( - Condition as AsyncCondition, - Event as AsyncEvent, - Lock as AsyncLock, -) +from asyncio import Event as AsyncEvent from threading import ( - Condition, Event, Lock, Thread, @@ -32,105 +26,16 @@ import pytest -from neo4j._async_compat.shims import wait_for +from neo4j._async.io._pool import AcquireAuth as AsyncAcquireAuth from neo4j._deadline import Deadline +from neo4j._sync.io._pool import AcquireAuth from ...async_.io.test_direct import AsyncFakeBoltPool from ...sync.io.test_direct import FakeBoltPool - - -class MultiEvent: - # Adopted from threading.Event - - def __init__(self): - super().__init__() - self._cond = Condition(Lock()) - self._counter = 0 - - def _reset_internal_locks(self): - # private! called by Thread._reset_internal_locks by _after_fork() - self._cond.__init__(Lock()) - - def counter(self): - return self._counter - - def increment(self): - with self._cond: - self._counter += 1 - self._cond.notify_all() - - def decrement(self): - with self._cond: - self._counter -= 1 - self._cond.notify_all() - - def clear(self): - with self._cond: - self._counter = 0 - self._cond.notify_all() - - def wait(self, value=0, timeout=None): - with self._cond: - t_start = time.time() - while True: - if value == self._counter: - return True - if timeout is None: - time_left = None - else: - time_left = timeout - (time.time() - t_start) - if time_left <= 0: - return False - if not self._cond.wait(time_left): - return False - - -class AsyncMultiEvent: - # Adopted from threading.Event - - def __init__(self): - super().__init__() - self._cond = AsyncCondition() - self._counter = 0 - - def _reset_internal_locks(self): - # private! called by Thread._reset_internal_locks by _after_fork() - self._cond.__init__(AsyncLock()) - - def counter(self): - return self._counter - - async def increment(self): - async with self._cond: - self._counter += 1 - self._cond.notify_all() - - async def decrement(self): - async with self._cond: - self._counter -= 1 - self._cond.notify_all() - - async def clear(self): - async with self._cond: - self._counter = 0 - self._cond.notify_all() - - async def wait(self, value=0, timeout=None): - async with self._cond: - t_start = time.time() - while True: - if value == self._counter: - return True - if timeout is None: - time_left = None - else: - time_left = timeout - (time.time() - t_start) - if time_left <= 0: - return False - try: - await wait_for(self._cond.wait(), time_left) - except asyncio.TimeoutError: - return False +from ._common import ( + AsyncMultiEvent, + MultiEvent, +) class TestMixedConnectionPoolTestCase: @@ -156,7 +61,7 @@ def test_multithread(self, pre_populated): def acquire_release_conn(pool_, address_, acquired_counter_, release_event_): nonlocal connections, connections_lock - conn_ = pool_._acquire(address_, Deadline(3), None) + conn_ = pool_._acquire(address_, None, Deadline(3), None) with connections_lock: if connections is not None: connections.append(conn_) @@ -171,7 +76,7 @@ def acquire_release_conn(pool_, address_, acquired_counter_, # pre-populate the pool with connections for _ in range(pre_populated): - conn = pool._acquire(address, Deadline(3), None) + conn = pool._acquire(address, None, Deadline(3), None) pre_populated_connections.append(conn) for conn in pre_populated_connections: pool.release(conn) @@ -217,7 +122,7 @@ async def test_multi_coroutine(self, pre_populated): async def acquire_release_conn(pool_, address_, acquired_counter_, release_event_): nonlocal connections - conn_ = await pool_._acquire(address_, Deadline(3), None) + conn_ = await pool_._acquire(address_, None, Deadline(3), None) if connections is not None: connections.append(conn_) await acquired_counter_.increment() @@ -251,7 +156,7 @@ async def waiter(pool_, acquired_counter_, release_event_): # pre-populate the pool with connections for _ in range(pre_populated): - conn = await pool._acquire(address, Deadline(3), None) + conn = await pool._acquire(address, None, Deadline(3), None) pre_populated_connections.append(conn) for conn in pre_populated_connections: await pool.release(conn) diff --git a/tests/unit/sync/fixtures/fake_connection.py b/tests/unit/sync/fixtures/fake_connection.py index ad26c730..659daebe 100644 --- a/tests/unit/sync/fixtures/fake_connection.py +++ b/tests/unit/sync/fixtures/fake_connection.py @@ -23,6 +23,7 @@ from neo4j import ServerInfo from neo4j._deadline import Deadline from neo4j._sync.io import Bolt +from neo4j.auth_management import AuthManager __all__ = [ @@ -50,7 +51,10 @@ def __init__(self, *args, **kwargs): self.attach_mock(mock.Mock(return_value=False), "stale") self.attach_mock(mock.Mock(return_value=False), "closed") self.attach_mock(mock.Mock(return_value=False), "socket") - self.attach_mock(mock.Mock(), "unresolved_address") + self.attach_mock(mock.Mock(return_value=False), "re_auth") + self.attach_mock(mock.MagicMock(spec=AuthManager), + "auth_manager") + self.unresolved_address = next(iter(args), "localhost") def close_side_effect(): self.closed.return_value = True diff --git a/tests/unit/sync/io/conftest.py b/tests/unit/sync/io/conftest.py index a06cf3c3..dc286dbd 100644 --- a/tests/unit/sync/io/conftest.py +++ b/tests/unit/sync/io/conftest.py @@ -24,10 +24,12 @@ import pytest +from neo4j import PreviewWarning from neo4j._sync.io._common import ( Inbox, Outbox, ) +from neo4j.auth_management import AuthManagers class FakeSocket: @@ -103,6 +105,9 @@ def sendall(self, data): def close(self): return + def kill(self): + return + def inject(self, data): self.recv_buffer += data @@ -142,3 +147,16 @@ def fake_socket_2(): @pytest.fixture def fake_socket_pair(): return FakeSocketPair + +@pytest.fixture +def static_auth(): + def inner(auth): + with pytest.warns(PreviewWarning, match="Auth managers"): + return AuthManagers.static(auth) + + return inner + + +@pytest.fixture +def none_auth(static_auth): + return static_auth(None) diff --git a/tests/unit/sync/io/test_class_bolt.py b/tests/unit/sync/io/test_class_bolt.py index f963834a..f0306f96 100644 --- a/tests/unit/sync/io/test_class_bolt.py +++ b/tests/unit/sync/io/test_class_bolt.py @@ -20,6 +20,7 @@ import pytest +import neo4j.auth_management from neo4j._async_compat.network import BoltSocket from neo4j._exceptions import BoltHandshakeError from neo4j._sync.io import Bolt @@ -94,7 +95,7 @@ def test_magic_preamble(): @TestDecorators.mark_async_only_test -def test_cancel_hello_in_open(mocker): +def test_cancel_hello_in_open(mocker, none_auth): address = ("localhost", 7687) socket_mock = mocker.MagicMock(spec=BoltSocket) @@ -112,7 +113,7 @@ def test_cancel_hello_in_open(mocker): bolt_mock.local_port = 1234 with pytest.raises(asyncio.CancelledError): - Bolt.open(address) + Bolt.open(address, auth_manager=none_auth) bolt_mock.kill.assert_called_once_with() @@ -132,7 +133,9 @@ def test_cancel_hello_in_open(mocker): ), ) @mark_sync_test -def test_version_negotiation(mocker, bolt_version, bolt_cls_path): +def test_version_negotiation( + mocker, bolt_version, bolt_cls_path, none_auth +): address = ("localhost", 7687) socket_mock = mocker.MagicMock(spec=BoltSocket) @@ -147,7 +150,7 @@ def test_version_negotiation(mocker, bolt_version, bolt_cls_path): bolt_mock = bolt_cls_mock.return_value bolt_mock.socket = socket_mock - connection = Bolt.open(address) + connection = Bolt.open(address, auth_manager=none_auth) bolt_cls_mock.assert_called_once() assert connection is bolt_mock @@ -163,7 +166,7 @@ def test_version_negotiation(mocker, bolt_version, bolt_cls_path): (6, 0), )) @mark_sync_test -def test_failing_version_negotiation(mocker, bolt_version): +def test_failing_version_negotiation(mocker, bolt_version, none_auth): supported_protocols = \ "('3.0', '4.1', '4.2', '4.3', '4.4', '5.0', '5.1', '5.2')" @@ -178,6 +181,64 @@ def test_failing_version_negotiation(mocker, bolt_version): socket_mock.getpeername.return_value = address with pytest.raises(BoltHandshakeError) as exc: - Bolt.open(address) + Bolt.open(address, auth_manager=none_auth) assert exc.match(supported_protocols) + + + +@TestDecorators.mark_async_only_test +def test_cancel_manager_in_open(mocker): + address = ("localhost", 7687) + socket_mock = mocker.MagicMock(spec=BoltSocket) + + socket_cls_mock = mocker.patch("neo4j._sync.io._bolt.BoltSocket", + autospec=True) + socket_cls_mock.connect.return_value = ( + socket_mock, (5, 0), None, None + ) + socket_mock.getpeername.return_value = address + bolt_cls_mock = mocker.patch("neo4j._sync.io._bolt5.Bolt5x0", + autospec=True) + bolt_mock = bolt_cls_mock.return_value + bolt_mock.socket = socket_mock + bolt_mock.local_port = 1234 + + auth_manager = mocker.MagicMock( + spec=neo4j.auth_management.AuthManager + ) + auth_manager.get_auth.side_effect = asyncio.CancelledError() + + with pytest.raises(asyncio.CancelledError): + Bolt.open(address, auth_manager=auth_manager) + + socket_mock.kill.assert_called_once_with() + + +@TestDecorators.mark_async_only_test +def test_fail_manager_in_open(mocker): + address = ("localhost", 7687) + socket_mock = mocker.MagicMock(spec=BoltSocket) + + socket_cls_mock = mocker.patch("neo4j._sync.io._bolt.BoltSocket", + autospec=True) + socket_cls_mock.connect.return_value = ( + socket_mock, (5, 0), None, None + ) + socket_mock.getpeername.return_value = address + bolt_cls_mock = mocker.patch("neo4j._sync.io._bolt5.Bolt5x0", + autospec=True) + bolt_mock = bolt_cls_mock.return_value + bolt_mock.socket = socket_mock + bolt_mock.local_port = 1234 + + auth_manager = mocker.MagicMock( + spec=neo4j.auth_management.AuthManager + ) + auth_manager.get_auth.side_effect = RuntimeError("token fetching failed") + + with pytest.raises(RuntimeError) as exc: + Bolt.open(address, auth_manager=auth_manager) + assert exc.value is auth_manager.get_auth.side_effect + + socket_mock.close.assert_called_once_with() diff --git a/tests/unit/sync/io/test_class_bolt3.py b/tests/unit/sync/io/test_class_bolt3.py index 71f0de81..73999db0 100644 --- a/tests/unit/sync/io/test_class_bolt3.py +++ b/tests/unit/sync/io/test_class_bolt3.py @@ -17,12 +17,13 @@ import logging +from itertools import permutations import pytest +import neo4j from neo4j._conf import PoolConfig from neo4j._sync.io._bolt3 import Bolt3 -from neo4j.api import Auth from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_sync_test @@ -30,7 +31,7 @@ @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = Bolt3(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -40,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = Bolt3(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -50,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = Bolt3(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -59,14 +60,14 @@ def test_conn_is_not_stale(fake_socket, set_stale): def test_db_extra_not_supported_in_begin(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = Bolt3(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError): connection.begin(db="something") def test_db_extra_not_supported_in_run(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = Bolt3(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError): connection.run("", db="something") @@ -74,7 +75,7 @@ def test_db_extra_not_supported_in_run(fake_socket): @mark_sync_test def test_simple_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt3.UNPACKER_CLS) connection = Bolt3(address, socket, PoolConfig.max_connection_lifetime) connection.discard() @@ -86,7 +87,7 @@ def test_simple_discard(fake_socket): @mark_sync_test def test_simple_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt3.UNPACKER_CLS) connection = Bolt3(address, socket, PoolConfig.max_connection_lifetime) connection.pull() @@ -101,7 +102,7 @@ def test_simple_pull(fake_socket): def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair( address, Bolt3.PACKER_CLS, Bolt3.UNPACKER_CLS ) @@ -117,6 +118,92 @@ def test_hint_recv_timeout_seconds_gets_ignored( sockets.client.settimeout.assert_not_called() +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_sync_test +def test_credentials_are_not_logged( + auth, fake_socket_pair, mocker, caplog +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt3.PACKER_CLS, + unpacker_cls=Bolt3.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + connection = Bolt3( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +@pytest.mark.parametrize("message", ("logon", "logoff")) +def test_auth_message_raises_configuration_error(message, fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = Bolt3(address, fake_socket(address), + PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError, + match="User switching is not supported"): + getattr(connection, message)() + + +@pytest.mark.parametrize("auth", ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), +)) +@mark_sync_test +def test_re_auth_noop(auth, fake_socket, mocker): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = Bolt3(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth) + logon_spy = mocker.spy(connection, "logon") + logoff_spy = mocker.spy(connection, "logoff") + res = connection.re_auth(auth, None) + + assert res is False + logon_spy.assert_not_called() + logoff_spy.assert_not_called() + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + permutations( + ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), + ), + 2 + ) +) +@mark_sync_test +def test_re_auth(auth1, auth2, fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = Bolt3(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth1) + with pytest.raises(ConfigurationError, + match="User switching is not supported"): + connection.re_auth(auth2, None) + + @pytest.mark.parametrize(("method", "args"), ( ("run", ("RETURN 1",)), ("begin", ()), @@ -132,7 +219,7 @@ def test_hint_recv_timeout_seconds_gets_ignored( )) def test_does_not_support_notification_filters(fake_socket, method, args, kwargs): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt3.UNPACKER_CLS) connection = Bolt3(address, socket, PoolConfig.max_connection_lifetime) @@ -154,7 +241,7 @@ def test_does_not_support_notification_filters(fake_socket, method, def test_hello_does_not_support_notification_filters( fake_socket, kwargs ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt3.UNPACKER_CLS) connection = Bolt3( address, socket, PoolConfig.max_connection_lifetime, @@ -162,65 +249,3 @@ def test_hello_does_not_support_notification_filters( ) with pytest.raises(ConfigurationError, match="Notification filtering"): connection.hello() - - -class HackedAuth: - def __init__(self, dict_): - self.__dict__ = dict_ - - -@mark_sync_test -@pytest.mark.parametrize("auth", ( - ("awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - foo="bar"), - HackedAuth({ - "scheme": "super nice scheme", "principal": "awesome test user", - "credentials": "safe p4ssw0rd", "realm": "super duper realm", - "parameters": {"credentials": "should be visible!"}, - }) - -)) -def test_hello_does_not_log_credentials(fake_socket_pair, caplog, auth): - def items(): - if isinstance(auth, tuple): - yield "scheme", "basic" - yield "principal", auth[0] - yield "credentials", auth[1] - elif isinstance(auth, Auth): - for key in ("scheme", "principal", "credentials", "realm", - "parameters"): - value = getattr(auth, key, None) - if value: - yield key, value - elif isinstance(auth, HackedAuth): - yield from auth.__dict__.items() - else: - raise TypeError(auth) - - address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address, - packer_cls=Bolt3.PACKER_CLS, - unpacker_cls=Bolt3.UNPACKER_CLS) - sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) - max_connection_lifetime = 0 - connection = Bolt3(address, sockets.client, - max_connection_lifetime, auth=auth) - - with caplog.at_level(logging.DEBUG): - connection.hello() - - hellos = [m for m in caplog.messages if "C: HELLO" in m] - assert len(hellos) == 1 - hello = hellos[0] - - for key, value in items(): - if key == "credentials": - assert value not in hello - else: - assert str({key: value})[1:-1] in hello diff --git a/tests/unit/sync/io/test_class_bolt4x0.py b/tests/unit/sync/io/test_class_bolt4x0.py index d0d62368..6fb416e1 100644 --- a/tests/unit/sync/io/test_class_bolt4x0.py +++ b/tests/unit/sync/io/test_class_bolt4x0.py @@ -17,12 +17,13 @@ import logging +from itertools import permutations import pytest +import neo4j from neo4j._conf import PoolConfig from neo4j._sync.io._bolt4 import Bolt4x0 -from neo4j.api import Auth from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_sync_test @@ -30,7 +31,7 @@ @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = Bolt4x0(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -40,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = Bolt4x0(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -50,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = Bolt4x0(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -60,7 +61,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_sync_test def test_db_extra_in_begin(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") @@ -73,7 +74,7 @@ def test_db_extra_in_begin(fake_socket): @mark_sync_test def test_db_extra_in_run(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") @@ -88,7 +89,7 @@ def test_db_extra_in_run(fake_socket): @mark_sync_test def test_n_extra_in_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) @@ -108,7 +109,7 @@ def test_n_extra_in_discard(fake_socket): ) @mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) @@ -128,7 +129,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): ) @mark_sync_test def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) @@ -148,7 +149,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): ) @mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) @@ -168,7 +169,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): ) @mark_sync_test def test_qid_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) @@ -181,7 +182,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) @@ -197,7 +198,7 @@ def test_n_and_qid_extras_in_pull(fake_socket): def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt4x0.PACKER_CLS, unpacker_cls=Bolt4x0.UNPACKER_CLS) @@ -213,6 +214,92 @@ def test_hint_recv_timeout_seconds_gets_ignored( sockets.client.settimeout.assert_not_called() +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_sync_test +def test_credentials_are_not_logged( + auth, fake_socket_pair, mocker, caplog +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x0.PACKER_CLS, + unpacker_cls=Bolt4x0.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + connection = Bolt4x0( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +@pytest.mark.parametrize("message", ("logon", "logoff")) +def test_auth_message_raises_configuration_error(message, fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = Bolt4x0(address, fake_socket(address), + PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError, + match="User switching is not supported"): + getattr(connection, message)() + + +@pytest.mark.parametrize("auth", ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), +)) +@mark_sync_test +def test_re_auth_noop(auth, fake_socket, mocker): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = Bolt4x0(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth) + logon_spy = mocker.spy(connection, "logon") + logoff_spy = mocker.spy(connection, "logoff") + res = connection.re_auth(auth, None) + + assert res is False + logon_spy.assert_not_called() + logoff_spy.assert_not_called() + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + permutations( + ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), + ), + 2 + ) +) +@mark_sync_test +def test_re_auth(auth1, auth2, fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = Bolt4x0(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth1) + with pytest.raises(ConfigurationError, + match="User switching is not supported"): + connection.re_auth(auth2, None) + + @pytest.mark.parametrize(("method", "args"), ( ("run", ("RETURN 1",)), ("begin", ()), @@ -228,7 +315,7 @@ def test_hint_recv_timeout_seconds_gets_ignored( )) def test_does_not_support_notification_filters(fake_socket, method, args, kwargs): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) @@ -250,7 +337,7 @@ def test_does_not_support_notification_filters(fake_socket, method, def test_hello_does_not_support_notification_filters( fake_socket, kwargs ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0( address, socket, PoolConfig.max_connection_lifetime, @@ -258,65 +345,3 @@ def test_hello_does_not_support_notification_filters( ) with pytest.raises(ConfigurationError, match="Notification filtering"): connection.hello() - - -class HackedAuth: - def __init__(self, dict_): - self.__dict__ = dict_ - - -@mark_sync_test -@pytest.mark.parametrize("auth", ( - ("awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - foo="bar"), - HackedAuth({ - "scheme": "super nice scheme", "principal": "awesome test user", - "credentials": "safe p4ssw0rd", "realm": "super duper realm", - "parameters": {"credentials": "should be visible!"}, - }) - -)) -def test_hello_does_not_log_credentials(fake_socket_pair, caplog, auth): - def items(): - if isinstance(auth, tuple): - yield "scheme", "basic" - yield "principal", auth[0] - yield "credentials", auth[1] - elif isinstance(auth, Auth): - for key in ("scheme", "principal", "credentials", "realm", - "parameters"): - value = getattr(auth, key, None) - if value: - yield key, value - elif isinstance(auth, HackedAuth): - yield from auth.__dict__.items() - else: - raise TypeError(auth) - - address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address, - packer_cls=Bolt4x0.PACKER_CLS, - unpacker_cls=Bolt4x0.UNPACKER_CLS) - sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) - max_connection_lifetime = 0 - connection = Bolt4x0(address, sockets.client, - max_connection_lifetime, auth=auth) - - with caplog.at_level(logging.DEBUG): - connection.hello() - - hellos = [m for m in caplog.messages if "C: HELLO" in m] - assert len(hellos) == 1 - hello = hellos[0] - - for key, value in items(): - if key == "credentials": - assert value not in hello - else: - assert str({key: value})[1:-1] in hello diff --git a/tests/unit/sync/io/test_class_bolt4x1.py b/tests/unit/sync/io/test_class_bolt4x1.py index 76b4facc..3ed433d9 100644 --- a/tests/unit/sync/io/test_class_bolt4x1.py +++ b/tests/unit/sync/io/test_class_bolt4x1.py @@ -17,12 +17,13 @@ import logging +from itertools import permutations import pytest +import neo4j from neo4j._conf import PoolConfig from neo4j._sync.io._bolt4 import Bolt4x1 -from neo4j.api import Auth from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_sync_test @@ -30,7 +31,7 @@ @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = Bolt4x1(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -40,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = Bolt4x1(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -50,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = Bolt4x1(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -60,7 +61,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_sync_test def test_db_extra_in_begin(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") @@ -73,7 +74,7 @@ def test_db_extra_in_begin(fake_socket): @mark_sync_test def test_db_extra_in_run(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") @@ -88,7 +89,7 @@ def test_db_extra_in_run(fake_socket): @mark_sync_test def test_n_extra_in_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) @@ -108,7 +109,7 @@ def test_n_extra_in_discard(fake_socket): ) @mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) @@ -128,7 +129,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): ) @mark_sync_test def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) @@ -148,7 +149,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): ) @mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) @@ -168,7 +169,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): ) @mark_sync_test def test_qid_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) @@ -181,7 +182,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) @@ -194,7 +195,7 @@ def test_n_and_qid_extras_in_pull(fake_socket): @mark_sync_test def test_hello_passes_routing_metadata(fake_socket_pair): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt4x1.PACKER_CLS, unpacker_cls=Bolt4x1.UNPACKER_CLS) @@ -215,7 +216,7 @@ def test_hello_passes_routing_metadata(fake_socket_pair): def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt4x1.PACKER_CLS, unpacker_cls=Bolt4x1.UNPACKER_CLS) @@ -230,6 +231,92 @@ def test_hint_recv_timeout_seconds_gets_ignored( sockets.client.settimeout.assert_not_called() +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_sync_test +def test_credentials_are_not_logged( + auth, fake_socket_pair, mocker, caplog +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x1.PACKER_CLS, + unpacker_cls=Bolt4x1.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + connection = Bolt4x1( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +@pytest.mark.parametrize("message", ("logon", "logoff")) +def test_auth_message_raises_configuration_error(message, fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = Bolt4x1(address, fake_socket(address), + PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError, + match="User switching is not supported"): + getattr(connection, message)() + + +@pytest.mark.parametrize("auth", ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), +)) +@mark_sync_test +def test_re_auth_noop(auth, fake_socket, mocker): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = Bolt4x1(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth) + logon_spy = mocker.spy(connection, "logon") + logoff_spy = mocker.spy(connection, "logoff") + res = connection.re_auth(auth, None) + + assert res is False + logon_spy.assert_not_called() + logoff_spy.assert_not_called() + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + permutations( + ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), + ), + 2 + ) +) +@mark_sync_test +def test_re_auth(auth1, auth2, fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = Bolt4x1(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth1) + with pytest.raises(ConfigurationError, + match="User switching is not supported"): + connection.re_auth(auth2, None) + + @pytest.mark.parametrize(("method", "args"), ( ("run", ("RETURN 1",)), ("begin", ()), @@ -245,7 +332,7 @@ def test_hint_recv_timeout_seconds_gets_ignored( )) def test_does_not_support_notification_filters(fake_socket, method, args, kwargs): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) @@ -267,7 +354,7 @@ def test_does_not_support_notification_filters(fake_socket, method, def test_hello_does_not_support_notification_filters( fake_socket, kwargs ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1( address, socket, PoolConfig.max_connection_lifetime, @@ -275,65 +362,3 @@ def test_hello_does_not_support_notification_filters( ) with pytest.raises(ConfigurationError, match="Notification filtering"): connection.hello() - - -class HackedAuth: - def __init__(self, dict_): - self.__dict__ = dict_ - - -@mark_sync_test -@pytest.mark.parametrize("auth", ( - ("awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - foo="bar"), - HackedAuth({ - "scheme": "super nice scheme", "principal": "awesome test user", - "credentials": "safe p4ssw0rd", "realm": "super duper realm", - "parameters": {"credentials": "should be visible!"}, - }) - -)) -def test_hello_does_not_log_credentials(fake_socket_pair, caplog, auth): - def items(): - if isinstance(auth, tuple): - yield "scheme", "basic" - yield "principal", auth[0] - yield "credentials", auth[1] - elif isinstance(auth, Auth): - for key in ("scheme", "principal", "credentials", "realm", - "parameters"): - value = getattr(auth, key, None) - if value: - yield key, value - elif isinstance(auth, HackedAuth): - yield from auth.__dict__.items() - else: - raise TypeError(auth) - - address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address, - packer_cls=Bolt4x1.PACKER_CLS, - unpacker_cls=Bolt4x1.UNPACKER_CLS) - sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) - max_connection_lifetime = 0 - connection = Bolt4x1(address, sockets.client, - max_connection_lifetime, auth=auth) - - with caplog.at_level(logging.DEBUG): - connection.hello() - - hellos = [m for m in caplog.messages if "C: HELLO" in m] - assert len(hellos) == 1 - hello = hellos[0] - - for key, value in items(): - if key == "credentials": - assert value not in hello - else: - assert str({key: value})[1:-1] in hello diff --git a/tests/unit/sync/io/test_class_bolt4x2.py b/tests/unit/sync/io/test_class_bolt4x2.py index 90dd813d..d9172bad 100644 --- a/tests/unit/sync/io/test_class_bolt4x2.py +++ b/tests/unit/sync/io/test_class_bolt4x2.py @@ -17,12 +17,13 @@ import logging +from itertools import permutations import pytest +import neo4j from neo4j._conf import PoolConfig from neo4j._sync.io._bolt4 import Bolt4x2 -from neo4j.api import Auth from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_sync_test @@ -30,7 +31,7 @@ @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = Bolt4x2(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -40,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = Bolt4x2(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -50,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = Bolt4x2(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -60,7 +61,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_sync_test def test_db_extra_in_begin(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") @@ -73,7 +74,7 @@ def test_db_extra_in_begin(fake_socket): @mark_sync_test def test_db_extra_in_run(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") @@ -88,7 +89,7 @@ def test_db_extra_in_run(fake_socket): @mark_sync_test def test_n_extra_in_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) @@ -108,7 +109,7 @@ def test_n_extra_in_discard(fake_socket): ) @mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) @@ -128,7 +129,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): ) @mark_sync_test def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) @@ -148,7 +149,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): ) @mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) @@ -168,7 +169,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): ) @mark_sync_test def test_qid_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) @@ -181,7 +182,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) @@ -194,7 +195,7 @@ def test_n_and_qid_extras_in_pull(fake_socket): @mark_sync_test def test_hello_passes_routing_metadata(fake_socket_pair): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt4x2.PACKER_CLS, unpacker_cls=Bolt4x2.UNPACKER_CLS) @@ -215,7 +216,7 @@ def test_hello_passes_routing_metadata(fake_socket_pair): def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt4x2.PACKER_CLS, unpacker_cls=Bolt4x2.UNPACKER_CLS) @@ -231,6 +232,92 @@ def test_hint_recv_timeout_seconds_gets_ignored( sockets.client.settimeout.assert_not_called() +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_sync_test +def test_credentials_are_not_logged( + auth, fake_socket_pair, mocker, caplog +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x2.PACKER_CLS, + unpacker_cls=Bolt4x2.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + connection = Bolt4x2( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +@pytest.mark.parametrize("message", ("logon", "logoff")) +def test_auth_message_raises_configuration_error(message, fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = Bolt4x2(address, fake_socket(address), + PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError, + match="User switching is not supported"): + getattr(connection, message)() + + +@pytest.mark.parametrize("auth", ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), +)) +@mark_sync_test +def test_re_auth_noop(auth, fake_socket, mocker): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = Bolt4x2(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth) + logon_spy = mocker.spy(connection, "logon") + logoff_spy = mocker.spy(connection, "logoff") + res = connection.re_auth(auth, None) + + assert res is False + logon_spy.assert_not_called() + logoff_spy.assert_not_called() + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + permutations( + ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), + ), + 2 + ) +) +@mark_sync_test +def test_re_auth(auth1, auth2, fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = Bolt4x2(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth1) + with pytest.raises(ConfigurationError, + match="User switching is not supported"): + connection.re_auth(auth2, None) + + @pytest.mark.parametrize(("method", "args"), ( ("run", ("RETURN 1",)), ("begin", ()), @@ -246,7 +333,7 @@ def test_hint_recv_timeout_seconds_gets_ignored( )) def test_does_not_support_notification_filters(fake_socket, method, args, kwargs): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) @@ -268,7 +355,7 @@ def test_does_not_support_notification_filters(fake_socket, method, def test_hello_does_not_support_notification_filters( fake_socket, kwargs ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2( address, socket, PoolConfig.max_connection_lifetime, @@ -276,65 +363,3 @@ def test_hello_does_not_support_notification_filters( ) with pytest.raises(ConfigurationError, match="Notification filtering"): connection.hello() - - -class HackedAuth: - def __init__(self, dict_): - self.__dict__ = dict_ - - -@mark_sync_test -@pytest.mark.parametrize("auth", ( - ("awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - foo="bar"), - HackedAuth({ - "scheme": "super nice scheme", "principal": "awesome test user", - "credentials": "safe p4ssw0rd", "realm": "super duper realm", - "parameters": {"credentials": "should be visible!"}, - }) - -)) -def test_hello_does_not_log_credentials(fake_socket_pair, caplog, auth): - def items(): - if isinstance(auth, tuple): - yield "scheme", "basic" - yield "principal", auth[0] - yield "credentials", auth[1] - elif isinstance(auth, Auth): - for key in ("scheme", "principal", "credentials", "realm", - "parameters"): - value = getattr(auth, key, None) - if value: - yield key, value - elif isinstance(auth, HackedAuth): - yield from auth.__dict__.items() - else: - raise TypeError(auth) - - address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address, - packer_cls=Bolt4x2.PACKER_CLS, - unpacker_cls=Bolt4x2.UNPACKER_CLS) - sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) - max_connection_lifetime = 0 - connection = Bolt4x2(address, sockets.client, - max_connection_lifetime, auth=auth) - - with caplog.at_level(logging.DEBUG): - connection.hello() - - hellos = [m for m in caplog.messages if "C: HELLO" in m] - assert len(hellos) == 1 - hello = hellos[0] - - for key, value in items(): - if key == "credentials": - assert value not in hello - else: - assert str({key: value})[1:-1] in hello diff --git a/tests/unit/sync/io/test_class_bolt4x3.py b/tests/unit/sync/io/test_class_bolt4x3.py index 427fb90d..c89d31d9 100644 --- a/tests/unit/sync/io/test_class_bolt4x3.py +++ b/tests/unit/sync/io/test_class_bolt4x3.py @@ -17,12 +17,13 @@ import logging +from itertools import permutations import pytest +import neo4j from neo4j._conf import PoolConfig from neo4j._sync.io._bolt4 import Bolt4x3 -from neo4j.api import Auth from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_sync_test @@ -30,7 +31,7 @@ @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = Bolt4x3(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -40,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = Bolt4x3(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -50,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = Bolt4x3(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -60,7 +61,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_sync_test def test_db_extra_in_begin(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") @@ -73,7 +74,7 @@ def test_db_extra_in_begin(fake_socket): @mark_sync_test def test_db_extra_in_run(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") @@ -88,7 +89,7 @@ def test_db_extra_in_run(fake_socket): @mark_sync_test def test_n_extra_in_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) @@ -108,7 +109,7 @@ def test_n_extra_in_discard(fake_socket): ) @mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) @@ -128,7 +129,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): ) @mark_sync_test def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) @@ -148,7 +149,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): ) @mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) @@ -168,7 +169,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): ) @mark_sync_test def test_qid_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) @@ -181,7 +182,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) @@ -194,7 +195,7 @@ def test_n_and_qid_extras_in_pull(fake_socket): @mark_sync_test def test_hello_passes_routing_metadata(fake_socket_pair): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt4x3.PACKER_CLS, unpacker_cls=Bolt4x3.UNPACKER_CLS) @@ -226,7 +227,7 @@ def test_hello_passes_routing_metadata(fake_socket_pair): def test_hint_recv_timeout_seconds( fake_socket_pair, hints, valid, caplog, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt4x3.PACKER_CLS, unpacker_cls=Bolt4x3.UNPACKER_CLS) @@ -257,6 +258,93 @@ def test_hint_recv_timeout_seconds( for msg in caplog.messages) +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_sync_test +def test_credentials_are_not_logged( + auth, fake_socket_pair, mocker, caplog +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x3.PACKER_CLS, + unpacker_cls=Bolt4x3.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + connection = Bolt4x3( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + + +@pytest.mark.parametrize("message", ("logon", "logoff")) +def test_auth_message_raises_configuration_error(message, fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = Bolt4x3(address, fake_socket(address), + PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError, + match="User switching is not supported"): + getattr(connection, message)() + + +@pytest.mark.parametrize("auth", ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), +)) +@mark_sync_test +def test_re_auth_noop(auth, fake_socket, mocker): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = Bolt4x3(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth) + logon_spy = mocker.spy(connection, "logon") + logoff_spy = mocker.spy(connection, "logoff") + res = connection.re_auth(auth, None) + + assert res is False + logon_spy.assert_not_called() + logoff_spy.assert_not_called() + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + permutations( + ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), + ), + 2 + ) +) +@mark_sync_test +def test_re_auth(auth1, auth2, fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = Bolt4x3(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth1) + with pytest.raises(ConfigurationError, + match="User switching is not supported"): + connection.re_auth(auth2, None) + + @pytest.mark.parametrize(("method", "args"), ( ("run", ("RETURN 1",)), ("begin", ()), @@ -272,7 +360,7 @@ def test_hint_recv_timeout_seconds( )) def test_does_not_support_notification_filters(fake_socket, method, args, kwargs): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) @@ -294,7 +382,7 @@ def test_does_not_support_notification_filters(fake_socket, method, def test_hello_does_not_support_notification_filters( fake_socket, kwargs ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3( address, socket, PoolConfig.max_connection_lifetime, @@ -302,65 +390,3 @@ def test_hello_does_not_support_notification_filters( ) with pytest.raises(ConfigurationError, match="Notification filtering"): connection.hello() - - -class HackedAuth: - def __init__(self, dict_): - self.__dict__ = dict_ - - -@mark_sync_test -@pytest.mark.parametrize("auth", ( - ("awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - foo="bar"), - HackedAuth({ - "scheme": "super nice scheme", "principal": "awesome test user", - "credentials": "safe p4ssw0rd", "realm": "super duper realm", - "parameters": {"credentials": "should be visible!"}, - }) - -)) -def test_hello_does_not_log_credentials(fake_socket_pair, caplog, auth): - def items(): - if isinstance(auth, tuple): - yield "scheme", "basic" - yield "principal", auth[0] - yield "credentials", auth[1] - elif isinstance(auth, Auth): - for key in ("scheme", "principal", "credentials", "realm", - "parameters"): - value = getattr(auth, key, None) - if value: - yield key, value - elif isinstance(auth, HackedAuth): - yield from auth.__dict__.items() - else: - raise TypeError(auth) - - address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address, - packer_cls=Bolt4x3.PACKER_CLS, - unpacker_cls=Bolt4x3.UNPACKER_CLS) - sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) - max_connection_lifetime = 0 - connection = Bolt4x3(address, sockets.client, - max_connection_lifetime, auth=auth) - - with caplog.at_level(logging.DEBUG): - connection.hello() - - hellos = [m for m in caplog.messages if "C: HELLO" in m] - assert len(hellos) == 1 - hello = hellos[0] - - for key, value in items(): - if key == "credentials": - assert value not in hello - else: - assert str({key: value})[1:-1] in hello diff --git a/tests/unit/sync/io/test_class_bolt4x4.py b/tests/unit/sync/io/test_class_bolt4x4.py index bdd126d0..be7cc754 100644 --- a/tests/unit/sync/io/test_class_bolt4x4.py +++ b/tests/unit/sync/io/test_class_bolt4x4.py @@ -17,12 +17,13 @@ import logging +from itertools import permutations import pytest +import neo4j from neo4j._conf import PoolConfig from neo4j._sync.io._bolt4 import Bolt4x4 -from neo4j.api import Auth from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_sync_test @@ -30,7 +31,7 @@ @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = Bolt4x4(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -40,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = Bolt4x4(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -50,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = Bolt4x4(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -69,7 +70,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): )) @mark_sync_test def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.begin(*args, **kwargs) @@ -90,7 +91,7 @@ def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): )) @mark_sync_test def test_extra_in_run(fake_socket, args, kwargs, expected_fields): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.run(*args, **kwargs) @@ -102,7 +103,7 @@ def test_extra_in_run(fake_socket, args, kwargs, expected_fields): @mark_sync_test def test_n_extra_in_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) @@ -122,7 +123,7 @@ def test_n_extra_in_discard(fake_socket): ) @mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) @@ -142,7 +143,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): ) @mark_sync_test def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) @@ -162,7 +163,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): ) @mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) @@ -182,7 +183,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): ) @mark_sync_test def test_qid_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) @@ -195,7 +196,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) @@ -208,7 +209,7 @@ def test_n_and_qid_extras_in_pull(fake_socket): @mark_sync_test def test_hello_passes_routing_metadata(fake_socket_pair): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt4x4.PACKER_CLS, unpacker_cls=Bolt4x4.UNPACKER_CLS) @@ -240,7 +241,7 @@ def test_hello_passes_routing_metadata(fake_socket_pair): def test_hint_recv_timeout_seconds( fake_socket_pair, hints, valid, caplog, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt4x4.PACKER_CLS, unpacker_cls=Bolt4x4.UNPACKER_CLS) @@ -271,6 +272,92 @@ def test_hint_recv_timeout_seconds( for msg in caplog.messages) +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_sync_test +def test_credentials_are_not_logged( + auth, fake_socket_pair, mocker, caplog +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x4.PACKER_CLS, + unpacker_cls=Bolt4x4.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + connection = Bolt4x4( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +@pytest.mark.parametrize("message", ("logon", "logoff")) +def test_auth_message_raises_configuration_error(message, fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = Bolt4x4(address, fake_socket(address), + PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError, + match="User switching is not supported"): + getattr(connection, message)() + + +@pytest.mark.parametrize("auth", ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), +)) +@mark_sync_test +def test_re_auth_noop(auth, fake_socket, mocker): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = Bolt4x4(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth) + logon_spy = mocker.spy(connection, "logon") + logoff_spy = mocker.spy(connection, "logoff") + res = connection.re_auth(auth, None) + + assert res is False + logon_spy.assert_not_called() + logoff_spy.assert_not_called() + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + permutations( + ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), + ), + 2 + ) +) +@mark_sync_test +def test_re_auth(auth1, auth2, fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = Bolt4x4(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth1) + with pytest.raises(ConfigurationError, + match="User switching is not supported"): + connection.re_auth(auth2, None) + + @pytest.mark.parametrize(("method", "args"), ( ("run", ("RETURN 1",)), ("begin", ()), @@ -286,7 +373,7 @@ def test_hint_recv_timeout_seconds( )) def test_does_not_support_notification_filters(fake_socket, method, args, kwargs): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) @@ -308,7 +395,7 @@ def test_does_not_support_notification_filters(fake_socket, method, def test_hello_does_not_support_notification_filters( fake_socket, kwargs ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4( address, socket, PoolConfig.max_connection_lifetime, @@ -316,65 +403,3 @@ def test_hello_does_not_support_notification_filters( ) with pytest.raises(ConfigurationError, match="Notification filtering"): connection.hello() - - -class HackedAuth: - def __init__(self, dict_): - self.__dict__ = dict_ - - -@mark_sync_test -@pytest.mark.parametrize("auth", ( - ("awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - foo="bar"), - HackedAuth({ - "scheme": "super nice scheme", "principal": "awesome test user", - "credentials": "safe p4ssw0rd", "realm": "super duper realm", - "parameters": {"credentials": "should be visible!"}, - }) - -)) -def test_hello_does_not_log_credentials(fake_socket_pair, caplog, auth): - def items(): - if isinstance(auth, tuple): - yield "scheme", "basic" - yield "principal", auth[0] - yield "credentials", auth[1] - elif isinstance(auth, Auth): - for key in ("scheme", "principal", "credentials", "realm", - "parameters"): - value = getattr(auth, key, None) - if value: - yield key, value - elif isinstance(auth, HackedAuth): - yield from auth.__dict__.items() - else: - raise TypeError(auth) - - address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address, - packer_cls=Bolt4x4.PACKER_CLS, - unpacker_cls=Bolt4x4.UNPACKER_CLS) - sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) - max_connection_lifetime = 0 - connection = Bolt4x4(address, sockets.client, - max_connection_lifetime, auth=auth) - - with caplog.at_level(logging.DEBUG): - connection.hello() - - hellos = [m for m in caplog.messages if "C: HELLO" in m] - assert len(hellos) == 1 - hello = hellos[0] - - for key, value in items(): - if key == "credentials": - assert value not in hello - else: - assert str({key: value})[1:-1] in hello diff --git a/tests/unit/sync/io/test_class_bolt5x0.py b/tests/unit/sync/io/test_class_bolt5x0.py index 2f570c9e..cf9dab99 100644 --- a/tests/unit/sync/io/test_class_bolt5x0.py +++ b/tests/unit/sync/io/test_class_bolt5x0.py @@ -17,12 +17,13 @@ import logging +from itertools import permutations import pytest +import neo4j from neo4j._conf import PoolConfig from neo4j._sync.io._bolt5 import Bolt5x0 -from neo4j.api import Auth from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_sync_test @@ -30,7 +31,7 @@ @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = Bolt5x0(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -40,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = Bolt5x0(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -50,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = Bolt5x0(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -69,7 +70,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): )) @mark_sync_test def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x0.UNPACKER_CLS) connection = Bolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.begin(*args, **kwargs) @@ -90,7 +91,7 @@ def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): )) @mark_sync_test def test_extra_in_run(fake_socket, args, kwargs, expected_fields): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x0.UNPACKER_CLS) connection = Bolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.run(*args, **kwargs) @@ -102,7 +103,7 @@ def test_extra_in_run(fake_socket, args, kwargs, expected_fields): @mark_sync_test def test_n_extra_in_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x0.UNPACKER_CLS) connection = Bolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) @@ -122,7 +123,7 @@ def test_n_extra_in_discard(fake_socket): ) @mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x0.UNPACKER_CLS) connection = Bolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) @@ -142,7 +143,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): ) @mark_sync_test def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x0.UNPACKER_CLS) connection = Bolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) @@ -162,7 +163,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): ) @mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x0.UNPACKER_CLS) connection = Bolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) @@ -182,7 +183,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): ) @mark_sync_test def test_qid_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x0.UNPACKER_CLS) connection = Bolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) @@ -195,7 +196,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x0.UNPACKER_CLS) connection = Bolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) @@ -208,7 +209,7 @@ def test_n_and_qid_extras_in_pull(fake_socket): @mark_sync_test def test_hello_passes_routing_metadata(fake_socket_pair): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt5x0.PACKER_CLS, unpacker_cls=Bolt5x0.UNPACKER_CLS) @@ -240,7 +241,7 @@ def test_hello_passes_routing_metadata(fake_socket_pair): def test_hint_recv_timeout_seconds( fake_socket_pair, hints, valid, caplog, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt5x0.PACKER_CLS, unpacker_cls=Bolt5x0.UNPACKER_CLS) @@ -271,6 +272,92 @@ def test_hint_recv_timeout_seconds( for msg in caplog.messages) +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_sync_test +def test_credentials_are_not_logged( + auth, fake_socket_pair, mocker, caplog +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x0.PACKER_CLS, + unpacker_cls=Bolt5x0.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + connection = Bolt5x0( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +@pytest.mark.parametrize("message", ("logon", "logoff")) +def test_auth_message_raises_configuration_error(message, fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = Bolt5x0(address, fake_socket(address), + PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError, + match="User switching is not supported"): + getattr(connection, message)() + + +@pytest.mark.parametrize("auth", ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), +)) +@mark_sync_test +def test_re_auth_noop(auth, fake_socket, mocker): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = Bolt5x0(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth) + logon_spy = mocker.spy(connection, "logon") + logoff_spy = mocker.spy(connection, "logoff") + res = connection.re_auth(auth, None) + + assert res is False + logon_spy.assert_not_called() + logoff_spy.assert_not_called() + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + permutations( + ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), + ), + 2 + ) +) +@mark_sync_test +def test_re_auth(auth1, auth2, fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + connection = Bolt5x0(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth1) + with pytest.raises(ConfigurationError, + match="User switching is not supported"): + connection.re_auth(auth2, None) + + @pytest.mark.parametrize(("method", "args"), ( ("run", ("RETURN 1",)), ("begin", ()), @@ -286,7 +373,7 @@ def test_hint_recv_timeout_seconds( )) def test_does_not_support_notification_filters(fake_socket, method, args, kwargs): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x0.UNPACKER_CLS) connection = Bolt5x0(address, socket, PoolConfig.max_connection_lifetime) @@ -308,7 +395,7 @@ def test_does_not_support_notification_filters(fake_socket, method, def test_hello_does_not_support_notification_filters( fake_socket, kwargs ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x0.UNPACKER_CLS) connection = Bolt5x0( address, socket, PoolConfig.max_connection_lifetime, @@ -316,65 +403,3 @@ def test_hello_does_not_support_notification_filters( ) with pytest.raises(ConfigurationError, match="Notification filtering"): connection.hello() - - -class HackedAuth: - def __init__(self, dict_): - self.__dict__ = dict_ - - -@mark_sync_test -@pytest.mark.parametrize("auth", ( - ("awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - foo="bar"), - HackedAuth({ - "scheme": "super nice scheme", "principal": "awesome test user", - "credentials": "safe p4ssw0rd", "realm": "super duper realm", - "parameters": {"credentials": "should be visible!"}, - }) - -)) -def test_hello_does_not_log_credentials(fake_socket_pair, caplog, auth): - def items(): - if isinstance(auth, tuple): - yield "scheme", "basic" - yield "principal", auth[0] - yield "credentials", auth[1] - elif isinstance(auth, Auth): - for key in ("scheme", "principal", "credentials", "realm", - "parameters"): - value = getattr(auth, key, None) - if value: - yield key, value - elif isinstance(auth, HackedAuth): - yield from auth.__dict__.items() - else: - raise TypeError(auth) - - address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address, - packer_cls=Bolt5x0.PACKER_CLS, - unpacker_cls=Bolt5x0.UNPACKER_CLS) - sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) - max_connection_lifetime = 0 - connection = Bolt5x0(address, sockets.client, - max_connection_lifetime, auth=auth) - - with caplog.at_level(logging.DEBUG): - connection.hello() - - hellos = [m for m in caplog.messages if "C: HELLO" in m] - assert len(hellos) == 1 - hello = hellos[0] - - for key, value in items(): - if key == "credentials": - assert value not in hello - else: - assert str({key: value})[1:-1] in hello diff --git a/tests/unit/sync/io/test_class_bolt5x1.py b/tests/unit/sync/io/test_class_bolt5x1.py index da50daea..9ef06704 100644 --- a/tests/unit/sync/io/test_class_bolt5x1.py +++ b/tests/unit/sync/io/test_class_bolt5x1.py @@ -20,16 +20,397 @@ import pytest +import neo4j +import neo4j.exceptions from neo4j._conf import PoolConfig from neo4j._sync.io._bolt5 import Bolt5x1 -from neo4j.api import Auth +from neo4j.auth_management import AuthManagers from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_sync_test -# TODO: proper testing should come from the re-auth ADR, -# which properly introduces Bolt 5.1 +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 0 + connection = Bolt5x1(address, fake_socket(address), + max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = -1 + connection = Bolt5x1(address, fake_socket(address), + max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 999999999 + connection = Bolt5x1(address, fake_socket(address), + max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},) + ), +)) +@mark_sync_test +def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x1.UNPACKER_CLS) + connection = Bolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.begin(*args, **kwargs) + connection.send_all() + tag, is_fields = socket.pop_message() + assert tag == b"\x11" + assert tuple(is_fields) == expected_fields + + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + (("", {}), {"imp_user": "imposter"}, ("", {}, {"imp_user": "imposter"})), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}) + ), +)) +@mark_sync_test +def test_extra_in_run(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x1.UNPACKER_CLS) + connection = Bolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.run(*args, **kwargs) + connection.send_all() + tag, is_fields = socket.pop_message() + assert tag == b"\x10" + assert tuple(is_fields) == expected_fields + + +@mark_sync_test +def test_n_extra_in_discard(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x1.UNPACKER_CLS) + connection = Bolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.discard(n=666) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x1.UNPACKER_CLS) + connection = Bolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.discard(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ] +) +@mark_sync_test +def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x1.UNPACKER_CLS) + connection = Bolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.discard(n=666, qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_n_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x1.UNPACKER_CLS) + connection = Bolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.pull(n=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_qid_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x1.UNPACKER_CLS) + connection = Bolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.pull(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_sync_test +def test_n_and_qid_extras_in_pull(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x1.UNPACKER_CLS) + connection = Bolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.pull(n=666, qid=777) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_sync_test +def test_hello_passes_routing_metadata(fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x1.PACKER_CLS, + unpacker_cls=Bolt5x1.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x1( + address, sockets.client, PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"} + ) + connection.hello() + tag, fields = sockets.server.pop_message() + assert tag == b"\x01" + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +def _assert_logon_message(sockets, auth): + tag, fields = sockets.server.pop_message() + assert tag == b"\x6A" # LOGON + assert len(fields) == 1 + keys = ["scheme", "principal", "credentials"] + assert list(fields[0].keys()) == keys + for key in keys: + assert fields[0][key] == getattr(auth, key) + + +@mark_sync_test +def test_hello_pipelines_logon(fake_socket_pair): + auth = neo4j.Auth("basic", "alice123", "supersecret123") + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x1.PACKER_CLS, + unpacker_cls=Bolt5x1.UNPACKER_CLS) + sockets.server.send_message( + b"\x7F", {"code": "Neo.DatabaseError.General.MadeUpError", + "message": "kthxbye"} + ) + connection = Bolt5x1( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with pytest.raises(neo4j.exceptions.Neo4jError): + connection.hello() + tag, fields = sockets.server.pop_message() + assert tag == b"\x01" # HELLO + assert len(fields) == 1 + assert list(fields[0].keys()) == ["user_agent"] + assert auth.credentials not in repr(fields) + _assert_logon_message(sockets, auth) + + +@mark_sync_test +def test_logon(fake_socket_pair): + auth = neo4j.Auth("basic", "alice123", "supersecret123") + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x1.PACKER_CLS, + unpacker_cls=Bolt5x1.UNPACKER_CLS) + connection = Bolt5x1(address, sockets.client, + PoolConfig.max_connection_lifetime, auth=auth) + connection.logon() + connection.send_all() + _assert_logon_message(sockets, auth) + + +@mark_sync_test +def test_re_auth(fake_socket_pair, mocker, static_auth): + auth = neo4j.Auth("basic", "alice123", "supersecret123") + auth_manager = static_auth(auth) + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x1.PACKER_CLS, + unpacker_cls=Bolt5x1.UNPACKER_CLS) + sockets.server.send_message( + b"\x7F", {"code": "Neo.DatabaseError.General.MadeUpError", + "message": "kthxbye"} + ) + connection = Bolt5x1(address, sockets.client, + PoolConfig.max_connection_lifetime) + connection.pool = mocker.MagicMock() + connection.re_auth(auth, auth_manager) + connection.send_all() + with pytest.raises(neo4j.exceptions.Neo4jError): + connection.fetch_all() + tag, fields = sockets.server.pop_message() + assert tag == b"\x6B" # LOGOFF + assert len(fields) == 0 + _assert_logon_message(sockets, auth) + assert connection.auth is auth + assert connection.auth_manager is auth_manager + + +@mark_sync_test +def test_logoff(fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x1.PACKER_CLS, + unpacker_cls=Bolt5x1.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x1(address, sockets.client, + PoolConfig.max_connection_lifetime) + connection.logoff() + assert not sockets.server.recv_buffer # pipelined, so no response yet + connection.send_all() + assert sockets.server.recv_buffer # now! + tag, fields = sockets.server.pop_message() + assert tag == b"\x6B" # LOGOFF + assert len(fields) == 0 + + +@pytest.mark.parametrize(("hints", "valid"), ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), +)) +@mark_sync_test +def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog, mocker +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x1.PACKER_CLS, + unpacker_cls=Bolt5x1.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + sockets.server.send_message( + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} + ) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x1( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + with caplog.at_level(logging.INFO): + connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any("recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + else: + sockets.client.settimeout.assert_not_called() + assert any(repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_sync_test +def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x1.PACKER_CLS, + unpacker_cls=Bolt5x1.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x1( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text @pytest.mark.parametrize(("method", "args"), ( @@ -47,7 +428,7 @@ )) def test_does_not_support_notification_filters(fake_socket, method, args, kwargs): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x1.UNPACKER_CLS) connection = Bolt5x1(address, socket, PoolConfig.max_connection_lifetime) @@ -69,7 +450,7 @@ def test_does_not_support_notification_filters(fake_socket, method, def test_hello_does_not_support_notification_filters( fake_socket, kwargs ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x1.UNPACKER_CLS) connection = Bolt5x1( address, socket, PoolConfig.max_connection_lifetime, @@ -77,66 +458,3 @@ def test_hello_does_not_support_notification_filters( ) with pytest.raises(ConfigurationError, match="Notification filtering"): connection.hello() - - -class HackedAuth: - def __init__(self, dict_): - self.__dict__ = dict_ - - -@mark_sync_test -@pytest.mark.parametrize("auth", ( - ("awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - foo="bar"), - HackedAuth({ - "scheme": "super nice scheme", "principal": "awesome test user", - "credentials": "safe p4ssw0rd", "realm": "super duper realm", - "parameters": {"credentials": "should be visible!"}, - }) - -)) -def test_hello_does_not_log_credentials(fake_socket_pair, caplog, auth): - def items(): - if isinstance(auth, tuple): - yield "scheme", "basic" - yield "principal", auth[0] - yield "credentials", auth[1] - elif isinstance(auth, Auth): - for key in ("scheme", "principal", "credentials", "realm", - "parameters"): - value = getattr(auth, key, None) - if value: - yield key, value - elif isinstance(auth, HackedAuth): - yield from auth.__dict__.items() - else: - raise TypeError(auth) - - address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address, - packer_cls=Bolt5x1.PACKER_CLS, - unpacker_cls=Bolt5x1.UNPACKER_CLS) - sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) - sockets.server.send_message(b"\x70", {}) - max_connection_lifetime = 0 - connection = Bolt5x1(address, sockets.client, - max_connection_lifetime, auth=auth) - - with caplog.at_level(logging.DEBUG): - connection.hello() - - logons = [m for m in caplog.messages if "C: LOGON " in m] - assert len(logons) == 1 - logon = logons[0] - - for key, value in items(): - if key == "credentials": - assert value not in logon - else: - assert str({key: value})[1:-1] in logon diff --git a/tests/unit/sync/io/test_class_bolt5x2.py b/tests/unit/sync/io/test_class_bolt5x2.py index 084424de..5eb2092c 100644 --- a/tests/unit/sync/io/test_class_bolt5x2.py +++ b/tests/unit/sync/io/test_class_bolt5x2.py @@ -14,21 +14,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + import itertools import logging import pytest +import neo4j from neo4j._conf import PoolConfig from neo4j._sync.io._bolt5 import Bolt5x2 -from neo4j.api import Auth +from neo4j.auth_management import AuthManagers from ...._async_compat import mark_sync_test @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = Bolt5x2(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -38,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = Bolt5x2(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -48,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = Bolt5x2(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -67,7 +70,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): )) @mark_sync_test def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x2.UNPACKER_CLS) connection = Bolt5x2(address, socket, PoolConfig.max_connection_lifetime) connection.begin(*args, **kwargs) @@ -88,7 +91,7 @@ def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): )) @mark_sync_test def test_extra_in_run(fake_socket, args, kwargs, expected_fields): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x2.UNPACKER_CLS) connection = Bolt5x2(address, socket, PoolConfig.max_connection_lifetime) connection.run(*args, **kwargs) @@ -100,7 +103,7 @@ def test_extra_in_run(fake_socket, args, kwargs, expected_fields): @mark_sync_test def test_n_extra_in_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x2.UNPACKER_CLS) connection = Bolt5x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) @@ -120,7 +123,7 @@ def test_n_extra_in_discard(fake_socket): ) @mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x2.UNPACKER_CLS) connection = Bolt5x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) @@ -140,7 +143,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): ) @mark_sync_test def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x2.UNPACKER_CLS) connection = Bolt5x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) @@ -160,7 +163,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): ) @mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x2.UNPACKER_CLS) connection = Bolt5x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) @@ -180,7 +183,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): ) @mark_sync_test def test_qid_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x2.UNPACKER_CLS) connection = Bolt5x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) @@ -193,7 +196,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x2.UNPACKER_CLS) connection = Bolt5x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) @@ -206,7 +209,7 @@ def test_n_and_qid_extras_in_pull(fake_socket): @mark_sync_test def test_hello_passes_routing_metadata(fake_socket_pair): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt5x2.PACKER_CLS, unpacker_cls=Bolt5x2.UNPACKER_CLS) @@ -223,6 +226,99 @@ def test_hello_passes_routing_metadata(fake_socket_pair): assert fields[0]["routing"] == {"foo": "bar"} +def _assert_logon_message(sockets, auth): + tag, fields = sockets.server.pop_message() + assert tag == b"\x6A" # LOGON + assert len(fields) == 1 + keys = ["scheme", "principal", "credentials"] + assert list(fields[0].keys()) == keys + for key in keys: + assert fields[0][key] == getattr(auth, key) + + +@mark_sync_test +def test_hello_pipelines_logon(fake_socket_pair): + auth = neo4j.Auth("basic", "alice123", "supersecret123") + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x2.PACKER_CLS, + unpacker_cls=Bolt5x2.UNPACKER_CLS) + sockets.server.send_message( + b"\x7F", {"code": "Neo.DatabaseError.General.MadeUpError", + "message": "kthxbye"} + ) + connection = Bolt5x2( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with pytest.raises(neo4j.exceptions.Neo4jError): + connection.hello() + tag, fields = sockets.server.pop_message() + assert tag == b"\x01" # HELLO + assert len(fields) == 1 + assert list(fields[0].keys()) == ["user_agent"] + assert auth.credentials not in repr(fields) + _assert_logon_message(sockets, auth) + + +@mark_sync_test +def test_logon(fake_socket_pair): + auth = neo4j.Auth("basic", "alice123", "supersecret123") + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x2.PACKER_CLS, + unpacker_cls=Bolt5x2.UNPACKER_CLS) + connection = Bolt5x2(address, sockets.client, + PoolConfig.max_connection_lifetime, auth=auth) + connection.logon() + connection.send_all() + _assert_logon_message(sockets, auth) + + +@mark_sync_test +def test_re_auth(fake_socket_pair, mocker, static_auth): + auth = neo4j.Auth("basic", "alice123", "supersecret123") + auth_manager = static_auth(auth) + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x2.PACKER_CLS, + unpacker_cls=Bolt5x2.UNPACKER_CLS) + sockets.server.send_message( + b"\x7F", {"code": "Neo.DatabaseError.General.MadeUpError", + "message": "kthxbye"} + ) + connection = Bolt5x2(address, sockets.client, + PoolConfig.max_connection_lifetime) + connection.pool = mocker.MagicMock() + connection.re_auth(auth, auth_manager) + connection.send_all() + with pytest.raises(neo4j.exceptions.Neo4jError): + connection.fetch_all() + tag, fields = sockets.server.pop_message() + assert tag == b"\x6B" # LOGOFF + assert len(fields) == 0 + _assert_logon_message(sockets, auth) + assert connection.auth is auth + assert connection.auth_manager is auth_manager + + +@mark_sync_test +def test_logoff(fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x2.PACKER_CLS, + unpacker_cls=Bolt5x2.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x2(address, sockets.client, + PoolConfig.max_connection_lifetime) + connection.logoff() + assert not sockets.server.recv_buffer # pipelined, so no response yet + connection.send_all() + assert sockets.server.recv_buffer # now! + tag, fields = sockets.server.pop_message() + assert tag == b"\x6B" # LOGOFF + assert len(fields) == 0 + + @pytest.mark.parametrize(("hints", "valid"), ( ({"connection.recv_timeout_seconds": 1}, True), ({"connection.recv_timeout_seconds": 42}, True), @@ -239,7 +335,7 @@ def test_hello_passes_routing_metadata(fake_socket_pair): def test_hint_recv_timeout_seconds( fake_socket_pair, hints, valid, caplog, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt5x2.PACKER_CLS, unpacker_cls=Bolt5x2.UNPACKER_CLS) @@ -271,6 +367,40 @@ def test_hint_recv_timeout_seconds( for msg in caplog.messages) +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_sync_test +def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x2.PACKER_CLS, + unpacker_cls=Bolt5x2.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x2( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + def _assert_notifications_in_extra(extra, expected): for key in expected: assert key in extra @@ -295,7 +425,7 @@ def test_supports_notification_filters( fake_socket, method, args, extra_idx, cls_min_sev, method_min_sev, cls_dis_cats, method_dis_cats ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x2.UNPACKER_CLS) connection = Bolt5x2( address, socket, PoolConfig.max_connection_lifetime, @@ -325,7 +455,7 @@ def test_supports_notification_filters( def test_hello_supports_notification_filters( fake_socket_pair, min_sev, dis_cats ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt5x2.PACKER_CLS, unpacker_cls=Bolt5x2.UNPACKER_CLS) @@ -347,66 +477,3 @@ def test_hello_supports_notification_filters( if dis_cats is not None: expected["notifications_disabled_categories"] = dis_cats _assert_notifications_in_extra(extra, expected) - - -class HackedAuth: - def __init__(self, dict_): - self.__dict__ = dict_ - - -@mark_sync_test -@pytest.mark.parametrize("auth", ( - ("awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - realm="super duper realm"), - Auth("super nice scheme", "awesome test user", "safe p4ssw0rd", - foo="bar"), - HackedAuth({ - "scheme": "super nice scheme", "principal": "awesome test user", - "credentials": "safe p4ssw0rd", "realm": "super duper realm", - "parameters": {"credentials": "should be visible!"}, - }) - -)) -def test_hello_does_not_log_credentials(fake_socket_pair, caplog, auth): - def items(): - if isinstance(auth, tuple): - yield "scheme", "basic" - yield "principal", auth[0] - yield "credentials", auth[1] - elif isinstance(auth, Auth): - for key in ("scheme", "principal", "credentials", "realm", - "parameters"): - value = getattr(auth, key, None) - if value: - yield key, value - elif isinstance(auth, HackedAuth): - yield from auth.__dict__.items() - else: - raise TypeError(auth) - - address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address, - packer_cls=Bolt5x2.PACKER_CLS, - unpacker_cls=Bolt5x2.UNPACKER_CLS) - sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) - sockets.server.send_message(b"\x70", {}) - max_connection_lifetime = 0 - connection = Bolt5x2(address, sockets.client, - max_connection_lifetime, auth=auth) - - with caplog.at_level(logging.DEBUG): - connection.hello() - - logons = [m for m in caplog.messages if "C: LOGON " in m] - assert len(logons) == 1 - logon = logons[0] - - for key, value in items(): - if key == "credentials": - assert value not in logon - else: - assert str({key: value})[1:-1] in logon diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py index 45a901b8..5349f5a2 100644 --- a/tests/unit/sync/io/test_direct.py +++ b/tests/unit/sync/io/test_direct.py @@ -18,6 +18,7 @@ import pytest +from neo4j import PreviewWarning from neo4j._conf import ( Config, PoolConfig, @@ -26,6 +27,7 @@ from neo4j._deadline import Deadline from neo4j._sync.io import Bolt from neo4j._sync.io._pool import IOPool +from neo4j.auth_management import AuthManagers from neo4j.exceptions import ( ClientError, ServiceUnavailable, @@ -65,6 +67,9 @@ def stale(self): def reset(self): pass + def re_auth(self, auth, auth_manager, force=False): + return False + def close(self): self.socket.close() @@ -79,39 +84,46 @@ def timedout(self): class FakeBoltPool(IOPool): - def __init__(self, address, *, auth=None, **config): + config["auth"] = static_auth(None) self.pool_config, self.workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) if config: raise ValueError("Unexpected config keys: %s" % ", ".join(config.keys())) - def opener(addr, timeout): + def opener(addr, auth, timeout): return QuickConnection(FakeSocket(addr)) super().__init__(opener, self.pool_config, self.workspace_config) self.address = address def acquire( - self, access_mode, timeout, database, bookmarks, liveness_check_timeout + self, access_mode, timeout, database, bookmarks, auth, + liveness_check_timeout ): return self._acquire( - self.address, timeout, liveness_check_timeout + self.address, auth, timeout, liveness_check_timeout ) +def static_auth(auth): + with pytest.warns(PreviewWarning, match="Auth managers"): + return AuthManagers.static(auth) + +@pytest.fixture +def auth_manager(): + static_auth(("test", "test")) @mark_sync_test -def test_bolt_connection_open(): +def test_bolt_connection_open(auth_manager): with pytest.raises(ServiceUnavailable): - Bolt.open( - ("localhost", 9999), auth=("test", "test") - ) + Bolt.open(("localhost", 9999), auth_manager=auth_manager) @mark_sync_test -def test_bolt_connection_open_timeout(): +def test_bolt_connection_open_timeout(auth_manager): with pytest.raises(ServiceUnavailable): Bolt.open( - ("localhost", 9999), auth=("test", "test"), deadline=Deadline(1) + ("localhost", 9999), auth_manager=auth_manager, + deadline=Deadline(1) ) @@ -150,7 +162,7 @@ def assert_pool_size( address, expected_active, expected_inactive, pool): @mark_sync_test def test_pool_can_acquire(pool): address = ("127.0.0.1", 7687) - connection = pool._acquire(address, Deadline(3), None) + connection = pool._acquire(address, None, Deadline(3), None) assert connection.address == address assert_pool_size(address, 1, 0, pool) @@ -158,8 +170,8 @@ def test_pool_can_acquire(pool): @mark_sync_test def test_pool_can_acquire_twice(pool): address = ("127.0.0.1", 7687) - connection_1 = pool._acquire(address, Deadline(3), None) - connection_2 = pool._acquire(address, Deadline(3), None) + connection_1 = pool._acquire(address, None, Deadline(3), None) + connection_2 = pool._acquire(address, None, Deadline(3), None) assert connection_1.address == address assert connection_2.address == address assert connection_1 is not connection_2 @@ -170,8 +182,8 @@ def test_pool_can_acquire_twice(pool): def test_pool_can_acquire_two_addresses(pool): address_1 = ("127.0.0.1", 7687) address_2 = ("127.0.0.1", 7474) - connection_1 = pool._acquire(address_1, Deadline(3), None) - connection_2 = pool._acquire(address_2, Deadline(3), None) + connection_1 = pool._acquire(address_1, None, Deadline(3), None) + connection_2 = pool._acquire(address_2, None, Deadline(3), None) assert connection_1.address == address_1 assert connection_2.address == address_2 assert_pool_size(address_1, 1, 0, pool) @@ -181,7 +193,7 @@ def test_pool_can_acquire_two_addresses(pool): @mark_sync_test def test_pool_can_acquire_and_release(pool): address = ("127.0.0.1", 7687) - connection = pool._acquire(address, Deadline(3), None) + connection = pool._acquire(address, None, Deadline(3), None) assert_pool_size(address, 1, 0, pool) pool.release(connection) assert_pool_size(address, 0, 1, pool) @@ -190,7 +202,7 @@ def test_pool_can_acquire_and_release(pool): @mark_sync_test def test_pool_releasing_twice(pool): address = ("127.0.0.1", 7687) - connection = pool._acquire(address, Deadline(3), None) + connection = pool._acquire(address, None, Deadline(3), None) pool.release(connection) assert_pool_size(address, 0, 1, pool) pool.release(connection) @@ -201,7 +213,7 @@ def test_pool_releasing_twice(pool): def test_pool_in_use_count(pool): address = ("127.0.0.1", 7687) assert pool.in_use_connection_count(address) == 0 - connection = pool._acquire(address, Deadline(3), None) + connection = pool._acquire(address, None, Deadline(3), None) assert pool.in_use_connection_count(address) == 1 pool.release(connection) assert pool.in_use_connection_count(address) == 0 @@ -211,10 +223,10 @@ def test_pool_in_use_count(pool): def test_pool_max_conn_pool_size(pool): with FakeBoltPool((), max_connection_pool_size=1) as pool: address = ("127.0.0.1", 7687) - pool._acquire(address, Deadline(0), None) + pool._acquire(address, None, Deadline(0), None) assert pool.in_use_connection_count(address) == 1 with pytest.raises(ClientError): - pool._acquire(address, Deadline(0), None) + pool._acquire(address, None, Deadline(0), None) assert pool.in_use_connection_count(address) == 1 @@ -232,7 +244,7 @@ def test_pool_reset_when_released(is_reset, pool, mocker): new_callable=mocker.MagicMock ) is_reset_mock.return_value = is_reset - connection = pool._acquire(address, Deadline(3), None) + connection = pool._acquire(address, None, Deadline(3), None) assert isinstance(connection, QuickConnection) assert is_reset_mock.call_count == 0 assert reset_mock.call_count == 0 diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py index ff0ed9e9..cfaaf1f3 100644 --- a/tests/unit/sync/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -21,6 +21,7 @@ import pytest from neo4j import ( + PreviewWarning, READ_ACCESS, WRITE_ACCESS, ) @@ -36,6 +37,7 @@ Neo4jPool, ) from neo4j.addressing import ResolvedAddress +from neo4j.auth_management import AuthManagers from neo4j.exceptions import ( Neo4jError, ServiceUnavailable, @@ -72,10 +74,11 @@ def routing_side_effect(*args, **kwargs): }] raise res - def open_(addr, timeout): + def open_(addr, auth, timeout): connection = fake_connection_generator() - connection.addr = addr + connection.unresolved_address = addr connection.timeout = timeout + connection.auth = auth route_mock = mocker.MagicMock() route_mock.side_effect = routing_side_effect @@ -97,48 +100,59 @@ def opener(routing_failure_opener): return routing_failure_opener() +def _pool_config(): + pool_config = PoolConfig() + pool_config.auth = _auth_manager(("user", "pass")) + return pool_config + + +def _auth_manager(auth): + with pytest.warns(PreviewWarning, match="Auth managers"): + return AuthManagers.static(auth) + + +def _simple_pool(opener) -> Neo4jPool: + return Neo4jPool( + opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS + ) + + @mark_sync_test def test_acquires_new_routing_table_if_deleted(opener): - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) - cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + pool = _simple_pool(opener) + cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx) assert pool.routing_tables.get("test_db") del pool.routing_tables["test_db"] - cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx) assert pool.routing_tables.get("test_db") @mark_sync_test def test_acquires_new_routing_table_if_stale(opener): - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) - cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + pool = _simple_pool(opener) + cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx) assert pool.routing_tables.get("test_db") old_value = pool.routing_tables["test_db"].last_updated_time pool.routing_tables["test_db"].ttl = 0 - cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx) assert pool.routing_tables["test_db"].last_updated_time > old_value @mark_sync_test def test_removes_old_routing_table(opener): - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) - cx = pool.acquire(READ_ACCESS, 30, "test_db1", None, None) + pool = _simple_pool(opener) + cx = pool.acquire(READ_ACCESS, 30, "test_db1", None, None, None) pool.release(cx) assert pool.routing_tables.get("test_db1") - cx = pool.acquire(READ_ACCESS, 30, "test_db2", None, None) + cx = pool.acquire(READ_ACCESS, 30, "test_db2", None, None, None) pool.release(cx) assert pool.routing_tables.get("test_db2") @@ -147,7 +161,7 @@ def test_removes_old_routing_table(opener): pool.routing_tables["test_db2"].ttl = \ -RoutingConfig.routing_table_purge_delay - cx = pool.acquire(READ_ACCESS, 30, "test_db1", None, None) + cx = pool.acquire(READ_ACCESS, 30, "test_db1", None, None, None) pool.release(cx) assert pool.routing_tables["test_db1"].last_updated_time > old_value assert "test_db2" not in pool.routing_tables @@ -156,28 +170,24 @@ def test_removes_old_routing_table(opener): @pytest.mark.parametrize("type_", ("r", "w")) @mark_sync_test def test_chooses_right_connection_type(opener, type_): - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) cx1 = pool.acquire( READ_ACCESS if type_ == "r" else WRITE_ACCESS, - 30, "test_db", None, None + 30, "test_db", None, None, None ) pool.release(cx1) if type_ == "r": - assert cx1.addr == READER_ADDRESS + assert cx1.unresolved_address == READER_ADDRESS else: - assert cx1.addr == WRITER_ADDRESS + assert cx1.unresolved_address == WRITER_ADDRESS @mark_sync_test def test_reuses_connection(opener): - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + pool = _simple_pool(opener) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx1) - cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) assert cx1 is cx2 @@ -185,75 +195,70 @@ def test_reuses_connection(opener): @mark_sync_test def test_closes_stale_connections(opener, break_on_close): def break_connection(): - pool.deactivate(cx1.addr) + pool.deactivate(cx1.unresolved_address) if cx_close_mock_side_effect: res = cx_close_mock_side_effect() if inspect.isawaitable(res): return res - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + pool = _simple_pool(opener) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx1) - assert cx1 in pool.connections[cx1.addr] - # simulate connection going stale (e.g. exceeding) and then breaking when - # the pool tries to close the connection + assert cx1 in pool.connections[cx1.unresolved_address] + # simulate connection going stale (e.g. exceeding idle timeout) and then + # breaking when the pool tries to close the connection cx1.stale.return_value = True cx_close_mock = cx1.close if break_on_close: cx_close_mock_side_effect = cx_close_mock.side_effect cx_close_mock.side_effect = break_connection - cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx2) if break_on_close: cx1.close.assert_called() else: cx1.close.assert_called_once() assert cx2 is not cx1 - assert cx2.addr == cx1.addr - assert cx1 not in pool.connections[cx1.addr] - assert cx2 in pool.connections[cx2.addr] + assert cx2.unresolved_address == cx1.unresolved_address + assert cx1 not in pool.connections[cx1.unresolved_address] + assert cx2 in pool.connections[cx2.unresolved_address] @mark_sync_test def test_does_not_close_stale_connections_in_use(opener): - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) - assert cx1 in pool.connections[cx1.addr] - # simulate connection going stale (e.g. exceeding) while being in use + pool = _simple_pool(opener) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + assert cx1 in pool.connections[cx1.unresolved_address] + # simulate connection going stale (e.g. exceeding idle timeout) while being + # in use cx1.stale.return_value = True - cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx2) cx1.close.assert_not_called() assert cx2 is not cx1 - assert cx2.addr == cx1.addr - assert cx1 in pool.connections[cx1.addr] - assert cx2 in pool.connections[cx2.addr] + assert cx2.unresolved_address == cx1.unresolved_address + assert cx1 in pool.connections[cx1.unresolved_address] + assert cx2 in pool.connections[cx2.unresolved_address] pool.release(cx1) # now that cx1 is back in the pool and still stale, # it should be closed when trying to acquire the next connection cx1.close.assert_not_called() - cx3 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx3 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx3) cx1.close.assert_called_once() assert cx2 is cx3 - assert cx3.addr == cx1.addr - assert cx1 not in pool.connections[cx1.addr] - assert cx3 in pool.connections[cx2.addr] + assert cx3.unresolved_address == cx1.unresolved_address + assert cx1 not in pool.connections[cx1.unresolved_address] + assert cx3 in pool.connections[cx2.unresolved_address] @mark_sync_test def test_release_resets_connections(opener): - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + pool = _simple_pool(opener) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) cx1.is_reset_mock.return_value = False cx1.is_reset_mock.reset_mock() pool.release(cx1) @@ -263,10 +268,8 @@ def test_release_resets_connections(opener): @mark_sync_test def test_release_does_not_resets_closed_connections(opener): - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + pool = _simple_pool(opener) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) cx1.closed.return_value = True cx1.closed.reset_mock() cx1.is_reset_mock.reset_mock() @@ -278,10 +281,8 @@ def test_release_does_not_resets_closed_connections(opener): @mark_sync_test def test_release_does_not_resets_defunct_connections(opener): - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + pool = _simple_pool(opener) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) cx1.defunct.return_value = True cx1.defunct.reset_mock() cx1.is_reset_mock.reset_mock() @@ -296,11 +297,10 @@ def test_release_does_not_resets_defunct_connections(opener): def test_acquire_performs_no_liveness_check_on_fresh_connection( opener, liveness_timeout ): - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) - cx1 = pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) - assert cx1.addr == READER_ADDRESS + pool = _simple_pool(opener) + cx1 = pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) + assert cx1.unresolved_address == READER_ADDRESS cx1.reset.assert_not_called() @@ -309,14 +309,13 @@ def test_acquire_performs_no_liveness_check_on_fresh_connection( def test_acquire_performs_liveness_check_on_existing_connection( opener, liveness_timeout ): - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) # populate the pool with a connection - cx1 = pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) + cx1 = pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) # make sure we assume the right state - assert cx1.addr == READER_ADDRESS + assert cx1.unresolved_address == READER_ADDRESS cx1.is_idle_for.assert_not_called() cx1.reset.assert_not_called() @@ -327,7 +326,8 @@ def test_acquire_performs_liveness_check_on_existing_connection( cx1.reset.assert_not_called() # then acquire it again and assert the liveness check was performed - cx2 = pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) + cx2 = pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) assert cx1 is cx2 cx1.is_idle_for.assert_called_once_with(liveness_timeout) cx2.reset.assert_called_once() @@ -343,14 +343,13 @@ def liveness_side_effect(*args, **kwargs): raise liveness_error("liveness check failed") liveness_timeout = 1 - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) # populate the pool with a connection - cx1 = pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) + cx1 = pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) # make sure we assume the right state - assert cx1.addr == READER_ADDRESS + assert cx1.unresolved_address == READER_ADDRESS cx1.is_idle_for.assert_not_called() cx1.reset.assert_not_called() @@ -363,13 +362,14 @@ def liveness_side_effect(*args, **kwargs): cx1.reset.assert_not_called() # then acquire it again and assert the liveness check was performed - cx2 = pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) + cx2 = pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) assert cx1 is not cx2 - assert cx1.addr == cx2.addr + assert cx1.unresolved_address == cx2.unresolved_address cx1.is_idle_for.assert_called_once_with(liveness_timeout) cx2.reset.assert_not_called() - assert cx1 not in pool.connections[cx1.addr] - assert cx2 in pool.connections[cx1.addr] + assert cx1 not in pool.connections[cx1.unresolved_address] + assert cx2 in pool.connections[cx1.unresolved_address] @pytest.mark.parametrize("liveness_error", @@ -382,16 +382,16 @@ def liveness_side_effect(*args, **kwargs): raise liveness_error("liveness check failed") liveness_timeout = 1 - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) # populate the pool with a connection - cx1 = pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) - cx2 = pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) + cx1 = pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) + cx2 = pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) # make sure we assume the right state - assert cx1.addr == READER_ADDRESS - assert cx2.addr == READER_ADDRESS + assert cx1.unresolved_address == READER_ADDRESS + assert cx2.unresolved_address == READER_ADDRESS assert cx1 is not cx2 cx1.is_idle_for.assert_not_called() cx2.is_idle_for.assert_not_called() @@ -409,14 +409,15 @@ def liveness_side_effect(*args, **kwargs): cx2.reset.assert_not_called() # then acquire it again and assert the liveness check was performed - cx3 = pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) + cx3 = pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) assert cx3 is cx2 cx1.is_idle_for.assert_called_once_with(liveness_timeout) cx1.reset.assert_called_once() cx3.is_idle_for.assert_called_once_with(liveness_timeout) cx3.reset.assert_called_once() - assert cx1 not in pool.connections[cx1.addr] - assert cx3 in pool.connections[cx1.addr] + assert cx1 not in pool.connections[cx1.unresolved_address] + assert cx3 in pool.connections[cx1.unresolved_address] @mark_sync_test @@ -431,11 +432,9 @@ def close_side_effect(): "close") # create pool with 2 idle connections - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) - cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + pool = _simple_pool(opener) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx1) pool.release(cx2) @@ -447,7 +446,7 @@ def close_side_effect(): # unreachable cx1.stale.return_value = True - cx3 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx3 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) assert cx3 is not cx1 assert cx3 is not cx2 @@ -455,26 +454,24 @@ def close_side_effect(): @mark_sync_test def test_failing_opener_leaves_connections_in_use_alone(opener): - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + pool = _simple_pool(opener) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) opener.side_effect = ServiceUnavailable("Server overloaded") with pytest.raises((ServiceUnavailable, SessionExpired)): - pool.acquire(READ_ACCESS, 30, "test_db", None, None) + pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) assert not cx1.closed() @mark_sync_test def test__acquire_new_later_with_room(opener): - config = PoolConfig() + config = _pool_config() config.max_connection_pool_size = 1 pool = Neo4jPool( opener, config, WorkspaceConfig(), ROUTER1_ADDRESS ) assert pool.connections_reservations[READER_ADDRESS] == 0 - creator = pool._acquire_new_later(READER_ADDRESS, Deadline(1)) + creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1)) assert pool.connections_reservations[READER_ADDRESS] == 1 assert callable(creator) if Util.is_async_code: @@ -483,15 +480,15 @@ def test__acquire_new_later_with_room(opener): @mark_sync_test def test__acquire_new_later_without_room(opener): - config = PoolConfig() + config = _pool_config() config.max_connection_pool_size = 1 pool = Neo4jPool( opener, config, WorkspaceConfig(), ROUTER1_ADDRESS ) - _ = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + _ = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) # pool is full now assert pool.connections_reservations[READER_ADDRESS] == 0 - creator = pool._acquire_new_later(READER_ADDRESS, Deadline(1)) + creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1)) assert pool.connections_reservations[READER_ADDRESS] == 0 assert creator is None @@ -503,11 +500,11 @@ def test_passes_pool_config_to_connection(mocker): pool_config = PoolConfig() workspace_config = WorkspaceConfig() pool = Neo4jPool.open( - mocker.Mock, auth=("a", "b"), pool_config=pool_config, workspace_config=workspace_config + mocker.Mock, pool_config=pool_config, workspace_config=workspace_config ) _ = pool._acquire( - mocker.Mock, Deadline.from_timeout_or_deadline(30), None + mocker.Mock, None, Deadline.from_timeout_or_deadline(30), None ) bolt_mock.assert_called_once() @@ -528,14 +525,14 @@ def test_discovery_is_retried(routing_failure_opener, error): error, # will be retried ]) pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), + opener, _pool_config(), WorkspaceConfig(), ResolvedAddress(("1.2.3.1", 9999), host_name="host") ) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx1) pool.routing_tables.get("test_db").ttl = 0 - cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx2) assert pool.routing_tables.get("test_db") @@ -572,15 +569,15 @@ def test_fast_failing_discovery(routing_failure_opener, error): error, # will be retried ]) pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), + opener, _pool_config(), WorkspaceConfig(), ResolvedAddress(("1.2.3.1", 9999), host_name="host") ) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx1) pool.routing_tables.get("test_db").ttl = 0 with pytest.raises(error.__class__) as exc: - pool.acquire(READ_ACCESS, 30, "test_db", None, None) + pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) assert exc.value is error @@ -588,3 +585,66 @@ def test_fast_failing_discovery(routing_failure_opener, error): # reader # failed router assert len(opener.connections) == 3 + + + +@pytest.mark.parametrize( + ("error", "marks_unauthenticated", "fetches_new"), + ( + (Neo4jError.hydrate("message", args[0]), *args[1:]) + for args in ( + ("Neo.ClientError.Database.DatabaseNotFound", False, False), + ("Neo.ClientError.Statement.TypeError", False, False), + ("Neo.ClientError.Statement.ArgumentError", False, False), + ("Neo.ClientError.Request.Invalid", False, False), + ("Neo.ClientError.Security.AuthenticationRateLimit", False, False), + ("Neo.ClientError.Security.CredentialsExpired", False, False), + ("Neo.ClientError.Security.Forbidden", False, False), + ("Neo.ClientError.Security.Unauthorized", False, False), + ("Neo.ClientError.Security.MadeUpError", False, False), + ("Neo.ClientError.Security.TokenExpired", False, True), + ("Neo.ClientError.Security.AuthorizationExpired", True, False), + ) + ) +) +@mark_sync_test +def test_connection_error_callback( + opener, error, marks_unauthenticated, fetches_new, mocker +): + config = _pool_config() + auth_manager = _auth_manager(("user", "auth")) + on_auth_expired_mock = mocker.patch.object(auth_manager, "on_auth_expired", + autospec=True) + config.auth = auth_manager + pool = Neo4jPool( + opener, config, WorkspaceConfig(), ROUTER1_ADDRESS + ) + cxs_read = [ + pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + for _ in range(5) + ] + cxs_write = [ + pool.acquire(WRITE_ACCESS, 30, "test_db", None, None, None) + for _ in range(5) + ] + + on_auth_expired_mock.assert_not_called() + for cx in cxs_read + cxs_write: + cx.mark_unauthenticated.assert_not_called() + + pool.on_neo4j_error(error, cxs_read[0]) + + if fetches_new: + cxs_read[0].auth_manager.on_auth_expired.assert_called_once() + else: + on_auth_expired_mock.assert_not_called() + for cx in cxs_read: + cx.auth_manager.on_auth_expired.assert_not_called() + + for cx in cxs_read: + if marks_unauthenticated: + cx.mark_unauthenticated.assert_called_once() + else: + cx.mark_unauthenticated.assert_not_called() + for cx in cxs_write: + cx.mark_unauthenticated.assert_not_called() diff --git a/tests/unit/sync/test_addressing.py b/tests/unit/sync/test_addressing.py index 190ac416..9274f449 100644 --- a/tests/unit/sync/test_addressing.py +++ b/tests/unit/sync/test_addressing.py @@ -16,10 +16,7 @@ # limitations under the License. -from socket import ( - AF_INET, - AF_INET6, -) +from socket import AF_INET import pytest diff --git a/tests/unit/sync/test_auth_manager.py b/tests/unit/sync/test_auth_manager.py new file mode 100644 index 00000000..15d8b598 --- /dev/null +++ b/tests/unit/sync/test_auth_manager.py @@ -0,0 +1,153 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import typing as t + +import pytest +from freezegun import freeze_time +from freezegun.api import FrozenDateTimeFactory + +from neo4j import ( + Auth, + basic_auth, + PreviewWarning, +) +from neo4j._meta import copy_signature +from neo4j.auth_management import ( + AuthManager, + AuthManagers, + ExpiringAuth, +) + +from ..._async_compat import mark_sync_test + + +SAMPLE_AUTHS = ( + None, + ("user", "password"), + basic_auth("foo", "bar"), + basic_auth("foo", "bar", "baz"), + Auth("scheme", "principal", "credentials", "realm", para="meter"), +) + + +@copy_signature(AuthManagers.static) +def static_auth_manager(*args, **kwargs): + with pytest.warns(PreviewWarning, match="Auth managers"): + return AuthManagers.static(*args, **kwargs) + +@copy_signature(AuthManagers.expiration_based) +def expiration_based_auth_manager(*args, **kwargs): + with pytest.warns(PreviewWarning, match="Auth managers"): + return AuthManagers.expiration_based(*args, **kwargs) + + +@copy_signature(ExpiringAuth) +def expiring_auth(*args, **kwargs): + with pytest.warns(PreviewWarning, match="Auth managers"): + return ExpiringAuth(*args, **kwargs) + + +@mark_sync_test +@pytest.mark.parametrize("auth", SAMPLE_AUTHS) +def test_static_manager( + auth +) -> None: + manager: AuthManager = static_auth_manager(auth) + assert manager.get_auth() is auth + + manager.on_auth_expired(("something", "else")) + assert manager.get_auth() is auth + + manager.on_auth_expired(auth) + assert manager.get_auth() is auth + + +@mark_sync_test +@pytest.mark.parametrize(("auth1", "auth2"), + itertools.product(SAMPLE_AUTHS, repeat=2)) +@pytest.mark.parametrize("expires_in", (None, -1, 1., 1, 1000.)) +def test_expiration_based_manager_manual_expiry( + auth1: t.Union[t.Tuple[str, str], Auth, None], + auth2: t.Union[t.Tuple[str, str], Auth, None], + expires_in: t.Union[float, int], + mocker +) -> None: + if expires_in is None or expires_in >= 0: + temporal_auth = expiring_auth(auth1, expires_in) + else: + temporal_auth = expiring_auth(auth1) + provider = mocker.MagicMock(return_value=temporal_auth) + manager: AuthManager = expiration_based_auth_manager(provider) + + provider.assert_not_called() + assert manager.get_auth() is auth1 + provider.assert_called_once() + provider.reset_mock() + + provider.return_value = expiring_auth(auth2) + + manager.on_auth_expired(("something", "else")) + assert manager.get_auth() is auth1 + provider.assert_not_called() + + manager.on_auth_expired(auth1) + provider.assert_called_once() + provider.reset_mock() + assert manager.get_auth() is auth2 + provider.assert_not_called() + + +@mark_sync_test +@pytest.mark.parametrize(("auth1", "auth2"), + itertools.product(SAMPLE_AUTHS, repeat=2)) +@pytest.mark.parametrize("expires_in", (None, -1, 1., 1, 1000.)) +def test_expiration_based_manager_time_expiry( + auth1: t.Union[t.Tuple[str, str], Auth, None], + auth2: t.Union[t.Tuple[str, str], Auth, None], + expires_in: t.Union[float, int, None], + mocker +) -> None: + with freeze_time() as frozen_time: + assert isinstance(frozen_time, FrozenDateTimeFactory) + if expires_in is None or expires_in >= 0: + temporal_auth = expiring_auth(auth1, expires_in) + else: + temporal_auth = expiring_auth(auth1) + provider = mocker.MagicMock(return_value=temporal_auth) + manager: AuthManager = expiration_based_auth_manager(provider) + + provider.assert_not_called() + assert manager.get_auth() is auth1 + provider.assert_called_once() + provider.reset_mock() + + provider.return_value = expiring_auth(auth2) + + if expires_in is None or expires_in < 0: + frozen_time.tick(1_000_000) + assert manager.get_auth() is auth1 + provider.assert_not_called() + else: + frozen_time.tick(expires_in - 0.000001) + assert manager.get_auth() is auth1 + provider.assert_not_called() + frozen_time.tick(0.000002) + assert manager.get_auth() is auth2 + provider.assert_called_once() diff --git a/tests/unit/sync/test_driver.py b/tests/unit/sync/test_driver.py index bf59fb4d..10b7c6e1 100644 --- a/tests/unit/sync/test_driver.py +++ b/tests/unit/sync/test_driver.py @@ -231,6 +231,7 @@ def test_driver_opens_write_session_by_default(uri, fake_pool, mocker): timeout=mocker.ANY, database=mocker.ANY, bookmarks=mocker.ANY, + auth=mocker.ANY, liveness_check_timeout=mocker.ANY ) tx._begin.assert_called_once_with( @@ -969,7 +970,8 @@ def test_execute_query_result_transformer( with assert_warns_execute_query_bmm_experimental(): bmm = driver.query_bookmark_manager res_custom = driver.execute_query( - "", None, "w", None, None, bmm, result_transformer + "", None, "w", None, None, bmm, None, + result_transformer ) else: res_custom = driver.execute_query( @@ -986,3 +988,18 @@ def test_execute_query_result_transformer( _work, mocker.ANY, mocker.ANY, result_transformer ) assert res is session_executor_mock.return_value + + +@mark_sync_test +def test_supports_session_auth(mocker) -> None: + driver = GraphDatabase.driver("bolt://localhost") + session_cls_mock = mocker.patch("neo4j._sync.driver.Session", + autospec=True) + with driver as driver: + res = driver.supports_session_auth() + + session_cls_mock.assert_called_once() + session_cls_mock.return_value.__enter__.assert_called_once() + session_mock = session_cls_mock.return_value.__enter__.return_value + connection_mock = session_mock._connection + assert res is connection_mock.supports_re_auth diff --git a/tests/unit/sync/work/test_session.py b/tests/unit/sync/work/test_session.py index 4717620e..d54581bd 100644 --- a/tests/unit/sync/work/test_session.py +++ b/tests/unit/sync/work/test_session.py @@ -409,7 +409,7 @@ def test_with_bookmark_manager( additional_session_bookmarks, mocker ): def update_routing_table_side_effect( - database, imp_user, bookmarks, acquisition_timeout=None, + database, imp_user, bookmarks, auth=None, acquisition_timeout=None, database_callback=None ): if home_db_gets_resolved: