Skip to content

Commit b42e7f1

Browse files
authored
Fix data race in routing pool (#852)
A race condition could lead to `KeyError` where the pool assumes a routing table for a database exists, while another concurrent execution path purged it.
1 parent 1c55054 commit b42e7f1

File tree

4 files changed

+64
-108
lines changed

4 files changed

+64
-108
lines changed

neo4j/_async/driver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,7 @@ def open(cls, *targets, auth=None, routing_context=None, **config):
669669
return cls(pool, default_workspace_config)
670670

671671
def __init__(self, pool, default_workspace_config):
672-
_Routing.__init__(self, pool.get_default_database_initial_router_addresses())
672+
_Routing.__init__(self, [pool.address])
673673
AsyncDriver.__init__(self, pool, default_workspace_config)
674674

675675
if not t.TYPE_CHECKING:

neo4j/_async/io/_pool.py

Lines changed: 31 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -288,12 +288,9 @@ def in_use_connection_count(self, address):
288288
""" Count the number of connections currently in use to a given
289289
address.
290290
"""
291-
try:
292-
connections = self.connections[address]
293-
except KeyError:
294-
return 0
295-
else:
296-
return sum(1 if connection.in_use else 0 for connection in connections)
291+
with self.lock:
292+
connections = self.connections.get(address, ())
293+
return sum(connection.in_use for connection in connections)
297294

298295
async def mark_all_stale(self):
299296
with self.lock:
@@ -447,7 +444,7 @@ def __init__(self, opener, pool_config, workspace_config, address):
447444
# Each database have a routing table, the default database is a special case.
448445
log.debug("[#0000] C: <NEO4J POOL> routing address %r", address)
449446
self.address = address
450-
self.routing_tables = {workspace_config.database: RoutingTable(database=workspace_config.database, routers=[address])}
447+
self.routing_tables = {}
451448
self.refresh_lock = AsyncRLock()
452449

453450
def __repr__(self):
@@ -456,37 +453,15 @@ def __repr__(self):
456453
:return: The representation
457454
:rtype: str
458455
"""
459-
return "<{} addresses={!r}>".format(self.__class__.__name__, self.get_default_database_initial_router_addresses())
460-
461-
@property
462-
def first_initial_routing_address(self):
463-
return self.get_default_database_initial_router_addresses()[0]
464-
465-
def get_default_database_initial_router_addresses(self):
466-
""" Get the initial router addresses for the default database.
467-
468-
:return:
469-
:rtype: OrderedSet
470-
"""
471-
return self.get_routing_table_for_default_database().initial_routers
472-
473-
def get_default_database_router_addresses(self):
474-
""" Get the router addresses for the default database.
475-
476-
:return:
477-
:rtype: OrderedSet
478-
"""
479-
return self.get_routing_table_for_default_database().routers
480-
481-
def get_routing_table_for_default_database(self):
482-
return self.routing_tables[self.workspace_config.database]
456+
return "<{} address={!r}>".format(self.__class__.__name__,
457+
self.address)
483458

484459
async def get_or_create_routing_table(self, database):
485460
async with self.refresh_lock:
486461
if database not in self.routing_tables:
487462
self.routing_tables[database] = RoutingTable(
488463
database=database,
489-
routers=self.get_default_database_initial_router_addresses()
464+
routers=[self.address]
490465
)
491466
return self.routing_tables[database]
492467

@@ -651,15 +626,15 @@ async def update_routing_table(
651626
if prefer_initial_routing_address:
652627
# TODO: Test this state
653628
if await self._update_routing_table_from(
654-
self.first_initial_routing_address, database=database,
629+
self.address, database=database,
655630
imp_user=imp_user, bookmarks=bookmarks,
656631
acquisition_timeout=acquisition_timeout,
657632
database_callback=database_callback
658633
):
659634
# Why is only the first initial routing address used?
660635
return
661636
if await self._update_routing_table_from(
662-
*(existing_routers - {self.first_initial_routing_address}),
637+
*(existing_routers - {self.address}),
663638
database=database, imp_user=imp_user, bookmarks=bookmarks,
664639
acquisition_timeout=acquisition_timeout,
665640
database_callback=database_callback
@@ -668,7 +643,7 @@ async def update_routing_table(
668643

669644
if not prefer_initial_routing_address:
670645
if await self._update_routing_table_from(
671-
self.first_initial_routing_address, database=database,
646+
self.address, database=database,
672647
imp_user=imp_user, bookmarks=bookmarks,
673648
acquisition_timeout=acquisition_timeout,
674649
database_callback=database_callback
@@ -705,6 +680,14 @@ async def ensure_routing_table_is_fresh(
705680
"""
706681
from neo4j.api import READ_ACCESS
707682
async with self.refresh_lock:
683+
for database_ in list(self.routing_tables.keys()):
684+
# Remove unused databases in the routing table
685+
# Remove the routing table after a timeout = TTL + 30s
686+
log.debug("[#0000] C: <ROUTING AGED> database=%s", database_)
687+
routing_table = self.routing_tables[database_]
688+
if routing_table.should_be_purged_from_memory():
689+
del self.routing_tables[database_]
690+
708691
routing_table = await self.get_or_create_routing_table(database)
709692
if routing_table.is_fresh(readonly=(access_mode == READ_ACCESS)):
710693
# Readers are fresh.
@@ -717,25 +700,21 @@ async def ensure_routing_table_is_fresh(
717700
)
718701
await self.update_connection_pool(database=database)
719702

720-
for database in list(self.routing_tables.keys()):
721-
# Remove unused databases in the routing table
722-
# Remove the routing table after a timeout = TTL + 30s
723-
log.debug("[#0000] C: <ROUTING AGED> database=%s", database)
724-
if (self.routing_tables[database].should_be_purged_from_memory()
725-
and database != self.workspace_config.database):
726-
del self.routing_tables[database]
727-
728703
return True
729704

730705
async def _select_address(self, *, access_mode, database):
731706
from ...api import READ_ACCESS
732707
""" Selects the address with the fewest in-use connections.
733708
"""
734709
async with self.refresh_lock:
735-
if access_mode == READ_ACCESS:
736-
addresses = self.routing_tables[database].readers
710+
routing_table = self.routing_tables.get(database)
711+
if routing_table:
712+
if access_mode == READ_ACCESS:
713+
addresses = routing_table.readers
714+
else:
715+
addresses = routing_table.writers
737716
else:
738-
addresses = self.routing_tables[database].writers
717+
addresses = ()
739718
addresses_by_usage = {}
740719
for address in addresses:
741720
addresses_by_usage.setdefault(
@@ -763,13 +742,12 @@ async def acquire(
763742

764743
from neo4j.api import check_access_mode
765744
access_mode = check_access_mode(access_mode)
766-
async with self.refresh_lock:
767-
log.debug("[#0000] C: <ROUTING TABLE ENSURE FRESH> %r",
768-
self.routing_tables)
769-
await self.ensure_routing_table_is_fresh(
770-
access_mode=access_mode, database=database, imp_user=None,
771-
bookmarks=bookmarks, acquisition_timeout=timeout
772-
)
745+
746+
await self.ensure_routing_table_is_fresh(
747+
access_mode=access_mode, database=database,
748+
imp_user=None, bookmarks=bookmarks,
749+
acquisition_timeout=timeout
750+
)
773751

774752
while True:
775753
try:

neo4j/_sync/driver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,7 @@ def open(cls, *targets, auth=None, routing_context=None, **config):
668668
return cls(pool, default_workspace_config)
669669

670670
def __init__(self, pool, default_workspace_config):
671-
_Routing.__init__(self, pool.get_default_database_initial_router_addresses())
671+
_Routing.__init__(self, [pool.address])
672672
Driver.__init__(self, pool, default_workspace_config)
673673

674674
if not t.TYPE_CHECKING:

neo4j/_sync/io/_pool.py

Lines changed: 31 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -288,12 +288,9 @@ def in_use_connection_count(self, address):
288288
""" Count the number of connections currently in use to a given
289289
address.
290290
"""
291-
try:
292-
connections = self.connections[address]
293-
except KeyError:
294-
return 0
295-
else:
296-
return sum(1 if connection.in_use else 0 for connection in connections)
291+
with self.lock:
292+
connections = self.connections.get(address, ())
293+
return sum(connection.in_use for connection in connections)
297294

298295
def mark_all_stale(self):
299296
with self.lock:
@@ -447,7 +444,7 @@ def __init__(self, opener, pool_config, workspace_config, address):
447444
# Each database have a routing table, the default database is a special case.
448445
log.debug("[#0000] C: <NEO4J POOL> routing address %r", address)
449446
self.address = address
450-
self.routing_tables = {workspace_config.database: RoutingTable(database=workspace_config.database, routers=[address])}
447+
self.routing_tables = {}
451448
self.refresh_lock = RLock()
452449

453450
def __repr__(self):
@@ -456,37 +453,15 @@ def __repr__(self):
456453
:return: The representation
457454
:rtype: str
458455
"""
459-
return "<{} addresses={!r}>".format(self.__class__.__name__, self.get_default_database_initial_router_addresses())
460-
461-
@property
462-
def first_initial_routing_address(self):
463-
return self.get_default_database_initial_router_addresses()[0]
464-
465-
def get_default_database_initial_router_addresses(self):
466-
""" Get the initial router addresses for the default database.
467-
468-
:return:
469-
:rtype: OrderedSet
470-
"""
471-
return self.get_routing_table_for_default_database().initial_routers
472-
473-
def get_default_database_router_addresses(self):
474-
""" Get the router addresses for the default database.
475-
476-
:return:
477-
:rtype: OrderedSet
478-
"""
479-
return self.get_routing_table_for_default_database().routers
480-
481-
def get_routing_table_for_default_database(self):
482-
return self.routing_tables[self.workspace_config.database]
456+
return "<{} address={!r}>".format(self.__class__.__name__,
457+
self.address)
483458

484459
def get_or_create_routing_table(self, database):
485460
with self.refresh_lock:
486461
if database not in self.routing_tables:
487462
self.routing_tables[database] = RoutingTable(
488463
database=database,
489-
routers=self.get_default_database_initial_router_addresses()
464+
routers=[self.address]
490465
)
491466
return self.routing_tables[database]
492467

@@ -651,15 +626,15 @@ def update_routing_table(
651626
if prefer_initial_routing_address:
652627
# TODO: Test this state
653628
if self._update_routing_table_from(
654-
self.first_initial_routing_address, database=database,
629+
self.address, database=database,
655630
imp_user=imp_user, bookmarks=bookmarks,
656631
acquisition_timeout=acquisition_timeout,
657632
database_callback=database_callback
658633
):
659634
# Why is only the first initial routing address used?
660635
return
661636
if self._update_routing_table_from(
662-
*(existing_routers - {self.first_initial_routing_address}),
637+
*(existing_routers - {self.address}),
663638
database=database, imp_user=imp_user, bookmarks=bookmarks,
664639
acquisition_timeout=acquisition_timeout,
665640
database_callback=database_callback
@@ -668,7 +643,7 @@ def update_routing_table(
668643

669644
if not prefer_initial_routing_address:
670645
if self._update_routing_table_from(
671-
self.first_initial_routing_address, database=database,
646+
self.address, database=database,
672647
imp_user=imp_user, bookmarks=bookmarks,
673648
acquisition_timeout=acquisition_timeout,
674649
database_callback=database_callback
@@ -705,6 +680,14 @@ def ensure_routing_table_is_fresh(
705680
"""
706681
from neo4j.api import READ_ACCESS
707682
with self.refresh_lock:
683+
for database_ in list(self.routing_tables.keys()):
684+
# Remove unused databases in the routing table
685+
# Remove the routing table after a timeout = TTL + 30s
686+
log.debug("[#0000] C: <ROUTING AGED> database=%s", database_)
687+
routing_table = self.routing_tables[database_]
688+
if routing_table.should_be_purged_from_memory():
689+
del self.routing_tables[database_]
690+
708691
routing_table = self.get_or_create_routing_table(database)
709692
if routing_table.is_fresh(readonly=(access_mode == READ_ACCESS)):
710693
# Readers are fresh.
@@ -717,25 +700,21 @@ def ensure_routing_table_is_fresh(
717700
)
718701
self.update_connection_pool(database=database)
719702

720-
for database in list(self.routing_tables.keys()):
721-
# Remove unused databases in the routing table
722-
# Remove the routing table after a timeout = TTL + 30s
723-
log.debug("[#0000] C: <ROUTING AGED> database=%s", database)
724-
if (self.routing_tables[database].should_be_purged_from_memory()
725-
and database != self.workspace_config.database):
726-
del self.routing_tables[database]
727-
728703
return True
729704

730705
def _select_address(self, *, access_mode, database):
731706
from ...api import READ_ACCESS
732707
""" Selects the address with the fewest in-use connections.
733708
"""
734709
with self.refresh_lock:
735-
if access_mode == READ_ACCESS:
736-
addresses = self.routing_tables[database].readers
710+
routing_table = self.routing_tables.get(database)
711+
if routing_table:
712+
if access_mode == READ_ACCESS:
713+
addresses = routing_table.readers
714+
else:
715+
addresses = routing_table.writers
737716
else:
738-
addresses = self.routing_tables[database].writers
717+
addresses = ()
739718
addresses_by_usage = {}
740719
for address in addresses:
741720
addresses_by_usage.setdefault(
@@ -763,13 +742,12 @@ def acquire(
763742

764743
from neo4j.api import check_access_mode
765744
access_mode = check_access_mode(access_mode)
766-
with self.refresh_lock:
767-
log.debug("[#0000] C: <ROUTING TABLE ENSURE FRESH> %r",
768-
self.routing_tables)
769-
self.ensure_routing_table_is_fresh(
770-
access_mode=access_mode, database=database, imp_user=None,
771-
bookmarks=bookmarks, acquisition_timeout=timeout
772-
)
745+
746+
self.ensure_routing_table_is_fresh(
747+
access_mode=access_mode, database=database,
748+
imp_user=None, bookmarks=bookmarks,
749+
acquisition_timeout=timeout
750+
)
773751

774752
while True:
775753
try:

0 commit comments

Comments
 (0)