1
1
use std:: pin:: Pin ;
2
- use std:: sync:: atomic:: { AtomicBool , Ordering } ;
2
+ use std:: sync:: atomic:: { AtomicBool , AtomicUsize , Ordering } ;
3
+ use std:: sync:: Arc ;
3
4
use std:: time:: Duration ;
4
5
5
6
use futures_timer:: Delay ;
@@ -61,7 +62,19 @@ impl WaitTimeoutResult {
61
62
#[ derive( Debug ) ]
62
63
pub struct Condvar {
63
64
has_blocked : AtomicBool ,
64
- blocked : std:: sync:: Mutex < Slab < Option < Waker > > > ,
65
+ blocked : std:: sync:: Mutex < Slab < WaitEntry > > ,
66
+ }
67
+
68
+ /// Flag to mark if the task was notified
69
+ const NOTIFIED : usize = 1 ;
70
+ /// State if the task was notified with `notify_once`
71
+ /// so it should notify another task if the future is dropped without waking.
72
+ const NOTIFIED_ONCE : usize = 0b11 ;
73
+
74
+ #[ derive( Debug ) ]
75
+ struct WaitEntry {
76
+ state : Arc < AtomicUsize > ,
77
+ waker : Option < Waker > ,
65
78
}
66
79
67
80
impl Condvar {
@@ -126,6 +139,7 @@ impl Condvar {
126
139
AwaitNotify {
127
140
cond : self ,
128
141
guard : Some ( guard) ,
142
+ state : Arc :: new ( AtomicUsize :: new ( 0 ) ) ,
129
143
key : None ,
130
144
}
131
145
}
@@ -262,12 +276,8 @@ impl Condvar {
262
276
/// ```
263
277
pub fn notify_one ( & self ) {
264
278
if self . has_blocked . load ( Ordering :: Acquire ) {
265
- let mut blocked = self . blocked . lock ( ) . unwrap ( ) ;
266
- if let Some ( ( _, opt_waker) ) = blocked. iter_mut ( ) . next ( ) {
267
- if let Some ( w) = opt_waker. take ( ) {
268
- w. wake ( ) ;
269
- }
270
- }
279
+ let blocked = self . blocked . lock ( ) . unwrap ( ) ;
280
+ notify ( blocked, false ) ;
271
281
}
272
282
}
273
283
@@ -305,11 +315,21 @@ impl Condvar {
305
315
/// ```
306
316
pub fn notify_all ( & self ) {
307
317
if self . has_blocked . load ( Ordering :: Acquire ) {
308
- let mut blocked = self . blocked . lock ( ) . unwrap ( ) ;
309
- for ( _, opt_waker) in blocked. iter_mut ( ) {
310
- if let Some ( w) = opt_waker. take ( ) {
311
- w. wake ( ) ;
312
- }
318
+ let blocked = self . blocked . lock ( ) . unwrap ( ) ;
319
+ notify ( blocked, true ) ;
320
+ }
321
+ }
322
+ }
323
+
324
+ #[ inline]
325
+ fn notify ( mut blocked : std:: sync:: MutexGuard < ' _ , Slab < WaitEntry > > , all : bool ) {
326
+ let state = if all { NOTIFIED } else { NOTIFIED_ONCE } ;
327
+ for ( _, entry) in blocked. iter_mut ( ) {
328
+ if let Some ( w) = entry. waker . take ( ) {
329
+ entry. state . store ( state, Ordering :: Release ) ;
330
+ w. wake ( ) ;
331
+ if !all {
332
+ return ;
313
333
}
314
334
}
315
335
}
@@ -318,6 +338,7 @@ impl Condvar {
318
338
struct AwaitNotify < ' a , ' b , T > {
319
339
cond : & ' a Condvar ,
320
340
guard : Option < MutexGuard < ' b , T > > ,
341
+ state : Arc < AtomicUsize > ,
321
342
key : Option < usize > ,
322
343
}
323
344
@@ -329,15 +350,24 @@ impl<'a, 'b, T> Future for AwaitNotify<'a, 'b, T> {
329
350
Some ( _) => {
330
351
let mut blocked = self . cond . blocked . lock ( ) . unwrap ( ) ;
331
352
let w = cx. waker ( ) . clone ( ) ;
332
- self . key = Some ( blocked. insert ( Some ( w) ) ) ;
353
+ self . key = Some ( blocked. insert ( WaitEntry {
354
+ state : self . state . clone ( ) ,
355
+ waker : Some ( w) ,
356
+ } ) ) ;
333
357
334
358
if blocked. len ( ) == 1 {
335
359
self . cond . has_blocked . store ( true , Ordering :: Relaxed ) ;
336
360
}
337
361
// the guard is dropped when we return, which frees the lock
338
362
Poll :: Pending
339
363
}
340
- None => Poll :: Ready ( ( ) ) ,
364
+ None => {
365
+ if self . state . fetch_and ( !NOTIFIED , Ordering :: AcqRel ) & NOTIFIED != 0 {
366
+ Poll :: Ready ( ( ) )
367
+ } else {
368
+ Poll :: Pending
369
+ }
370
+ }
341
371
}
342
372
}
343
373
}
@@ -350,6 +380,10 @@ impl<'a, 'b, T> Drop for AwaitNotify<'a, 'b, T> {
350
380
351
381
if blocked. is_empty ( ) {
352
382
self . cond . has_blocked . store ( false , Ordering :: Relaxed ) ;
383
+ } else if self . state . load ( Ordering :: Acquire ) == NOTIFIED_ONCE {
384
+ // we got a notification form notify_once but didn't handle it,
385
+ // so send it to a different task
386
+ notify ( blocked, false ) ;
353
387
}
354
388
}
355
389
}
@@ -369,12 +403,12 @@ impl<'a, 'b, T> Future for TimeoutWaitFuture<'a, 'b, T> {
369
403
type Output = WaitTimeoutResult ;
370
404
371
405
fn poll ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
372
- match self . as_mut ( ) . await_notify ( ) . poll ( cx) {
373
- Poll :: Ready ( _) => Poll :: Ready ( WaitTimeoutResult ( false ) ) ,
374
- Poll :: Pending => match self . delay ( ) . poll ( cx) {
375
- Poll :: Ready ( _) => Poll :: Ready ( WaitTimeoutResult ( true ) ) ,
406
+ match self . as_mut ( ) . delay ( ) . poll ( cx) {
407
+ Poll :: Pending => match self . await_notify ( ) . poll ( cx) {
408
+ Poll :: Ready ( _) => Poll :: Ready ( WaitTimeoutResult ( false ) ) ,
376
409
Poll :: Pending => Poll :: Pending ,
377
410
} ,
411
+ Poll :: Ready ( _) => Poll :: Ready ( WaitTimeoutResult ( true ) ) ,
378
412
}
379
413
}
380
414
}
0 commit comments