|
19 | 19 | import threading
|
20 | 20 | import time
|
21 | 21 | import weakref
|
22 |
| -from typing import Any, Optional |
| 22 | +from typing import Any, Callable, Optional, TypeVar |
23 | 23 |
|
24 | 24 | _HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork")
|
25 | 25 |
|
26 | 26 | # References to instances of _create_lock
|
27 | 27 | _forkable_locks: weakref.WeakSet[threading.Lock] = weakref.WeakSet()
|
28 | 28 |
|
| 29 | +_T = TypeVar("_T") |
| 30 | + |
29 | 31 |
|
30 | 32 | def _create_lock() -> threading.Lock:
|
31 | 33 | """Represents a lock that is tracked upon instantiation using a WeakSet and
|
@@ -163,6 +165,20 @@ async def wait(self, timeout: Optional[float] = None) -> bool:
|
163 | 165 | self.notify(1)
|
164 | 166 | raise
|
165 | 167 |
|
| 168 | + async def wait_for(self, predicate: Callable[[], _T]) -> _T: |
| 169 | + """Wait until a predicate becomes true. |
| 170 | +
|
| 171 | + The predicate should be a callable whose result will be |
| 172 | + interpreted as a boolean value. The method will repeatedly |
| 173 | + wait() until it evaluates to true. The final predicate value is |
| 174 | + the return value. |
| 175 | + """ |
| 176 | + result = predicate() |
| 177 | + while not result: |
| 178 | + await self.wait() |
| 179 | + result = predicate() |
| 180 | + return result |
| 181 | + |
166 | 182 | def notify(self, n: int = 1) -> None:
|
167 | 183 | """By default, wake up one coroutine waiting on this condition, if any.
|
168 | 184 | If the calling coroutine has not acquired the lock when this method
|
@@ -204,6 +220,10 @@ def notify_all(self) -> None:
|
204 | 220 | """
|
205 | 221 | self.notify(len(self._waiters))
|
206 | 222 |
|
| 223 | + def locked(self) -> bool: |
| 224 | + """Only needed for tests in test_locks.""" |
| 225 | + return self._condition._lock.locked() # type: ignore[attr-defined] |
| 226 | + |
207 | 227 | def release(self) -> None:
|
208 | 228 | self._condition.release()
|
209 | 229 |
|
|
0 commit comments