Skip to content

Commit d7dc659

Browse files
committed
PYTHON-4782 Fix deadlock and blocking behavior in _ACondition.wait
1 parent e03f8f2 commit d7dc659

File tree

1 file changed

+116
-25
lines changed

1 file changed

+116
-25
lines changed

pymongo/lock.py

Lines changed: 116 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414
from __future__ import annotations
1515

1616
import asyncio
17+
import collections
1718
import os
1819
import threading
1920
import time
2021
import weakref
21-
from typing import Any, Callable, Optional
22+
from typing import Any, Optional
2223

2324
_HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork")
2425

@@ -44,6 +45,8 @@ def _release_locks() -> None:
4445

4546

4647
class _ALock:
48+
__slots__ = ("_lock",)
49+
4750
def __init__(self, lock: threading.Lock) -> None:
4851
self._lock = lock
4952

@@ -82,8 +85,11 @@ async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
8285

8386

8487
class _ACondition:
88+
__slots__ = ("_condition", "_waiters")
89+
8590
def __init__(self, condition: threading.Condition) -> None:
8691
self._condition = condition
92+
self._waiters = collections.deque()
8793

8894
async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
8995
if timeout > 0:
@@ -99,30 +105,115 @@ async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
99105
await asyncio.sleep(0)
100106

101107
async def wait(self, timeout: Optional[float] = None) -> bool:
102-
if timeout is not None:
103-
tstart = time.monotonic()
104-
while True:
105-
notified = self._condition.wait(0.001)
106-
if notified:
107-
return True
108-
if timeout is not None and (time.monotonic() - tstart) > timeout:
109-
return False
110-
111-
async def wait_for(self, predicate: Callable, timeout: Optional[float] = None) -> bool:
112-
if timeout is not None:
113-
tstart = time.monotonic()
114-
while True:
115-
notified = self._condition.wait_for(predicate, 0.001)
116-
if notified:
117-
return True
118-
if timeout is not None and (time.monotonic() - tstart) > timeout:
119-
return False
120-
121-
def notify(self, n: int = 1) -> None:
122-
self._condition.notify(n)
123-
124-
def notify_all(self) -> None:
125-
self._condition.notify_all()
108+
"""Wait until notified.
109+
110+
If the calling task has not acquired the lock when this
111+
method is called, a RuntimeError is raised.
112+
113+
This method releases the underlying lock, and then blocks
114+
until it is awakened by a notify() or notify_all() call for
115+
the same condition variable in another task. Once
116+
awakened, it re-acquires the lock and returns True.
117+
118+
This method may return spuriously,
119+
which is why the caller should always
120+
re-check the state and be prepared to wait() again.
121+
"""
122+
loop = asyncio.get_running_loop()
123+
fut = loop.create_future()
124+
self._waiters.append((loop, fut))
125+
self.release()
126+
try:
127+
try:
128+
try:
129+
await asyncio.wait_for(fut, timeout)
130+
return True
131+
except asyncio.TimeoutError:
132+
return False # Return false on timeout for sync pool compat.
133+
finally:
134+
# Must re-acquire lock even if wait is cancelled.
135+
# We only catch CancelledError here, since we don't want any
136+
# other (fatal) errors with the future to cause us to spin.
137+
err = None
138+
while True:
139+
try:
140+
await self.acquire()
141+
break
142+
except asyncio.exceptions.CancelledError as e:
143+
err = e
144+
145+
self._waiters.remove((loop, fut))
146+
if err is not None:
147+
try:
148+
raise err # Re-raise most recent exception instance.
149+
finally:
150+
err = None # Break reference cycles.
151+
except BaseException:
152+
# Any error raised out of here _may_ have occurred after this Task
153+
# believed to have been successfully notified.
154+
# Make sure to notify another Task instead. This may result
155+
# in a "spurious wakeup", which is allowed as part of the
156+
# Condition Variable protocol.
157+
self.notify(1)
158+
raise
159+
160+
async def wait_for(self, predicate):
161+
"""Wait until a predicate becomes true.
162+
163+
The predicate should be a callable which result will be
164+
interpreted as a boolean value. The final predicate value is
165+
the return value.
166+
"""
167+
result = predicate()
168+
while not result:
169+
await self.wait()
170+
result = predicate()
171+
return result
172+
173+
def notify(self, n=1):
174+
"""By default, wake up one coroutine waiting on this condition, if any.
175+
If the calling coroutine has not acquired the lock when this method
176+
is called, a RuntimeError is raised.
177+
178+
This method wakes up at most n of the coroutines waiting for the
179+
condition variable; it is a no-op if no coroutines are waiting.
180+
181+
Note: an awakened coroutine does not actually return from its
182+
wait() call until it can reacquire the lock. Since notify() does
183+
not release the lock, its caller should.
184+
"""
185+
idx = 0
186+
to_remove = []
187+
for loop, fut in self._waiters:
188+
if idx >= n:
189+
break
190+
191+
if fut.done():
192+
continue
193+
194+
def safe_set_result(fut):
195+
if not fut.done():
196+
fut.set_result(False)
197+
198+
try:
199+
loop.call_soon_threadsafe(safe_set_result, fut)
200+
except RuntimeError:
201+
# Loop was closed, ignore.
202+
to_remove.append((loop, fut))
203+
continue
204+
205+
idx += 1
206+
207+
for waiter in to_remove:
208+
self._waiters.remove(waiter)
209+
210+
def notify_all(self):
211+
"""Wake up all threads waiting on this condition. This method acts
212+
like notify(), but wakes up all waiting threads instead of one. If the
213+
calling thread has not acquired the lock when this method is called,
214+
a RuntimeError is raised.
215+
"""
216+
self.notify(len(self._waiters))
126217

127218
def release(self) -> None:
128219
self._condition.release()

0 commit comments

Comments
 (0)