14
14
from __future__ import annotations
15
15
16
16
import asyncio
17
+ import collections
17
18
import os
18
19
import threading
19
20
import time
20
21
import weakref
21
- from typing import Any , Callable , Optional
22
+ from typing import Any , Optional
22
23
23
24
_HAS_REGISTER_AT_FORK = hasattr (os , "register_at_fork" )
24
25
@@ -44,6 +45,8 @@ def _release_locks() -> None:
44
45
45
46
46
47
class _ALock :
48
+ __slots__ = ("_lock" ,)
49
+
47
50
def __init__ (self , lock : threading .Lock ) -> None :
48
51
self ._lock = lock
49
52
@@ -82,8 +85,11 @@ async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
82
85
83
86
84
87
class _ACondition :
88
+ __slots__ = ("_condition" , "_waiters" )
89
+
85
90
def __init__ (self , condition : threading .Condition ) -> None :
86
91
self ._condition = condition
92
+ self ._waiters = collections .deque ()
87
93
88
94
async def acquire (self , blocking : bool = True , timeout : float = - 1 ) -> bool :
89
95
if timeout > 0 :
@@ -99,30 +105,115 @@ async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
99
105
await asyncio .sleep (0 )
100
106
101
107
async def wait (self , timeout : Optional [float ] = None ) -> bool :
102
- if timeout is not None :
103
- tstart = time .monotonic ()
104
- while True :
105
- notified = self ._condition .wait (0.001 )
106
- if notified :
107
- return True
108
- if timeout is not None and (time .monotonic () - tstart ) > timeout :
109
- return False
110
-
111
- async def wait_for (self , predicate : Callable , timeout : Optional [float ] = None ) -> bool :
112
- if timeout is not None :
113
- tstart = time .monotonic ()
114
- while True :
115
- notified = self ._condition .wait_for (predicate , 0.001 )
116
- if notified :
117
- return True
118
- if timeout is not None and (time .monotonic () - tstart ) > timeout :
119
- return False
120
-
121
- def notify (self , n : int = 1 ) -> None :
122
- self ._condition .notify (n )
123
-
124
- def notify_all (self ) -> None :
125
- self ._condition .notify_all ()
108
+ """Wait until notified.
109
+
110
+ If the calling task has not acquired the lock when this
111
+ method is called, a RuntimeError is raised.
112
+
113
+ This method releases the underlying lock, and then blocks
114
+ until it is awakened by a notify() or notify_all() call for
115
+ the same condition variable in another task. Once
116
+ awakened, it re-acquires the lock and returns True.
117
+
118
+ This method may return spuriously,
119
+ which is why the caller should always
120
+ re-check the state and be prepared to wait() again.
121
+ """
122
+ loop = asyncio .get_running_loop ()
123
+ fut = loop .create_future ()
124
+ self ._waiters .append ((loop , fut ))
125
+ self .release ()
126
+ try :
127
+ try :
128
+ try :
129
+ await asyncio .wait_for (fut , timeout )
130
+ return True
131
+ except asyncio .TimeoutError :
132
+ return False # Return false on timeout for sync pool compat.
133
+ finally :
134
+ # Must re-acquire lock even if wait is cancelled.
135
+ # We only catch CancelledError here, since we don't want any
136
+ # other (fatal) errors with the future to cause us to spin.
137
+ err = None
138
+ while True :
139
+ try :
140
+ await self .acquire ()
141
+ break
142
+ except asyncio .exceptions .CancelledError as e :
143
+ err = e
144
+
145
+ self ._waiters .remove ((loop , fut ))
146
+ if err is not None :
147
+ try :
148
+ raise err # Re-raise most recent exception instance.
149
+ finally :
150
+ err = None # Break reference cycles.
151
+ except BaseException :
152
+ # Any error raised out of here _may_ have occurred after this Task
153
+ # believed to have been successfully notified.
154
+ # Make sure to notify another Task instead. This may result
155
+ # in a "spurious wakeup", which is allowed as part of the
156
+ # Condition Variable protocol.
157
+ self .notify (1 )
158
+ raise
159
+
160
+ async def wait_for (self , predicate ):
161
+ """Wait until a predicate becomes true.
162
+
163
+ The predicate should be a callable which result will be
164
+ interpreted as a boolean value. The final predicate value is
165
+ the return value.
166
+ """
167
+ result = predicate ()
168
+ while not result :
169
+ await self .wait ()
170
+ result = predicate ()
171
+ return result
172
+
173
+ def notify (self , n = 1 ):
174
+ """By default, wake up one coroutine waiting on this condition, if any.
175
+ If the calling coroutine has not acquired the lock when this method
176
+ is called, a RuntimeError is raised.
177
+
178
+ This method wakes up at most n of the coroutines waiting for the
179
+ condition variable; it is a no-op if no coroutines are waiting.
180
+
181
+ Note: an awakened coroutine does not actually return from its
182
+ wait() call until it can reacquire the lock. Since notify() does
183
+ not release the lock, its caller should.
184
+ """
185
+ idx = 0
186
+ to_remove = []
187
+ for loop , fut in self ._waiters :
188
+ if idx >= n :
189
+ break
190
+
191
+ if fut .done ():
192
+ continue
193
+
194
+ def safe_set_result (fut ):
195
+ if not fut .done ():
196
+ fut .set_result (False )
197
+
198
+ try :
199
+ loop .call_soon_threadsafe (safe_set_result , fut )
200
+ except RuntimeError :
201
+ # Loop was closed, ignore.
202
+ to_remove .append ((loop , fut ))
203
+ continue
204
+
205
+ idx += 1
206
+
207
+ for waiter in to_remove :
208
+ self ._waiters .remove (waiter )
209
+
210
+ def notify_all (self ):
211
+ """Wake up all threads waiting on this condition. This method acts
212
+ like notify(), but wakes up all waiting threads instead of one. If the
213
+ calling thread has not acquired the lock when this method is called,
214
+ a RuntimeError is raised.
215
+ """
216
+ self .notify (len (self ._waiters ))
126
217
127
218
def release (self ) -> None :
128
219
self ._condition .release ()
0 commit comments