Skip to content

Fix suppression of warnings for internal API usage #961

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions src/neo4j/_async/auth_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@


import typing as t
import warnings
from logging import getLogger

from .._async_compat.concurrency import AsyncLock
Expand All @@ -31,10 +30,7 @@
expiring_auth_has_expired,
ExpiringAuth,
)
from .._meta import (
preview,
PreviewWarning,
)
from .._meta import preview

# work around for https://github.com/sphinx-doc/sphinx/pull/10880
# make sure TAuth is resolved in the docs, else they're pretty useless
Expand Down Expand Up @@ -215,11 +211,9 @@ async def auth_provider():
handled_codes = frozenset(("Neo.ClientError.Security.Unauthorized",))

async def wrapped_provider() -> ExpiringAuth:
with warnings.catch_warnings():
warnings.filterwarnings("ignore",
message=r"^Auth managers\b.*",
category=PreviewWarning)
return ExpiringAuth(await provider())
return ExpiringAuth._without_warning( # type: ignore
await provider()
)

return AsyncNeo4jAuthTokenManager(wrapped_provider, handled_codes)

Expand Down
141 changes: 61 additions & 80 deletions src/neo4j/_async/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import asyncio
import typing as t
import warnings


if t.TYPE_CHECKING:
Expand All @@ -47,7 +46,6 @@
experimental_warn,
preview,
preview_warn,
PreviewWarning,
unclosed_resource_warn,
)
from .._work import EagerResult
Expand Down Expand Up @@ -196,12 +194,7 @@ def driver(
driver_type, security_type, parsed = parse_neo4j_uri(uri)

if not isinstance(auth, AsyncAuthManager):
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message=r".*\bAuth managers\b.*",
category=PreviewWarning
)
auth = AsyncAuthManagers.static(auth)
auth = AsyncAuthManagers.static._without_warning(auth)
else:
preview_warn("Auth managers are a preview feature.",
stack_level=2)
Expand Down Expand Up @@ -501,13 +494,6 @@ def encrypted(self) -> bool:
"""Indicate whether the driver was configured to use encryption."""
return bool(self._pool.pool_config.encrypted)

def _prepare_session_config(self, **config):
if "auth" in config:
preview_warn("User switching is a preview feature.",
stack_level=3)
_normalize_notifications_config(config)
return config

if t.TYPE_CHECKING:

def session(
Expand Down Expand Up @@ -549,7 +535,25 @@ def session(self, **config) -> AsyncSession:

:returns: new :class:`neo4j.AsyncSession` object
"""
raise NotImplementedError
session_config = self._read_session_config(config)
return self._session(session_config)

def _session(self, session_config) -> AsyncSession:
return AsyncSession(self._pool, session_config)

def _read_session_config(self, config_kwargs, preview_check=True):
config = self._prepare_session_config(preview_check, config_kwargs)
session_config = SessionConfig(self._default_workspace_config,
config)
return session_config

@classmethod
def _prepare_session_config(cls, preview_check, config_kwargs):
if preview_check and "auth" in config_kwargs:
preview_warn("User switching is a preview feature.",
stack_level=5)
_normalize_notifications_config(config_kwargs)
return config_kwargs

async def close(self) -> None:
""" Shut down, closing any open connections in the pool.
Expand Down Expand Up @@ -844,14 +848,16 @@ async def example(driver: neo4j.AsyncDriver) -> neo4j.Record::
bookmark_manager_ = self._query_bookmark_manager
assert bookmark_manager_ is not _default

with warnings.catch_warnings():
warnings.filterwarnings("ignore",
message=r"^User switching\b.*",
category=PreviewWarning)
session = self.session(database=database_,
impersonated_user=impersonated_user_,
bookmark_manager=bookmark_manager_,
auth=auth_)
session_config = self._read_session_config(
{
"database": database_,
"impersonated_user": impersonated_user_,
"bookmark_manager": bookmark_manager_,
"auth": auth_,
},
preview_check=False
)
session = self._session(session_config)
async with session:
if routing_ == RoutingControl.WRITE:
executor = session.execute_write
Expand Down Expand Up @@ -963,7 +969,8 @@ async def verify_connectivity(self, **config) -> None:
"changed or removed in any future version without prior "
"notice."
)
await self._get_server_info()
session_config = self._read_session_config(config)
await self._get_server_info(session_config)

if t.TYPE_CHECKING:

Expand Down Expand Up @@ -1034,7 +1041,8 @@ async def get_server_info(self, **config) -> ServerInfo:
"changed or removed in any future version without prior "
"notice."
)
return await self._get_server_info()
session_config = self._read_session_config(config)
return await self._get_server_info(session_config)

async def supports_multi_db(self) -> bool:
""" Check if the server or cluster supports multi-databases.
Expand All @@ -1049,7 +1057,8 @@ async def supports_multi_db(self) -> bool:
won't throw a :exc:`ConfigurationError` when trying to use this
driver feature.
"""
async with self.session() as session:
session_config = self._read_session_config({}, preview_check=False)
async with self._session(session_config) as session:
await session._connect(READ_ACCESS)
assert session._connection
return session._connection.supports_multiple_databases
Expand Down Expand Up @@ -1130,30 +1139,24 @@ async def verify_authentication(
"changed or removed in any future version without prior "
"notice."
)
config["auth"] = auth
if "database" not in config:
config["database"] = "system"
try:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message=r"^User switching\b.*",
category=PreviewWarning
)
session = self.session(**config)
async with session as session:
session_config = self._read_session_config(config)
session_config = SessionConfig(session_config, {"auth": auth})
async with self._session(session_config) as session:
try:
await session._verify_authentication()
except Neo4jError as exc:
if exc.code in (
"Neo.ClientError.Security.CredentialsExpired",
"Neo.ClientError.Security.Forbidden",
"Neo.ClientError.Security.TokenExpired",
"Neo.ClientError.Security.Unauthorized",
):
return False
raise
except Neo4jError as exc:
if exc.code in (
"Neo.ClientError.Security.CredentialsExpired",
"Neo.ClientError.Security.Forbidden",
"Neo.ClientError.Security.TokenExpired",
"Neo.ClientError.Security.Unauthorized",
):
return False
raise
return True


async def supports_session_auth(self) -> bool:
"""Check if the remote supports connection re-authentication.

Expand All @@ -1170,13 +1173,14 @@ async def supports_session_auth(self) -> bool:

.. versionadded:: 5.8
"""
async with self.session() as session:
session_config = self._read_session_config({}, preview_check=False)
async with self._session(session_config) as session:
await session._connect(READ_ACCESS)
assert session._connection
return session._connection.supports_re_auth

async def _get_server_info(self, **config) -> ServerInfo:
async with self.session(**config) as session:
async def _get_server_info(self, session_config) -> ServerInfo:
async with self._session(session_config) as session:
return await session._get_server_info()


Expand Down Expand Up @@ -1225,21 +1229,6 @@ def __init__(self, pool, default_workspace_config):
AsyncDriver.__init__(self, pool, default_workspace_config)
self._default_workspace_config = default_workspace_config

if not t.TYPE_CHECKING:

def session(self, **config) -> AsyncSession:
"""
:param config: The values that can be specified are found in
:class: `neo4j.SessionConfig`

:returns:
:rtype: :class: `neo4j.AsyncSession`
"""
config = self._prepare_session_config(**config)
session_config = SessionConfig(self._default_workspace_config,
config)
return AsyncSession(self._pool, session_config)


class AsyncNeo4jDriver(_Routing, AsyncDriver):
""":class:`.AsyncNeo4jDriver` is instantiated for ``neo4j`` URIs. The
Expand All @@ -1264,23 +1253,15 @@ def __init__(self, pool, default_workspace_config):
_Routing.__init__(self, [pool.address])
AsyncDriver.__init__(self, pool, default_workspace_config)

if not t.TYPE_CHECKING:

def session(self, **config) -> AsyncSession:
config = self._prepare_session_config(**config)
session_config = SessionConfig(self._default_workspace_config,
config)
return AsyncSession(self._pool, session_config)


def _normalize_notifications_config(config):
if config.get("notifications_disabled_categories") is not None:
config["notifications_disabled_categories"] = [
def _normalize_notifications_config(config_kwargs):
if config_kwargs.get("notifications_disabled_categories") is not None:
config_kwargs["notifications_disabled_categories"] = [
getattr(e, "value", e)
for e in config["notifications_disabled_categories"]
for e in config_kwargs["notifications_disabled_categories"]
]
if config.get("notifications_min_severity") is not None:
config["notifications_min_severity"] = getattr(
config["notifications_min_severity"], "value",
config["notifications_min_severity"]
if config_kwargs.get("notifications_min_severity") is not None:
config_kwargs["notifications_min_severity"] = getattr(
config_kwargs["notifications_min_severity"], "value",
config_kwargs["notifications_min_severity"]
)
17 changes: 4 additions & 13 deletions src/neo4j/_async/work/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,14 @@

import asyncio
import typing as t
import warnings
from logging import getLogger
from random import random
from time import perf_counter

from ..._async_compat import async_sleep
from ..._async_compat.util import AsyncUtil
from ..._conf import SessionConfig
from ..._meta import (
deprecated,
PreviewWarning,
)
from ..._meta import deprecated
from ..._util import ContextBool
from ..._work import Query
from ...api import (
Expand Down Expand Up @@ -108,14 +104,9 @@ class AsyncSession(AsyncWorkspace):
def __init__(self, pool, session_config):
assert isinstance(session_config, SessionConfig)
if session_config.auth is not None:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message=r".*\bAuth managers\b.*",
category=PreviewWarning
)
session_config.auth = AsyncAuthManagers.static(
session_config.auth
)
session_config.auth = AsyncAuthManagers.static._without_warning(
session_config.auth
)
super().__init__(pool, session_config)
self._config = session_config
self._initialize_bookmarks(session_config.bookmarks)
Expand Down
14 changes: 4 additions & 10 deletions src/neo4j/_auth_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,9 @@
import abc
import time
import typing as t
import warnings
from dataclasses import dataclass

from ._meta import (
preview,
PreviewWarning,
)
from ._meta import preview
from .api import _TAuth
from .exceptions import Neo4jError

Expand Down Expand Up @@ -89,11 +85,9 @@ def expires_in(self, seconds: float) -> "ExpiringAuth":

.. versionadded:: 5.9
"""
with warnings.catch_warnings():
warnings.filterwarnings("ignore",
message=r"^Auth managers\b.*",
category=PreviewWarning)
return ExpiringAuth(self.auth, time.time() + seconds)
return ExpiringAuth._without_warning( # type: ignore
self.auth, time.time() + seconds
)


def expiring_auth_has_expired(auth: ExpiringAuth) -> bool:
Expand Down
Loading