diff --git a/Cargo.toml b/Cargo.toml index 8c611bf..ea06cf8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ atomic-waker = "1" easy-parallel = "3" flaky_test = "0.1" flume = { version = "0.10", default-features = false } +futures-lite = "1.12.0" once_cell = "1" smol = "1" diff --git a/src/header.rs b/src/header.rs index 9747c5d..3dd35e7 100644 --- a/src/header.rs +++ b/src/header.rs @@ -31,6 +31,10 @@ pub(crate) struct Header { /// /// This metadata may be provided to the user. pub(crate) metadata: M, + + /// Whether or not a panic that occurs in the task should be propagated. + #[cfg(feature = "std")] + pub(crate) propagate_panic: bool, } impl Header { diff --git a/src/raw.rs b/src/raw.rs index 4af7b27..4bba757 100644 --- a/src/raw.rs +++ b/src/raw.rs @@ -13,6 +13,12 @@ use crate::state::*; use crate::utils::{abort, abort_on_panic, max, Layout}; use crate::Runnable; +#[cfg(feature = "std")] +pub(crate) type Panic = alloc::boxed::Box; + +#[cfg(not(feature = "std"))] +pub(crate) type Panic = core::convert::Infallible; + /// The vtable for a task. pub(crate) struct TaskVTable { /// Schedules the task. @@ -76,7 +82,7 @@ pub(crate) struct RawTask { pub(crate) future: *mut F, /// The output of the future. - pub(crate) output: *mut T, + pub(crate) output: *mut Result, } impl Copy for RawTask {} @@ -97,7 +103,7 @@ impl RawTask { let layout_header = Layout::new::>(); let layout_s = Layout::new::(); let layout_f = Layout::new::(); - let layout_r = Layout::new::(); + let layout_r = Layout::new::>(); // Compute the layout for `union { F, T }`. let size_union = max(layout_f.size(), layout_r.size()); @@ -138,7 +144,7 @@ where pub(crate) fn allocate<'a, Gen: FnOnce(&'a M) -> F>( future: Gen, schedule: S, - metadata: M, + builder: crate::Builder, ) -> NonNull<()> where F: 'a, @@ -158,6 +164,12 @@ where let raw = Self::from_ptr(ptr.as_ptr()); + let crate::Builder { + metadata, + #[cfg(feature = "std")] + propagate_panic, + } = builder; + // Write the header as the first field of the task. (raw.header as *mut Header).write(Header { state: AtomicUsize::new(SCHEDULED | TASK | REFERENCE), @@ -173,6 +185,8 @@ where layout_info: &Self::TASK_LAYOUT, }, metadata, + #[cfg(feature = "std")] + propagate_panic, }); // Write the schedule function as the third field of the task. @@ -199,7 +213,7 @@ where header: p as *const Header, schedule: p.add(task_layout.offset_s) as *const S, future: p.add(task_layout.offset_f) as *mut F, - output: p.add(task_layout.offset_r) as *mut T, + output: p.add(task_layout.offset_r) as *mut Result, } } } @@ -525,8 +539,30 @@ where // Poll the inner future, but surround it with a guard that closes the task in case polling // panics. + // If available, we should also try to catch the panic so that it is propagated correctly. let guard = Guard(raw); - let poll = ::poll(Pin::new_unchecked(&mut *raw.future), cx); + + // Panic propagation is not available for no_std. + #[cfg(not(feature = "std"))] + let poll = ::poll(Pin::new_unchecked(&mut *raw.future), cx).map(Ok); + + #[cfg(feature = "std")] + let poll = { + // Check if we should propagate panics. + if (*raw.header).propagate_panic { + // Use catch_unwind to catch the panic. + match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + ::poll(Pin::new_unchecked(&mut *raw.future), cx) + })) { + Ok(Poll::Ready(v)) => Poll::Ready(Ok(v)), + Ok(Poll::Pending) => Poll::Pending, + Err(e) => Poll::Ready(Err(e)), + } + } else { + ::poll(Pin::new_unchecked(&mut *raw.future), cx).map(Ok) + } + }; + mem::forget(guard); match poll { diff --git a/src/runnable.rs b/src/runnable.rs index c10177d..e371176 100644 --- a/src/runnable.rs +++ b/src/runnable.rs @@ -17,7 +17,11 @@ use crate::Task; #[derive(Debug)] pub struct Builder { /// The metadata associated with the task. - metadata: M, + pub(crate) metadata: M, + + /// Whether or not a panic that occurs in the task should be propagated. + #[cfg(feature = "std")] + pub(crate) propagate_panic: bool, } impl Default for Builder { @@ -40,7 +44,11 @@ impl Builder<()> { /// let (runnable, task) = Builder::new().spawn(|()| async {}, |_| {}); /// ``` pub fn new() -> Builder<()> { - Builder { metadata: () } + Builder { + metadata: (), + #[cfg(feature = "std")] + propagate_panic: false, + } } /// Adds metadata to the task. @@ -123,11 +131,63 @@ impl Builder<()> { /// # }); /// ``` pub fn metadata(self, metadata: M) -> Builder { - Builder { metadata } + Builder { + metadata, + #[cfg(feature = "std")] + propagate_panic: self.propagate_panic, + } } } impl Builder { + /// Propagates panics that occur in the task. + /// + /// When this is `true`, panics that occur in the task will be propagated to the caller of + /// the [`Task`]. When this is false, no special action is taken when a panic occurs in the + /// task, meaning that the caller of [`Runnable::run`] will observe a panic. + /// + /// This is only available when the `std` feature is enabled. By default, this is `false`. + /// + /// # Examples + /// + /// ``` + /// use async_task::Builder; + /// use futures_lite::future::poll_fn; + /// use std::future::Future; + /// use std::panic; + /// use std::pin::Pin; + /// use std::task::{Context, Poll}; + /// + /// fn did_panic(f: F) -> bool { + /// panic::catch_unwind(panic::AssertUnwindSafe(f)).is_err() + /// } + /// + /// # smol::future::block_on(async { + /// let (runnable1, mut task1) = Builder::new() + /// .propagate_panic(true) + /// .spawn(|()| async move { panic!() }, |_| {}); + /// + /// let (runnable2, mut task2) = Builder::new() + /// .propagate_panic(false) + /// .spawn(|()| async move { panic!() }, |_| {}); + /// + /// assert!(!did_panic(|| { runnable1.run(); })); + /// assert!(did_panic(|| { runnable2.run(); })); + /// + /// let waker = poll_fn(|cx| Poll::Ready(cx.waker().clone())).await; + /// let mut cx = Context::from_waker(&waker); + /// assert!(did_panic(|| { let _ = Pin::new(&mut task1).poll(&mut cx); })); + /// assert!(did_panic(|| { let _ = Pin::new(&mut task2).poll(&mut cx); })); + /// # }); + /// ``` + #[cfg(feature = "std")] + pub fn propagate_panic(self, propagate_panic: bool) -> Builder { + Builder { + metadata: self.metadata, + propagate_panic, + } + } + /// Creates a new task. /// /// The returned [`Runnable`] is used to poll the `future`, and the [`Task`] is used to await its @@ -313,8 +373,6 @@ impl Builder { S: Fn(Runnable), M: 'a, { - let Self { metadata } = self; - // Allocate large futures on the heap. let ptr = if mem::size_of::() >= 2048 { let future = |meta| { @@ -322,9 +380,9 @@ impl Builder { Box::pin(future) }; - RawTask::<_, Fut::Output, S, M>::allocate(future, schedule, metadata) + RawTask::<_, Fut::Output, S, M>::allocate(future, schedule, self) } else { - RawTask::::allocate(future, schedule, metadata) + RawTask::::allocate(future, schedule, self) }; let runnable = Runnable { diff --git a/src/task.rs b/src/task.rs index 49ba501..5bf8b46 100644 --- a/src/task.rs +++ b/src/task.rs @@ -8,6 +8,7 @@ use core::sync::atomic::Ordering; use core::task::{Context, Poll}; use crate::header::Header; +use crate::raw::Panic; use crate::state::*; /// A spawned task. @@ -226,7 +227,7 @@ impl Task { } /// Puts the task in detached state. - fn set_detached(&mut self) -> Option { + fn set_detached(&mut self) -> Option> { let ptr = self.ptr.as_ptr(); let header = ptr as *const Header; @@ -256,8 +257,10 @@ impl Task { ) { Ok(_) => { // Read the output. - output = - Some((((*header).vtable.get_output)(ptr) as *mut T).read()); + output = Some( + (((*header).vtable.get_output)(ptr) as *mut Result) + .read(), + ); // Update the state variable because we're continuing the loop. state |= CLOSED; @@ -382,8 +385,22 @@ impl Task { } // Take the output from the task. - let output = ((*header).vtable.get_output)(ptr) as *mut T; - return Poll::Ready(Some(output.read())); + let output = ((*header).vtable.get_output)(ptr) as *mut Result; + let output = output.read(); + + // Propagate the panic if the task panicked. + let output = match output { + Ok(output) => output, + Err(panic) => { + #[cfg(feature = "std")] + std::panic::resume_unwind(panic); + + #[cfg(not(feature = "std"))] + match panic {} + } + }; + + return Poll::Ready(Some(output)); } Err(s) => state = s, }