From 694c4fab49daf381aa59938e024f90e66019f287 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 29 Jun 2023 15:50:06 +0200 Subject: [PATCH] Avoid unclean interpreter shutdown When not closing the driver before shutting down the interpreter, the GC might call `__del__` on the driver, which currently (will be changed with 6.0) closes the driver. So any code that's run on driver closure might be execute just before the interpreter shuts down. This can lead to funky errors when trying to import things dynamically. So this should be avoided. This got raised here: https://community.neo4j.com/t/error-with-the-neo4j-python-driver-importerror-sys-meta-path-is-none-python-is-likely-shutting-down/63096/ --- src/neo4j/__init__.py | 3 +-- src/neo4j/_async/io/_bolt.py | 5 ++--- src/neo4j/_async/io/_pool.py | 10 ++++++++++ src/neo4j/_sync/io/_bolt.py | 5 ++--- src/neo4j/_sync/io/_pool.py | 10 ++++++++++ src/neo4j/time/_clock_implementations.py | 2 +- testkitbackend/_async/requests.py | 2 +- testkitbackend/_sync/requests.py | 2 +- tests/conftest.py | 11 ++--------- tests/unit/async_/io/test_direct.py | 2 ++ tests/unit/common/test_api.py | 6 +----- tests/unit/sync/io/test_direct.py | 2 ++ 12 files changed, 35 insertions(+), 25 deletions(-) diff --git a/src/neo4j/__init__.py b/src/neo4j/__init__.py index b18803d5..5b3b34b6 100644 --- a/src/neo4j/__init__.py +++ b/src/neo4j/__init__.py @@ -176,8 +176,7 @@ def __getattr__(name): if name in ( "log", "Config", "PoolConfig", "SessionConfig", "WorkspaceConfig" ): - from ._meta import deprecation_warn - deprecation_warn( + _deprecation_warn( "Importing {} from neo4j is deprecated without replacement. It's " "internal and will be removed in a future version." .format(name), diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index fd4afbc6..696486e8 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -37,6 +37,7 @@ from ..._meta import USER_AGENT from ...addressing import ResolvedAddress from ...api import ( + Auth, ServerInfo, Version, ) @@ -178,7 +179,6 @@ def _to_auth_dict(cls, auth): if not auth: return {} elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: - from ...api import Auth return vars(Auth("basic", *auth)) else: try: @@ -837,8 +837,7 @@ async def _set_defunct_write(self, error=None, silent=False): await self._set_defunct(message, error=error, silent=silent) async def _set_defunct(self, message, error=None, silent=False): - from ._pool import AsyncBoltPool - direct_driver = isinstance(self.pool, AsyncBoltPool) + direct_driver = getattr(self.pool, "is_direct_pool", False) user_cancelled = isinstance(error, asyncio.CancelledError) if error: diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index 714a9dfc..5f8468c4 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -100,6 +100,11 @@ def __init__(self, opener, pool_config, workspace_config): self.lock = AsyncCooperativeRLock() self.cond = AsyncCondition(self.lock) + @property + @abc.abstractmethod + def is_direct_pool(self) -> bool: + ... + async def __aenter__(self): return self @@ -490,6 +495,8 @@ async def close(self): class AsyncBoltPool(AsyncIOPool): + is_direct_pool = True + @classmethod def open(cls, address, *, pool_config, workspace_config): """Create a new BoltPool @@ -536,6 +543,8 @@ class AsyncNeo4jPool(AsyncIOPool): """ Connection pool with routing table. """ + is_direct_pool = False + @classmethod def open(cls, *addresses, pool_config, workspace_config, routing_context=None): @@ -578,6 +587,7 @@ def __init__(self, opener, pool_config, workspace_config, address): self.address = address self.routing_tables = {} self.refresh_lock = AsyncRLock() + self.is_direct_pool = False def __repr__(self): """ The representation shows the initial routing addresses. diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index abba0dce..9d050e93 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -37,6 +37,7 @@ from ..._meta import USER_AGENT from ...addressing import ResolvedAddress from ...api import ( + Auth, ServerInfo, Version, ) @@ -178,7 +179,6 @@ def _to_auth_dict(cls, auth): if not auth: return {} elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: - from ...api import Auth return vars(Auth("basic", *auth)) else: try: @@ -837,8 +837,7 @@ def _set_defunct_write(self, error=None, silent=False): self._set_defunct(message, error=error, silent=silent) def _set_defunct(self, message, error=None, silent=False): - from ._pool import BoltPool - direct_driver = isinstance(self.pool, BoltPool) + direct_driver = getattr(self.pool, "is_direct_pool", False) user_cancelled = isinstance(error, asyncio.CancelledError) if error: diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index b0af10b5..946d2348 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -97,6 +97,11 @@ def __init__(self, opener, pool_config, workspace_config): self.lock = CooperativeRLock() self.cond = Condition(self.lock) + @property + @abc.abstractmethod + def is_direct_pool(self) -> bool: + ... + def __enter__(self): return self @@ -487,6 +492,8 @@ def close(self): class BoltPool(IOPool): + is_direct_pool = True + @classmethod def open(cls, address, *, pool_config, workspace_config): """Create a new BoltPool @@ -533,6 +540,8 @@ class Neo4jPool(IOPool): """ Connection pool with routing table. """ + is_direct_pool = False + @classmethod def open(cls, *addresses, pool_config, workspace_config, routing_context=None): @@ -575,6 +584,7 @@ def __init__(self, opener, pool_config, workspace_config, address): self.address = address self.routing_tables = {} self.refresh_lock = RLock() + self.is_direct_pool = False def __repr__(self): """ The representation shows the initial routing addresses. diff --git a/src/neo4j/time/_clock_implementations.py b/src/neo4j/time/_clock_implementations.py index 60f82cc8..e742cd98 100644 --- a/src/neo4j/time/_clock_implementations.py +++ b/src/neo4j/time/_clock_implementations.py @@ -24,6 +24,7 @@ Structure, ) from platform import uname +from time import time from . import ( Clock, @@ -53,7 +54,6 @@ def available(cls): return True def utc_time(self): - from time import time seconds, nanoseconds = nano_divmod(int(time() * 1000000), 1000000) return ClockTime(seconds, nanoseconds * 1000) diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index dfe891a1..6711187f 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -19,6 +19,7 @@ import datetime import json import re +import ssl import warnings from os import path @@ -56,7 +57,6 @@ def load_config(): config = json.load(fd) skips = config["skips"] features = [k for k, v in config["features"].items() if v is True] - import ssl if ssl.HAS_TLSv1_3: features += ["Feature:TLS:1.3"] return skips, features diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index bd96f751..2d22d9d8 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -19,6 +19,7 @@ import datetime import json import re +import ssl import warnings from os import path @@ -56,7 +57,6 @@ def load_config(): config = json.load(fd) skips = config["skips"] features = [k for k, v in config["features"].items() if v is True] - import ssl if ssl.HAS_TLSv1_3: features += ["Feature:TLS:1.3"] return skips, features diff --git a/tests/conftest.py b/tests/conftest.py index eebfcf0d..8bdb4bf9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,21 +17,17 @@ import asyncio -import warnings +import sys from functools import wraps -from os import environ import pytest import pytest_asyncio from neo4j import ( AsyncGraphDatabase, - ExperimentalWarning, GraphDatabase, ) -from neo4j._exceptions import BoltHandshakeError -from neo4j._sync.io import Bolt -from neo4j.exceptions import ServiceUnavailable +from neo4j.debug import watch from . import env @@ -189,8 +185,5 @@ def _(): @pytest.fixture def watcher(): - import sys - - from neo4j.debug import watch with watch("neo4j", out=sys.stdout, colour=True): yield diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py index 6be98a2e..711ec711 100644 --- a/tests/unit/async_/io/test_direct.py +++ b/tests/unit/async_/io/test_direct.py @@ -85,6 +85,8 @@ def timedout(self): class AsyncFakeBoltPool(AsyncIOPool): + is_direct_pool = False + def __init__(self, address, *, auth=None, **config): config["auth"] = static_auth(None) self.pool_config, self.workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) diff --git a/tests/unit/common/test_api.py b/tests/unit/common/test_api.py index d484d10c..cb934e05 100644 --- a/tests/unit/common/test_api.py +++ b/tests/unit/common/test_api.py @@ -24,6 +24,7 @@ import pytest import neo4j.api +from neo4j.addressing import Address from neo4j.exceptions import ConfigurationError @@ -276,9 +277,6 @@ def test_version_to_bytes_with_valid_bolt_version( def test_serverinfo_initialization() -> None: - - from neo4j.addressing import Address - address = Address(("bolt://localhost", 7687)) version = neo4j.Version(3, 0) @@ -301,8 +299,6 @@ def test_serverinfo_initialization() -> None: def test_serverinfo_with_metadata( test_input, expected_agent, protocol_version ) -> None: - from neo4j.addressing import Address - address = Address(("bolt://localhost", 7687)) version = neo4j.Version(*protocol_version) diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py index aa294c2f..4a4ef6a7 100644 --- a/tests/unit/sync/io/test_direct.py +++ b/tests/unit/sync/io/test_direct.py @@ -85,6 +85,8 @@ def timedout(self): class FakeBoltPool(IOPool): + is_direct_pool = False + def __init__(self, address, *, auth=None, **config): config["auth"] = static_auth(None) self.pool_config, self.workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig)