Skip to content

Unset the needs-notify bit in a Notifier when a Future is fetched #1851

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 116 additions & 56 deletions lightning/src/util/wakers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

use alloc::sync::Arc;
use core::mem;
use crate::sync::{Condvar, Mutex};
use crate::sync::{Condvar, Mutex, MutexGuard};

use crate::prelude::*;

Expand All @@ -33,6 +33,20 @@ pub(crate) struct Notifier {
condvar: Condvar,
}

macro_rules! check_woken {
($guard: expr, $retval: expr) => { {
if $guard.0 {
$guard.0 = false;
if $guard.1.as_ref().map(|l| l.lock().unwrap().complete).unwrap_or(false) {
// If we're about to return as woken, and the future state is marked complete, wipe
// the future state and let the next future wait until we get a new notify.
$guard.1.take();
}
return $retval;
}
} }
}

impl Notifier {
pub(crate) fn new() -> Self {
Self {
Expand All @@ -41,45 +55,47 @@ impl Notifier {
}
}

fn propagate_future_state_to_notify_flag(&self) -> MutexGuard<(bool, Option<Arc<Mutex<FutureState>>>)> {
let mut lock = self.notify_pending.lock().unwrap();
if let Some(existing_state) = &lock.1 {
if existing_state.lock().unwrap().callbacks_made {
// If the existing `FutureState` has completed and actually made callbacks,
// consider the notification flag to have been cleared and reset the future state.
lock.1.take();
lock.0 = false;
}
}
lock
}

pub(crate) fn wait(&self) {
loop {
let mut guard = self.notify_pending.lock().unwrap();
if guard.0 {
guard.0 = false;
return;
}
let mut guard = self.propagate_future_state_to_notify_flag();
check_woken!(guard, ());
guard = self.condvar.wait(guard).unwrap();
let result = guard.0;
if result {
guard.0 = false;
return
}
check_woken!(guard, ());
}
}

#[cfg(any(test, feature = "std"))]
pub(crate) fn wait_timeout(&self, max_wait: Duration) -> bool {
let current_time = Instant::now();
loop {
let mut guard = self.notify_pending.lock().unwrap();
if guard.0 {
guard.0 = false;
return true;
}
let mut guard = self.propagate_future_state_to_notify_flag();
check_woken!(guard, true);
guard = self.condvar.wait_timeout(guard, max_wait).unwrap().0;
check_woken!(guard, true);
// Due to spurious wakeups that can happen on `wait_timeout`, here we need to check if the
// desired wait time has actually passed, and if not then restart the loop with a reduced wait
// time. Note that this logic can be highly simplified through the use of
// `Condvar::wait_while` and `Condvar::wait_timeout_while`, if and when our MSRV is raised to
// 1.42.0.
let elapsed = current_time.elapsed();
let result = guard.0;
if result || elapsed >= max_wait {
guard.0 = false;
return result;
if elapsed >= max_wait {
return false;
}
match max_wait.checked_sub(elapsed) {
None => return result,
None => return false,
Some(_) => continue
}
}
Expand All @@ -88,17 +104,8 @@ impl Notifier {
/// Wake waiters, tracking that wake needs to occur even if there are currently no waiters.
pub(crate) fn notify(&self) {
let mut lock = self.notify_pending.lock().unwrap();
let mut future_probably_generated_calls = false;
if let Some(future_state) = lock.1.take() {
future_probably_generated_calls |= future_state.lock().unwrap().complete();
future_probably_generated_calls |= Arc::strong_count(&future_state) > 1;
}
if future_probably_generated_calls {
// If a future made some callbacks or has not yet been drop'd (i.e. the state has more
// than the one reference we hold), assume the user was notified and skip setting the
// notification-required flag. This will not cause the `wait` functions above to return
// and avoid any future `Future`s starting in a completed state.
return;
if let Some(future_state) = &lock.1 {
future_state.lock().unwrap().complete();
}
lock.0 = true;
mem::drop(lock);
Expand All @@ -107,20 +114,14 @@ impl Notifier {

/// Gets a [`Future`] that will get woken up with any waiters
pub(crate) fn get_future(&self) -> Future {
let mut lock = self.notify_pending.lock().unwrap();
if lock.0 {
Future {
state: Arc::new(Mutex::new(FutureState {
callbacks: Vec::new(),
complete: true,
}))
}
} else if let Some(existing_state) = &lock.1 {
let mut lock = self.propagate_future_state_to_notify_flag();
if let Some(existing_state) = &lock.1 {
Future { state: Arc::clone(&existing_state) }
} else {
let state = Arc::new(Mutex::new(FutureState {
callbacks: Vec::new(),
complete: false,
complete: lock.0,
callbacks_made: false,
}));
lock.1 = Some(Arc::clone(&state));
Future { state }
Expand Down Expand Up @@ -151,19 +152,21 @@ impl<F: Fn() + Send> FutureCallback for F {
}

pub(crate) struct FutureState {
callbacks: Vec<Box<dyn FutureCallback>>,
// When we're tracking whether a callback counts as having woken the user's code, we check the
// first bool - set to false if we're just calling a Waker, and true if we're calling an actual
// user-provided function.
callbacks: Vec<(bool, Box<dyn FutureCallback>)>,
complete: bool,
callbacks_made: bool,
}

impl FutureState {
fn complete(&mut self) -> bool {
let mut made_calls = false;
for callback in self.callbacks.drain(..) {
fn complete(&mut self) {
for (counts_as_call, callback) in self.callbacks.drain(..) {
callback.call();
made_calls = true;
self.callbacks_made |= counts_as_call;
}
self.complete = true;
made_calls
}
}

Expand All @@ -180,10 +183,11 @@ impl Future {
pub fn register_callback(&self, callback: Box<dyn FutureCallback>) {
let mut state = self.state.lock().unwrap();
if state.complete {
state.callbacks_made = true;
mem::drop(state);
callback.call();
} else {
state.callbacks.push(callback);
state.callbacks.push((true, callback));
}
}

Expand All @@ -198,12 +202,10 @@ impl Future {
}
}

mod std_future {
use core::task::Waker;
pub struct StdWaker(pub Waker);
impl super::FutureCallback for StdWaker {
fn call(&self) { self.0.wake_by_ref() }
}
use core::task::Waker;
struct StdWaker(pub Waker);
impl FutureCallback for StdWaker {
fn call(&self) { self.0.wake_by_ref() }
}

/// (C-not exported) as Rust Futures aren't usable in language bindings.
Expand All @@ -213,10 +215,11 @@ impl<'a> StdFuture for Future {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut state = self.state.lock().unwrap();
if state.complete {
state.callbacks_made = true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we doing this? A completed future does not necessarily mean callbacks were made. Removing this line also don't have any effect on the tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, sorry, failed to include that in the test. As far as we're concerned, the future polling Ready is calling the user code to persist. When we hit the waker, we tell the runtime (eg tokio) that its time to poll the future again. Once it does that (and we return Ready) that tells the runtime its time to do whatever was waiting on the future (probably user code to persist the ChannelManager.

Poll::Ready(())
} else {
let waker = cx.waker().clone();
state.callbacks.push(Box::new(std_future::StdWaker(waker)));
state.callbacks.push((false, Box::new(StdWaker(waker))));
Poll::Pending
}
}
Expand Down Expand Up @@ -285,6 +288,28 @@ mod tests {
assert!(!callback.load(Ordering::SeqCst));
}

#[test]
fn new_future_wipes_notify_bit() {
// Previously, if we were only using the `Future` interface to learn when a `Notifier` has
// been notified, we'd never mark the notifier as not-awaiting-notify if a `Future` is
// fetched after the notify bit has been set.
let notifier = Notifier::new();
notifier.notify();

let callback = Arc::new(AtomicBool::new(false));
let callback_ref = Arc::clone(&callback);
notifier.get_future().register_callback(Box::new(move || assert!(!callback_ref.fetch_or(true, Ordering::SeqCst))));
assert!(callback.load(Ordering::SeqCst));

let callback = Arc::new(AtomicBool::new(false));
let callback_ref = Arc::clone(&callback);
notifier.get_future().register_callback(Box::new(move || assert!(!callback_ref.fetch_or(true, Ordering::SeqCst))));
assert!(!callback.load(Ordering::SeqCst));

notifier.notify();
assert!(callback.load(Ordering::SeqCst));
}

#[cfg(feature = "std")]
#[test]
fn test_wait_timeout() {
Expand Down Expand Up @@ -336,6 +361,7 @@ mod tests {
state: Arc::new(Mutex::new(FutureState {
callbacks: Vec::new(),
complete: false,
callbacks_made: false,
}))
};
let callback = Arc::new(AtomicBool::new(false));
Expand All @@ -354,6 +380,7 @@ mod tests {
state: Arc::new(Mutex::new(FutureState {
callbacks: Vec::new(),
complete: false,
callbacks_made: false,
}))
};
future.state.lock().unwrap().complete();
Expand Down Expand Up @@ -391,6 +418,7 @@ mod tests {
state: Arc::new(Mutex::new(FutureState {
callbacks: Vec::new(),
complete: false,
callbacks_made: false,
}))
};
let mut second_future = Future { state: Arc::clone(&future.state) };
Expand All @@ -409,4 +437,36 @@ mod tests {
assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(()));
assert_eq!(Pin::new(&mut second_future).poll(&mut Context::from_waker(&second_waker)), Poll::Ready(()));
}

#[test]
fn test_dropped_future_doesnt_count() {
// Tests that if a Future gets drop'd before it is poll()ed `Ready` it doesn't count as
// having been woken, leaving the notify-required flag set.
let notifier = Notifier::new();
notifier.notify();

// If we get a future and don't touch it we're definitely still notify-required.
notifier.get_future();
assert!(notifier.wait_timeout(Duration::from_millis(1)));
assert!(!notifier.wait_timeout(Duration::from_millis(1)));

// Even if we poll'd once but didn't observe a `Ready`, we should be notify-required.
let mut future = notifier.get_future();
let (woken, waker) = create_waker();
assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Pending);

notifier.notify();
assert!(woken.load(Ordering::SeqCst));
assert!(notifier.wait_timeout(Duration::from_millis(1)));

// However, once we do poll `Ready` it should wipe the notify-required flag.
let mut future = notifier.get_future();
let (woken, waker) = create_waker();
assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Pending);

notifier.notify();
assert!(woken.load(Ordering::SeqCst));
assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(()));
assert!(!notifier.wait_timeout(Duration::from_millis(1)));
}
}