Skip to content

Commit 385ae68

Browse files
committed
More rigourous detection of notification for condvar
1 parent c8dd264 commit 385ae68

File tree

2 files changed

+70
-19
lines changed

2 files changed

+70
-19
lines changed

src/sync/condvar.rs

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::pin::Pin;
2-
use std::sync::atomic::{AtomicBool, Ordering};
2+
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
3+
use std::sync::Arc;
34
use std::time::Duration;
45

56
use futures_timer::Delay;
@@ -61,7 +62,19 @@ impl WaitTimeoutResult {
6162
#[derive(Debug)]
6263
pub struct Condvar {
6364
has_blocked: AtomicBool,
64-
blocked: std::sync::Mutex<Slab<Option<Waker>>>,
65+
blocked: std::sync::Mutex<Slab<WaitEntry>>,
66+
}
67+
68+
/// Flag to mark if the task was notified
69+
const NOTIFIED: usize = 1;
70+
/// State if the task was notified with `notify_once`
71+
/// so it should notify another task if the future is dropped without waking.
72+
const NOTIFIED_ONCE: usize = 0b11;
73+
74+
#[derive(Debug)]
75+
struct WaitEntry {
76+
state: Arc<AtomicUsize>,
77+
waker: Option<Waker>,
6578
}
6679

6780
impl Condvar {
@@ -126,6 +139,7 @@ impl Condvar {
126139
AwaitNotify {
127140
cond: self,
128141
guard: Some(guard),
142+
state: Arc::new(AtomicUsize::new(0)),
129143
key: None,
130144
}
131145
}
@@ -262,12 +276,8 @@ impl Condvar {
262276
/// ```
263277
pub fn notify_one(&self) {
264278
if self.has_blocked.load(Ordering::Acquire) {
265-
let mut blocked = self.blocked.lock().unwrap();
266-
if let Some((_, opt_waker)) = blocked.iter_mut().next() {
267-
if let Some(w) = opt_waker.take() {
268-
w.wake();
269-
}
270-
}
279+
let blocked = self.blocked.lock().unwrap();
280+
notify(blocked, false);
271281
}
272282
}
273283

@@ -305,11 +315,21 @@ impl Condvar {
305315
/// ```
306316
pub fn notify_all(&self) {
307317
if self.has_blocked.load(Ordering::Acquire) {
308-
let mut blocked = self.blocked.lock().unwrap();
309-
for (_, opt_waker) in blocked.iter_mut() {
310-
if let Some(w) = opt_waker.take() {
311-
w.wake();
312-
}
318+
let blocked = self.blocked.lock().unwrap();
319+
notify(blocked, true);
320+
}
321+
}
322+
}
323+
324+
#[inline]
325+
fn notify(mut blocked: std::sync::MutexGuard<'_, Slab<WaitEntry>>, all: bool) {
326+
let state = if all { NOTIFIED } else { NOTIFIED_ONCE };
327+
for (_, entry) in blocked.iter_mut() {
328+
if let Some(w) = entry.waker.take() {
329+
entry.state.store(state, Ordering::Release);
330+
w.wake();
331+
if !all {
332+
return;
313333
}
314334
}
315335
}
@@ -318,6 +338,7 @@ impl Condvar {
318338
struct AwaitNotify<'a, 'b, T> {
319339
cond: &'a Condvar,
320340
guard: Option<MutexGuard<'b, T>>,
341+
state: Arc<AtomicUsize>,
321342
key: Option<usize>,
322343
}
323344

@@ -329,15 +350,24 @@ impl<'a, 'b, T> Future for AwaitNotify<'a, 'b, T> {
329350
Some(_) => {
330351
let mut blocked = self.cond.blocked.lock().unwrap();
331352
let w = cx.waker().clone();
332-
self.key = Some(blocked.insert(Some(w)));
353+
self.key = Some(blocked.insert(WaitEntry {
354+
state: self.state.clone(),
355+
waker: Some(w),
356+
}));
333357

334358
if blocked.len() == 1 {
335359
self.cond.has_blocked.store(true, Ordering::Relaxed);
336360
}
337361
// the guard is dropped when we return, which frees the lock
338362
Poll::Pending
339363
}
340-
None => Poll::Ready(()),
364+
None => {
365+
if self.state.fetch_and(!NOTIFIED, Ordering::AcqRel) & NOTIFIED != 0 {
366+
Poll::Ready(())
367+
} else {
368+
Poll::Pending
369+
}
370+
}
341371
}
342372
}
343373
}
@@ -350,6 +380,10 @@ impl<'a, 'b, T> Drop for AwaitNotify<'a, 'b, T> {
350380

351381
if blocked.is_empty() {
352382
self.cond.has_blocked.store(false, Ordering::Relaxed);
383+
} else if self.state.load(Ordering::Acquire) == NOTIFIED_ONCE {
384+
// we got a notification form notify_once but didn't handle it,
385+
// so send it to a different task
386+
notify(blocked, false);
353387
}
354388
}
355389
}
@@ -369,12 +403,12 @@ impl<'a, 'b, T> Future for TimeoutWaitFuture<'a, 'b, T> {
369403
type Output = WaitTimeoutResult;
370404

371405
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
372-
match self.as_mut().await_notify().poll(cx) {
373-
Poll::Ready(_) => Poll::Ready(WaitTimeoutResult(false)),
374-
Poll::Pending => match self.delay().poll(cx) {
375-
Poll::Ready(_) => Poll::Ready(WaitTimeoutResult(true)),
406+
match self.as_mut().delay().poll(cx) {
407+
Poll::Pending => match self.await_notify().poll(cx) {
408+
Poll::Ready(_) => Poll::Ready(WaitTimeoutResult(false)),
376409
Poll::Pending => Poll::Pending,
377410
},
411+
Poll::Ready(_) => Poll::Ready(WaitTimeoutResult(true)),
378412
}
379413
}
380414
}

tests/condvar.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
use std::sync::Arc;
2+
use std::time::Duration;
3+
4+
use async_std::sync::{Condvar, Mutex};
5+
use async_std::task;
6+
7+
#[test]
8+
fn wait_timeout() {
9+
task::block_on(async {
10+
let m = Mutex::new(());
11+
let c = Condvar::new();
12+
let (_, wait_result) = c
13+
.wait_timeout(m.lock().await, Duration::from_millis(10))
14+
.await;
15+
assert!(wait_result.timed_out());
16+
})
17+
}

0 commit comments

Comments
 (0)