Skip to content

Commit 7527e4b

Browse files
committed
Unset the needs-notify bit in a Notifier when a Future is fetched
If a `Notifier` gets `notify()`ed and the a `Future` is fetched, even though the `Future` is marked completed from the start and the user may pass callbacks which are called, we'll never wipe the needs-notify bit in the `Notifier`. The solution is to keep track of the `FutureState` in the returned `Future` even though its `complete` from the start, adding a new flag in the `FutureState` which indicates callbacks have been made and checking that flag when waiting or returning a second `Future`.
1 parent bcf8687 commit 7527e4b

File tree

1 file changed

+51
-28
lines changed

1 file changed

+51
-28
lines changed

lightning/src/util/wakers.rs

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
1616
use alloc::sync::Arc;
1717
use core::mem;
18-
use crate::sync::{Condvar, Mutex};
18+
use crate::sync::{Condvar, Mutex, MutexGuard};
1919

2020
use crate::prelude::*;
2121

@@ -41,9 +41,22 @@ impl Notifier {
4141
}
4242
}
4343

44+
fn propagate_future_state_to_notify_flag(&self) -> MutexGuard<(bool, Option<Arc<Mutex<FutureState>>>)> {
45+
let mut lock = self.notify_pending.lock().unwrap();
46+
if let Some(existing_state) = &lock.1 {
47+
if existing_state.lock().unwrap().callbacks_made {
48+
// If the existing `FutureState` has completed and actually made callbacks,
49+
// consider the notification flag to have been cleared and reset the future state.
50+
lock.1.take();
51+
lock.0 = false;
52+
}
53+
}
54+
lock
55+
}
56+
4457
pub(crate) fn wait(&self) {
4558
loop {
46-
let mut guard = self.notify_pending.lock().unwrap();
59+
let mut guard = self.propagate_future_state_to_notify_flag();
4760
if guard.0 {
4861
guard.0 = false;
4962
return;
@@ -61,7 +74,7 @@ impl Notifier {
6174
pub(crate) fn wait_timeout(&self, max_wait: Duration) -> bool {
6275
let current_time = Instant::now();
6376
loop {
64-
let mut guard = self.notify_pending.lock().unwrap();
77+
let mut guard = self.propagate_future_state_to_notify_flag();
6578
if guard.0 {
6679
guard.0 = false;
6780
return true;
@@ -88,17 +101,8 @@ impl Notifier {
88101
/// Wake waiters, tracking that wake needs to occur even if there are currently no waiters.
89102
pub(crate) fn notify(&self) {
90103
let mut lock = self.notify_pending.lock().unwrap();
91-
let mut future_probably_generated_calls = false;
92-
if let Some(future_state) = lock.1.take() {
93-
future_probably_generated_calls |= future_state.lock().unwrap().complete();
94-
future_probably_generated_calls |= Arc::strong_count(&future_state) > 1;
95-
}
96-
if future_probably_generated_calls {
97-
// If a future made some callbacks or has not yet been drop'd (i.e. the state has more
98-
// than the one reference we hold), assume the user was notified and skip setting the
99-
// notification-required flag. This will not cause the `wait` functions above to return
100-
// and avoid any future `Future`s starting in a completed state.
101-
return;
104+
if let Some(future_state) = &lock.1 {
105+
future_state.lock().unwrap().complete();
102106
}
103107
lock.0 = true;
104108
mem::drop(lock);
@@ -107,20 +111,14 @@ impl Notifier {
107111

108112
/// Gets a [`Future`] that will get woken up with any waiters
109113
pub(crate) fn get_future(&self) -> Future {
110-
let mut lock = self.notify_pending.lock().unwrap();
111-
if lock.0 {
112-
Future {
113-
state: Arc::new(Mutex::new(FutureState {
114-
callbacks: Vec::new(),
115-
complete: true,
116-
}))
117-
}
118-
} else if let Some(existing_state) = &lock.1 {
114+
let mut lock = self.propagate_future_state_to_notify_flag();
115+
if let Some(existing_state) = &lock.1 {
119116
Future { state: Arc::clone(&existing_state) }
120117
} else {
121118
let state = Arc::new(Mutex::new(FutureState {
122119
callbacks: Vec::new(),
123-
complete: false,
120+
complete: lock.0,
121+
callbacks_made: false,
124122
}));
125123
lock.1 = Some(Arc::clone(&state));
126124
Future { state }
@@ -153,17 +151,16 @@ impl<F: Fn() + Send> FutureCallback for F {
153151
pub(crate) struct FutureState {
154152
callbacks: Vec<Box<dyn FutureCallback>>,
155153
complete: bool,
154+
callbacks_made: bool,
156155
}
157156

158157
impl FutureState {
159-
fn complete(&mut self) -> bool {
160-
let mut made_calls = false;
158+
fn complete(&mut self) {
161159
for callback in self.callbacks.drain(..) {
162160
callback.call();
163-
made_calls = true;
161+
self.callbacks_made = true;
164162
}
165163
self.complete = true;
166-
made_calls
167164
}
168165
}
169166

@@ -180,6 +177,7 @@ impl Future {
180177
pub fn register_callback(&self, callback: Box<dyn FutureCallback>) {
181178
let mut state = self.state.lock().unwrap();
182179
if state.complete {
180+
state.callbacks_made = true;
183181
mem::drop(state);
184182
callback.call();
185183
} else {
@@ -283,6 +281,28 @@ mod tests {
283281
assert!(!callback.load(Ordering::SeqCst));
284282
}
285283

284+
#[test]
285+
fn new_future_wipes_notify_bit() {
286+
// Previously, if we were only using the `Future` interface to learn when a `Notifier` has
287+
// been notified, we'd never mark the notifier as not-awaiting-notify if a `Future` is
288+
// fetched after the notify bit has been set.
289+
let notifier = Notifier::new();
290+
notifier.notify();
291+
292+
let callback = Arc::new(AtomicBool::new(false));
293+
let callback_ref = Arc::clone(&callback);
294+
notifier.get_future().register_callback(Box::new(move || assert!(!callback_ref.fetch_or(true, Ordering::SeqCst))));
295+
assert!(callback.load(Ordering::SeqCst));
296+
297+
let callback = Arc::new(AtomicBool::new(false));
298+
let callback_ref = Arc::clone(&callback);
299+
notifier.get_future().register_callback(Box::new(move || assert!(!callback_ref.fetch_or(true, Ordering::SeqCst))));
300+
assert!(!callback.load(Ordering::SeqCst));
301+
302+
notifier.notify();
303+
assert!(callback.load(Ordering::SeqCst));
304+
}
305+
286306
#[cfg(feature = "std")]
287307
#[test]
288308
fn test_wait_timeout() {
@@ -334,6 +354,7 @@ mod tests {
334354
state: Arc::new(Mutex::new(FutureState {
335355
callbacks: Vec::new(),
336356
complete: false,
357+
callbacks_made: false,
337358
}))
338359
};
339360
let callback = Arc::new(AtomicBool::new(false));
@@ -352,6 +373,7 @@ mod tests {
352373
state: Arc::new(Mutex::new(FutureState {
353374
callbacks: Vec::new(),
354375
complete: false,
376+
callbacks_made: false,
355377
}))
356378
};
357379
future.state.lock().unwrap().complete();
@@ -389,6 +411,7 @@ mod tests {
389411
state: Arc::new(Mutex::new(FutureState {
390412
callbacks: Vec::new(),
391413
complete: false,
414+
callbacks_made: false,
392415
}))
393416
};
394417
let mut second_future = Future { state: Arc::clone(&future.state) };

0 commit comments

Comments
 (0)