From 1a6a3a9ce7947201cb5b8b14b95ba68081b187cc Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Fri, 7 Mar 2025 18:11:33 +0100 Subject: [PATCH 1/2] First pass at a select statement --- src/context/macro_support.rs | 8 + src/context/mod.rs | 18 +- src/context/request.rs | 10 +- src/context/select.rs | 557 ++++++++++++++++++ src/endpoint/context.rs | 309 ++++++---- src/endpoint/futures/durable_future_impl.rs | 59 ++ src/endpoint/futures/mod.rs | 2 + src/endpoint/futures/select_poll.rs | 151 +++++ src/endpoint/futures/trap.rs | 4 +- test-services/exclusions.yaml | 5 - .../src/virtual_object_command_interpreter.rs | 16 +- 11 files changed, 1002 insertions(+), 137 deletions(-) create mode 100644 src/context/macro_support.rs create mode 100644 src/context/select.rs create mode 100644 src/endpoint/futures/durable_future_impl.rs create mode 100644 src/endpoint/futures/select_poll.rs diff --git a/src/context/macro_support.rs b/src/context/macro_support.rs new file mode 100644 index 0000000..2479f69 --- /dev/null +++ b/src/context/macro_support.rs @@ -0,0 +1,8 @@ +use crate::endpoint::ContextInternal; +use restate_sdk_shared_core::NotificationHandle; + +// Sealed future trait, used by select statement +pub trait SealedDurableFuture { + fn inner_context(&self) -> ContextInternal; + fn handle(&self) -> NotificationHandle; +} diff --git a/src/context/mod.rs b/src/context/mod.rs index 4a0129e..2432f09 100644 --- a/src/context/mod.rs +++ b/src/context/mod.rs @@ -6,8 +6,12 @@ use crate::serde::{Deserialize, Serialize}; use std::future::Future; use std::time::Duration; +#[doc(hidden)] +pub mod macro_support; mod request; mod run; +mod select; + pub use request::{CallFuture, InvocationHandle, Request, RequestTarget}; pub use run::{RunClosure, RunFuture, RunRetryPolicy}; @@ -249,7 +253,10 @@ impl<'ctx> WorkflowContext<'ctx> { /// pub trait ContextTimers<'ctx>: private::SealedContext<'ctx> { /// Sleep using Restate - fn sleep(&self, duration: Duration) -> impl Future> + 'ctx { + fn sleep( + &self, + duration: Duration, + ) -> impl DurableFuture> + 'ctx { private::SealedContext::inner_context(self).sleep(duration) } } @@ -632,7 +639,7 @@ pub trait ContextAwakeables<'ctx>: private::SealedContext<'ctx> { &self, ) -> ( String, - impl Future> + Send + 'ctx, + impl DurableFuture> + Send + 'ctx, ) { self.inner_context().awakeable() } @@ -918,7 +925,7 @@ pub trait ContextPromises<'ctx>: private::SealedContext<'ctx> { fn promise( &'ctx self, key: &str, - ) -> impl Future> + 'ctx { + ) -> impl DurableFuture> + 'ctx { self.inner_context().promise(key) } @@ -946,9 +953,10 @@ impl<'ctx, CTX: private::SealedContext<'ctx> + private::SealedCanUsePromises> Co { } -mod private { - use super::*; +pub trait DurableFuture: Future + macro_support::SealedDurableFuture {} +pub(crate) mod private { + use super::*; pub trait SealedContext<'ctx> { fn inner_context(&self) -> &'ctx ContextInternal; diff --git a/src/context/request.rs b/src/context/request.rs index 0284628..4478135 100644 --- a/src/context/request.rs +++ b/src/context/request.rs @@ -1,3 +1,5 @@ +use super::DurableFuture; + use crate::endpoint::ContextInternal; use crate::errors::TerminalError; use crate::serde::{Deserialize, Serialize}; @@ -95,7 +97,7 @@ impl<'a, Req, Res> Request<'a, Req, Res> { } /// Call a service. This returns a future encapsulating the response. - pub fn call(self) -> impl CallFuture> + Send + pub fn call(self) -> impl CallFuture + Send where Req: Serialize + 'static, Res: Deserialize + 'static, @@ -132,4 +134,8 @@ pub trait InvocationHandle { fn cancel(&self) -> impl Future> + Send; } -pub trait CallFuture: Future + InvocationHandle {} +pub trait CallFuture: + DurableFuture> + InvocationHandle +{ + type Response; +} diff --git a/src/context/select.rs b/src/context/select.rs new file mode 100644 index 0000000..3c1ad6c --- /dev/null +++ b/src/context/select.rs @@ -0,0 +1,557 @@ +/// Select macro, alike tokio::select: +/// +/// ```rust +/// # use restate_sdk::prelude::*; +/// # use std::convert::Infallible; +/// # use std::time::Duration; +/// # +/// # async fn handle(ctx: Context<'_>) -> Result<(), HandlerError> { +/// # let (_, awakeable) = ctx.awakeable::(); +/// # let (_, call_result) = ctx.awakeable::(); +/// restate_sdk::select! { +/// // Bind res to the awakeable result +/// res = awakeable => { +/// // Handle awakeable result +/// }, +/// _ = ctx.sleep(Duration::from_secs(10)) => { +/// // Handle sleep +/// }, +/// // You can also pattern match +/// Ok(success_result) = call_result => { +/// // Handle success result +/// }, +/// else => { +/// // Optional: handle cases when pattern matching doesn't match a future result +/// // If unspecified, select panics when there is no match, e.g. in the above select arm, +/// // if call_result returns Err, it would panic unless you specify an else arm. +/// }, +/// on_cancel => { +/// // Optional: handle when the invocation gets cancelled during this select. +/// // If unspecified, it just propagates the TerminalError +/// } +/// } +/// # Ok(()) +/// # } +/// ``` +/// +/// Note: This API is experimental and subject to changes. +#[macro_export] +macro_rules! select { + // The macro is structured as a tt-muncher. All branches are processed and + // normalized. Once the input is normalized, it is passed to the top-most + // rule. When entering the macro, `@{ }` is inserted at the front. This is + // used to collect the normalized input. + // + // The macro only recurses once per branch. This allows using `select!` + // without requiring the user to increase the recursion limit. + + // All input is normalized, now transform. + (@ { + // One `_` for each branch in the `select!` macro. Passing this to + // `count!` converts $skip to an integer. + ( $($count:tt)* ) + + // Normalized select branches. `( $skip )` is a set of `_` characters. + // There is one `_` for each select branch **before** this one. Given + // that all input futures are stored in a tuple, $skip is useful for + // generating a pattern to reference the future for the current branch. + // $skip is also used as an argument to `count!`, returning the index of + // the current select branch. + $( ( $($skip:tt)* ) $bind:pat = $fut:expr => $handle:expr, )+ + + // Expression used to special handle cancellation when awaiting select + ; $on_cancel:expr + + // Fallback expression used when all select branches have been disabled. + ; $else:expr + }) => {{ + use $crate::context::DurableFuture; + use $crate::context::macro_support::SealedDurableFuture; + + let futures_init = ($( $fut, )+); + let handles = vec![$( + $crate::count_field!(futures_init.$($skip)*).handle() + ,)+]; + let select_fut = futures_init.0.inner_context().select(handles); + + match select_fut.await { + $( + Ok($crate::count!( $($skip)* )) => { + match $crate::count_field!(futures_init.$($skip)*).await { + $bind => { + $handle + } + _ => { + $else + } + } + } + )* + Ok(_) => { + unreachable!("Select fut returned index out of bounds") + } + Err(_) => { + $on_cancel + } + _ => unreachable!("reaching this means there probably is an off by one bug"), + } + }}; + + // ==== Normalize ===== + + // These rules match a single `select!` branch and normalize it for + // processing by the first rule. + + (@ { $($t:tt)* } ) => { + // No `else` branch + $crate::select!(@{ $($t)*; { Err(TerminalError::new_with_code(409, "cancelled"))? }; panic!("No else branch is defined")}) + }; + (@ { $($t:tt)* } on_cancel => $on_cancel:expr $(,)?) => { + // on_cancel branch + $crate::select!(@{ $($t)*; $on_cancel; panic!("No else branch is defined") }) + }; + (@ { $($t:tt)* } else => $else:expr $(,)?) => { + // on_cancel branch + $crate::select!(@{ $($t)*; { Err(TerminalError::new_with_code(409, "cancelled"))? }; $else }) + }; + (@ { $($t:tt)* } on_cancel => $on_cancel:expr, else => $else:expr $(,)?) => { + // on_cancel branch + $crate::select!(@{ $($t)*; $on_cancel; $else }) + }; + (@ { $($t:tt)* } else => $else:expr, on_cancel => $on_cancel:expr $(,)?) => { + // on_cancel branch + $crate::select!(@{ $($t)*; $on_cancel; $else }) + }; + (@ { ( $($s:tt)* ) $($t:tt)* } $p:pat = $f:expr => $h:block, $($r:tt)* ) => { + $crate::select!(@{ ($($s)* _) $($t)* ($($s)*) $p = $f => $h, } $($r)*) + }; + (@ { ( $($s:tt)* ) $($t:tt)* } $p:pat = $f:expr => $h:expr, $($r:tt)* ) => { + $crate::select!(@{ ($($s)* _) $($t)* ($($s)*) $p = $f => $h, } $($r)*) + }; + (@ { ( $($s:tt)* ) $($t:tt)* } $p:pat = $f:expr => $h:block ) => { + $crate::select!(@{ ($($s)* _) $($t)* ($($s)*) $p = $f => $h, }) + }; + (@ { ( $($s:tt)* ) $($t:tt)* } $p:pat = $f:expr => $h:expr ) => { + $crate::select!(@{ ($($s)* _) $($t)* ($($s)*) $p = $f => $h, }) + }; + + // ===== Entry point ===== + + (on_cancel => $on_cancel:expr $(,)? ) => {{ + compile_error!("select! cannot contain only on_cancel branch.") + }}; + (else => $else:expr $(,)? ) => {{ + compile_error!("select! cannot contain only else branch.") + }}; + + ( $p:pat = $($t:tt)* ) => { + $crate::select!(@{ () } $p = $($t)*) + }; + + () => { + compile_error!("select! requires at least one branch.") + }; +} + +// And here... we manually list out matches for up to 64 branches... I'm not +// happy about it either, but this is how we manage to use a declarative macro! + +#[macro_export] +#[doc(hidden)] +macro_rules! count { + () => { + 0 + }; + (_) => { + 1 + }; + (_ _) => { + 2 + }; + (_ _ _) => { + 3 + }; + (_ _ _ _) => { + 4 + }; + (_ _ _ _ _) => { + 5 + }; + (_ _ _ _ _ _) => { + 6 + }; + (_ _ _ _ _ _ _) => { + 7 + }; + (_ _ _ _ _ _ _ _) => { + 8 + }; + (_ _ _ _ _ _ _ _ _) => { + 9 + }; + (_ _ _ _ _ _ _ _ _ _) => { + 10 + }; + (_ _ _ _ _ _ _ _ _ _ _) => { + 11 + }; + (_ _ _ _ _ _ _ _ _ _ _ _) => { + 12 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _) => { + 13 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 14 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 15 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 16 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 17 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 18 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 19 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 20 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 21 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 22 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 23 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 24 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 25 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 26 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 27 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 28 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 29 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 30 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 31 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 32 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 33 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 34 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 35 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 36 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 37 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 38 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 39 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 40 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 41 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 42 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 43 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 44 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 45 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 46 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 47 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 48 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 49 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 50 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 51 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 52 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 53 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 54 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 55 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 56 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 57 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 58 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 59 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 60 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 61 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 62 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 63 + }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 64 + }; +} + +#[macro_export] +#[doc(hidden)] +macro_rules! count_field { + ($var:ident. ) => { + $var.0 + }; + ($var:ident. _) => { + $var.1 + }; + ($var:ident. _ _) => { + $var.2 + }; + ($var:ident. _ _ _) => { + $var.3 + }; + ($var:ident. _ _ _ _) => { + $var.4 + }; + ($var:ident. _ _ _ _ _) => { + $var.5 + }; + ($var:ident. _ _ _ _ _ _) => { + $var.6 + }; + ($var:ident. _ _ _ _ _ _ _) => { + $var.7 + }; + ($var:ident. _ _ _ _ _ _ _ _) => { + $var.8 + }; + ($var:ident. _ _ _ _ _ _ _ _ _) => { + $var.9 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _) => { + $var.10 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _) => { + $var.11 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.12 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.13 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.14 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.15 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.16 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.17 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.18 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.19 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.20 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.21 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.22 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.23 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.24 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.25 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.26 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.27 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.28 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.29 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.30 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.31 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.32 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.33 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.34 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.35 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.36 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.37 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.38 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.39 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.40 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.41 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.42 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.43 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.44 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.45 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.46 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.47 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.48 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.49 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.50 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.51 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.52 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.53 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.54 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.55 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.56 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.57 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.58 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.59 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.60 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.61 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.62 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.63 + }; + ($var:ident. _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + $var.64 + }; +} diff --git a/src/endpoint/context.rs b/src/endpoint/context.rs index 66e2bfb..3505b53 100644 --- a/src/endpoint/context.rs +++ b/src/endpoint/context.rs @@ -1,8 +1,11 @@ use crate::context::{ - CallFuture, InvocationHandle, Request, RequestTarget, RunClosure, RunFuture, RunRetryPolicy, + CallFuture, DurableFuture, InvocationHandle, Request, RequestTarget, RunClosure, RunFuture, + RunRetryPolicy, }; use crate::endpoint::futures::async_result_poll::VmAsyncResultPollFuture; +use crate::endpoint::futures::durable_future_impl::DurableFutureImpl; use crate::endpoint::futures::intercept_error::InterceptErrorFuture; +use crate::endpoint::futures::select_poll::VmSelectAsyncResultPollFuture; use crate::endpoint::futures::trap::TrapFuture; use crate::endpoint::handler_state::HandlerStateNotifier; use crate::endpoint::{Error, ErrorInner, InputReceiver, OutputSender}; @@ -73,38 +76,6 @@ impl ContextInternalInner { } } -/// Internal context interface. -/// -/// For the high level interfaces, look at [`crate::context`]. -#[derive(Clone)] -pub struct ContextInternal { - svc_name: String, - handler_name: String, - inner: Arc>, -} - -impl ContextInternal { - pub(super) fn new( - vm: CoreVM, - svc_name: String, - handler_name: String, - read: InputReceiver, - write: OutputSender, - handler_state: HandlerStateNotifier, - ) -> Self { - Self { - svc_name, - handler_name, - inner: Arc::new(Mutex::new(ContextInternalInner::new( - vm, - read, - write, - handler_state, - ))), - } - } -} - #[allow(unused)] const fn is_send_sync() {} const _: () = is_send_sync::(); @@ -127,6 +98,22 @@ macro_rules! unwrap_or_trap { }; } +macro_rules! unwrap_or_trap_durable_future { + ($ctx:expr, $inner_lock:expr, $res:expr) => { + match $res { + Ok(t) => t, + Err(e) => { + $inner_lock.fail(e.into()); + return DurableFutureImpl::new( + $ctx.clone(), + NotificationHandle::from(u32::MAX), + Either::Right(TrapFuture::default()), + ); + } + } + }; +} + #[derive(Debug, Eq, PartialEq)] pub struct InputMetadata { pub invocation_id: String, @@ -163,7 +150,37 @@ impl From for Target { } } +/// Internal context interface. +/// +/// For the high level interfaces, look at [`crate::context`]. +#[derive(Clone)] +pub struct ContextInternal { + svc_name: String, + handler_name: String, + inner: Arc>, +} + impl ContextInternal { + pub(super) fn new( + vm: CoreVM, + svc_name: String, + handler_name: String, + read: InputReceiver, + write: OutputSender, + handler_state: HandlerStateNotifier, + ) -> Self { + Self { + svc_name, + handler_name, + inner: Arc::new(Mutex::new(ContextInternalInner::new( + vm, + read, + write, + handler_state, + ))), + } + } + pub fn service_name(&self) -> &str { &self.svc_name } @@ -305,15 +322,26 @@ impl ContextInternal { inner_lock.maybe_flip_span_replaying_field(); } + pub fn select( + &self, + handles: Vec, + ) -> impl Future> + Send { + InterceptErrorFuture::new( + self.clone(), + VmSelectAsyncResultPollFuture::new(self.inner.clone(), handles).map_err(Error::from), + ) + } + pub fn sleep( &self, sleep_duration: Duration, - ) -> impl Future> + Send { + ) -> impl DurableFuture> + Send { let now = SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .expect("Duration since unix epoch cannot fail"); let mut inner_lock = must_lock!(self.inner); - let handle = unwrap_or_trap!( + let handle = unwrap_or_trap_durable_future!( + self, inner_lock, inner_lock.vm.sys_sleep(now + sleep_duration, Some(now)) ); @@ -330,7 +358,7 @@ impl ContextInternal { Err(e) => Err(e), }); - Either::Left(InterceptErrorFuture::new(self.clone(), poll_future)) + DurableFutureImpl::new(self.clone(), handle, Either::Left(poll_future)) } pub fn request(&self, request_target: RequestTarget, req: Req) -> Request { @@ -342,17 +370,27 @@ impl ContextInternal { request_target: RequestTarget, idempotency_key: Option, req: Req, - ) -> impl CallFuture> + Send { + ) -> impl CallFuture + Send { let mut inner_lock = must_lock!(self.inner); let mut target: Target = request_target.into(); target.idempotency_key = idempotency_key; - let input = unwrap_or_trap!( - inner_lock, - Req::serialize(&req).map_err(|e| Error::serialization("call", e)) - ); + let call_result = Req::serialize(&req) + .map_err(|e| Error::serialization("call", e)) + .and_then(|input| inner_lock.vm.sys_call(target, input).map_err(Into::into)); - let call_handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_call(target, input)); + let call_handle = match call_result { + Ok(t) => t, + Err(e) => { + inner_lock.fail(e); + return CallFutureImpl { + invocation_id_future: Either::Right(TrapFuture::default()).shared(), + result_future: Either::Right(TrapFuture::default()), + call_notification_handle: NotificationHandle::from(u32::MAX), + ctx: self.clone(), + }; + } + }; inner_lock.maybe_flip_span_replaying_field(); drop(inner_lock); @@ -374,31 +412,29 @@ impl ContextInternal { Err(e) => Err(e), }), ); - let result_future = InterceptErrorFuture::new( - self.clone(), - get_async_result( - Arc::clone(&self.inner), - call_handle.call_notification_handle, - ) - .map(|res| match res { - Ok(Value::Success(mut s)) => Ok(Ok( - Res::deserialize(&mut s).map_err(|e| Error::deserialization("call", e))? - )), - Ok(Value::Failure(f)) => Ok(Err(TerminalError::from(f))), - Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: <&'static str>::from(v), - syscall: "call", - } - .into()), - Err(e) => Err(e), - }), - ); + let result_future = get_async_result( + Arc::clone(&self.inner), + call_handle.call_notification_handle, + ) + .map(|res| match res { + Ok(Value::Success(mut s)) => Ok(Ok( + Res::deserialize(&mut s).map_err(|e| Error::deserialization("call", e))? + )), + Ok(Value::Failure(f)) => Ok(Err(TerminalError::from(f))), + Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: <&'static str>::from(v), + syscall: "call", + } + .into()), + Err(e) => Err(e), + }); - Either::Left(CallFutureImpl { - invocation_id_future: invocation_id_fut.shared(), - result_future, + CallFutureImpl { + invocation_id_future: Either::Left(invocation_id_fut).shared(), + result_future: Either::Left(result_future), + call_notification_handle: call_handle.call_notification_handle, ctx: self.clone(), - }) + } } pub fn send( @@ -475,7 +511,7 @@ impl ContextInternal { &self, ) -> ( String, - impl Future> + Send, + impl DurableFuture> + Send, ) { let mut inner_lock = must_lock!(self.inner); let maybe_awakeable_id_and_handle = inner_lock.vm.sys_awakeable(); @@ -489,7 +525,11 @@ impl ContextInternal { // TODO NOW this is REALLY BAD. The reason for this is that we would need to return a future of a future instead, which is not nice. // we assume for the time being this works because no user should use the awakeable without doing any other syscall first, which will prevent this invalid awakeable id to work in the first place. "invalid".to_owned(), - Either::Right(TrapFuture::default()), + DurableFutureImpl::new( + self.clone(), + NotificationHandle::from(u32::MAX), + Either::Right(TrapFuture::default()), + ), ); } }; @@ -510,7 +550,7 @@ impl ContextInternal { ( awakeable_id, - Either::Left(InterceptErrorFuture::new(self.clone(), poll_future)), + DurableFutureImpl::new(self.clone(), handle, Either::Left(poll_future)), ) } @@ -537,9 +577,13 @@ impl ContextInternal { pub fn promise( &self, name: &str, - ) -> impl Future> + Send { + ) -> impl DurableFuture> + Send { let mut inner_lock = must_lock!(self.inner); - let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_get_promise(name.to_owned())); + let handle = unwrap_or_trap_durable_future!( + self, + inner_lock, + inner_lock.vm.sys_get_promise(name.to_owned()) + ); inner_lock.maybe_flip_span_replaying_field(); drop(inner_lock); @@ -557,7 +601,7 @@ impl ContextInternal { Err(e) => Err(e), }); - Either::Left(InterceptErrorFuture::new(self.clone(), poll_future)) + DurableFutureImpl::new(self.clone(), handle, Either::Left(poll_future)) } pub fn peek_promise( @@ -871,38 +915,13 @@ where } } -struct SendRequestHandle { - invocation_id_future: Shared, - ctx: ContextInternal, -} - -impl> + Send> InvocationHandle - for SendRequestHandle -{ - fn invocation_id(&self) -> impl Future> + Send { - Shared::clone(&self.invocation_id_future) - } - - fn cancel(&self) -> impl Future> + Send { - let cloned_invocation_id_fut = Shared::clone(&self.invocation_id_future); - let cloned_ctx = Arc::clone(&self.ctx.inner); - async move { - let inv_id = cloned_invocation_id_fut.await?; - let mut inner_lock = must_lock!(cloned_ctx); - let _ = inner_lock.vm.sys_cancel_invocation(inv_id); - inner_lock.maybe_flip_span_replaying_field(); - drop(inner_lock); - Ok(()) - } - } -} - pin_project! { struct CallFutureImpl { #[pin] invocation_id_future: Shared, #[pin] result_future: ResultFut, + call_notification_handle: NotificationHandle, ctx: ContextInternal, } } @@ -910,13 +929,25 @@ pin_project! { impl Future for CallFutureImpl where InvIdFut: Future> + Send, - ResultFut: Future> + Send, + ResultFut: Future, Error>> + Send, { - type Output = ResultFut::Output; + type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); - this.result_future.poll(cx) + let result = ready!(this.result_future.poll(cx)); + + match result { + Ok(r) => Poll::Ready(r), + Err(e) => { + this.ctx.fail(e); + + // Here is the secret sauce. This will immediately cause the whole future chain to be polled, + // but the poll here will be intercepted by HandlerStateAwareFuture + cx.waker().wake_by_ref(); + Poll::Pending + } + } } } @@ -942,39 +973,59 @@ where } } -impl CallFuture> - for CallFutureImpl +impl CallFuture for CallFutureImpl where InvIdFut: Future> + Send, - ResultFut: Future> + Send, + ResultFut: Future, Error>> + Send, { + type Response = Res; } -impl InvocationHandle for Either +impl crate::context::macro_support::SealedDurableFuture + for CallFutureImpl where - A: InvocationHandle, - B: InvocationHandle, + InvIdFut: Future, { - fn invocation_id(&self) -> impl Future> + Send { - match self { - Either::Left(l) => Either::Left(l.invocation_id()), - Either::Right(r) => Either::Right(r.invocation_id()), - } + fn inner_context(&self) -> ContextInternal { + self.ctx.clone() } - fn cancel(&self) -> impl Future> + Send { - match self { - Either::Left(l) => Either::Left(l.cancel()), - Either::Right(r) => Either::Right(r.cancel()), - } + fn handle(&self) -> NotificationHandle { + self.call_notification_handle } } -impl CallFuture for Either +impl DurableFuture for CallFutureImpl where - A: CallFuture, - B: CallFuture, + InvIdFut: Future> + Send, + ResultFut: Future, Error>> + Send, +{ +} + +struct SendRequestHandle { + invocation_id_future: Shared, + ctx: ContextInternal, +} + +impl> + Send> InvocationHandle + for SendRequestHandle { + fn invocation_id(&self) -> impl Future> + Send { + Shared::clone(&self.invocation_id_future) + } + + fn cancel(&self) -> impl Future> + Send { + let cloned_invocation_id_fut = Shared::clone(&self.invocation_id_future); + let cloned_ctx = Arc::clone(&self.ctx.inner); + async move { + let inv_id = cloned_invocation_id_fut.await?; + let mut inner_lock = must_lock!(cloned_ctx); + let _ = inner_lock.vm.sys_cancel_invocation(inv_id); + inner_lock.maybe_flip_span_replaying_field(); + drop(inner_lock); + Ok(()) + } + } } struct InvocationIdBackedInvocationHandle { @@ -996,6 +1047,26 @@ impl InvocationHandle for InvocationIdBackedInvocationHandle { } } +impl InvocationHandle for Either +where + A: InvocationHandle, + B: InvocationHandle, +{ + fn invocation_id(&self) -> impl Future> + Send { + match self { + Either::Left(l) => Either::Left(l.invocation_id()), + Either::Right(r) => Either::Right(r.invocation_id()), + } + } + + fn cancel(&self) -> impl Future> + Send { + match self { + Either::Left(l) => Either::Left(l.cancel()), + Either::Right(r) => Either::Right(r.cancel()), + } + } +} + impl Error { fn serialization( syscall: &'static str, diff --git a/src/endpoint/futures/durable_future_impl.rs b/src/endpoint/futures/durable_future_impl.rs new file mode 100644 index 0000000..7e5881a --- /dev/null +++ b/src/endpoint/futures/durable_future_impl.rs @@ -0,0 +1,59 @@ +use crate::context::DurableFuture; +use crate::endpoint::{ContextInternal, Error}; +use pin_project_lite::pin_project; +use restate_sdk_shared_core::NotificationHandle; +use std::future::Future; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; + +pin_project! { + /// Future that intercepts errors of inner future, and passes them to ContextInternal + pub struct DurableFutureImpl{ + #[pin] + fut: F, + handle: NotificationHandle, + ctx: ContextInternal + } +} + +impl DurableFutureImpl { + pub fn new(ctx: ContextInternal, handle: NotificationHandle, fut: F) -> Self { + Self { fut, handle, ctx } + } +} + +impl Future for DurableFutureImpl +where + F: Future>, +{ + type Output = R; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let result = ready!(this.fut.poll(cx)); + + match result { + Ok(r) => Poll::Ready(r), + Err(e) => { + this.ctx.fail(e); + + // Here is the secret sauce. This will immediately cause the whole future chain to be polled, + // but the poll here will be intercepted by HandlerStateAwareFuture + cx.waker().wake_by_ref(); + Poll::Pending + } + } + } +} + +impl crate::context::macro_support::SealedDurableFuture for DurableFutureImpl { + fn inner_context(&self) -> ContextInternal { + self.ctx.clone() + } + + fn handle(&self) -> NotificationHandle { + self.handle + } +} + +impl DurableFuture for DurableFutureImpl where F: Future> {} diff --git a/src/endpoint/futures/mod.rs b/src/endpoint/futures/mod.rs index 9f0b03f..a9d7946 100644 --- a/src/endpoint/futures/mod.rs +++ b/src/endpoint/futures/mod.rs @@ -1,4 +1,6 @@ pub mod async_result_poll; +pub mod durable_future_impl; pub mod handler_state_aware; pub mod intercept_error; +pub mod select_poll; pub mod trap; diff --git a/src/endpoint/futures/select_poll.rs b/src/endpoint/futures/select_poll.rs new file mode 100644 index 0000000..2732e24 --- /dev/null +++ b/src/endpoint/futures/select_poll.rs @@ -0,0 +1,151 @@ +use crate::endpoint::context::ContextInternalInner; +use crate::endpoint::ErrorInner; +use crate::errors::TerminalError; +use restate_sdk_shared_core::{ + DoProgressResponse, Error as CoreError, NotificationHandle, TakeOutputResult, TerminalFailure, + VM, +}; +use std::future::Future; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::task::Poll; + +pub(crate) struct VmSelectAsyncResultPollFuture { + state: Option, +} + +impl VmSelectAsyncResultPollFuture { + pub fn new(ctx: Arc>, handles: Vec) -> Self { + VmSelectAsyncResultPollFuture { + state: Some(VmSelectAsyncResultPollState::Init { ctx, handles }), + } + } +} + +enum VmSelectAsyncResultPollState { + Init { + ctx: Arc>, + handles: Vec, + }, + PollProgress { + ctx: Arc>, + handles: Vec, + }, + WaitingInput { + ctx: Arc>, + handles: Vec, + }, +} + +macro_rules! must_lock { + ($mutex:expr) => { + $mutex.try_lock().expect("You're trying to await two futures at the same time and/or trying to perform some operation on the restate context while awaiting a future. This is not supported!") + }; +} + +impl Future for VmSelectAsyncResultPollFuture { + type Output = Result, ErrorInner>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + loop { + match self + .state + .take() + .expect("Future should not be polled after Poll::Ready") + { + VmSelectAsyncResultPollState::Init { ctx, handles } => { + let mut inner_lock = must_lock!(ctx); + + // Let's consume some output to begin with + let out = inner_lock.vm.take_output(); + match out { + TakeOutputResult::Buffer(b) => { + if !inner_lock.write.send(b) { + return Poll::Ready(Err(ErrorInner::Suspended)); + } + } + TakeOutputResult::EOF => { + return Poll::Ready(Err(ErrorInner::UnexpectedOutputClosed)) + } + } + + // We can now start polling + drop(inner_lock); + self.state = Some(VmSelectAsyncResultPollState::PollProgress { ctx, handles }); + } + VmSelectAsyncResultPollState::WaitingInput { ctx, handles } => { + let mut inner_lock = must_lock!(ctx); + + let read_result = match inner_lock.read.poll_recv(cx) { + Poll::Ready(t) => t, + Poll::Pending => { + // Still need to wait for input + drop(inner_lock); + self.state = + Some(VmSelectAsyncResultPollState::WaitingInput { ctx, handles }); + return Poll::Pending; + } + }; + + // Pass read result to VM + match read_result { + Some(Ok(b)) => inner_lock.vm.notify_input(b), + Some(Err(e)) => inner_lock.vm.notify_error( + CoreError::new(500u16, format!("Error when reading the body {e:?}",)), + None, + ), + None => inner_lock.vm.notify_input_closed(), + } + + // It's time to poll progress again + drop(inner_lock); + self.state = Some(VmSelectAsyncResultPollState::PollProgress { ctx, handles }); + } + VmSelectAsyncResultPollState::PollProgress { ctx, handles } => { + let mut inner_lock = must_lock!(ctx); + + match inner_lock.vm.do_progress(handles.clone()) { + Ok(DoProgressResponse::AnyCompleted) => { + // We're good, we got the response + } + Ok(DoProgressResponse::ReadFromInput) => { + drop(inner_lock); + self.state = + Some(VmSelectAsyncResultPollState::WaitingInput { ctx, handles }); + continue; + } + Ok(DoProgressResponse::ExecuteRun(_)) => { + unimplemented!() + } + Ok(DoProgressResponse::WaitingPendingRun) => { + unimplemented!() + } + Ok(DoProgressResponse::CancelSignalReceived) => { + return Poll::Ready(Ok(Err(TerminalFailure { + code: 409, + message: "cancelled".to_string(), + } + .into()))) + } + Err(e) => { + return Poll::Ready(Err(e.into())); + } + }; + + // DoProgress might cause a flip of the replaying state + inner_lock.maybe_flip_span_replaying_field(); + + // At this point let's try to take the notification + for (idx, handle) in handles.iter().enumerate() { + if inner_lock.vm.is_completed(*handle) { + return Poll::Ready(Ok(Ok(idx))); + } + } + panic!( + "This is not supposed to happen, none of the given handles were completed even though poll progress completed with AnyCompleted" + ) + } + } + } + } +} diff --git a/src/endpoint/futures/trap.rs b/src/endpoint/futures/trap.rs index b0a4ae9..17e0ac4 100644 --- a/src/endpoint/futures/trap.rs +++ b/src/endpoint/futures/trap.rs @@ -1,4 +1,4 @@ -use crate::context::{CallFuture, InvocationHandle}; +use crate::context::InvocationHandle; use crate::errors::TerminalError; use std::future::Future; use std::marker::PhantomData; @@ -32,5 +32,3 @@ impl InvocationHandle for TrapFuture { TrapFuture::default() } } - -impl CallFuture for TrapFuture {} diff --git a/test-services/exclusions.yaml b/test-services/exclusions.yaml index 5c9150a..7a41b81 100644 --- a/test-services/exclusions.yaml +++ b/test-services/exclusions.yaml @@ -1,21 +1,16 @@ exclusions: "alwaysSuspending": - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwaitAny" - - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwakeableTimeoutCommand" - "dev.restate.sdktesting.tests.Combinators.firstSuccessfulCompletedAwakeable" "default": - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwaitAny" - - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwakeableTimeoutCommand" - "dev.restate.sdktesting.tests.Combinators.firstSuccessfulCompletedAwakeable" "singleThreadSinglePartition": - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwaitAny" - - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwakeableTimeoutCommand" - "dev.restate.sdktesting.tests.Combinators.firstSuccessfulCompletedAwakeable" "threeNodes": - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwaitAny" - - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwakeableTimeoutCommand" - "dev.restate.sdktesting.tests.Combinators.firstSuccessfulCompletedAwakeable" "threeNodesAlwaysSuspending": - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwaitAny" - - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwakeableTimeoutCommand" - "dev.restate.sdktesting.tests.Combinators.firstSuccessfulCompletedAwakeable" diff --git a/test-services/src/virtual_object_command_interpreter.rs b/test-services/src/virtual_object_command_interpreter.rs index 994d20c..f269ab2 100644 --- a/test-services/src/virtual_object_command_interpreter.rs +++ b/test-services/src/virtual_object_command_interpreter.rs @@ -106,9 +106,19 @@ impl VirtualObjectCommandInterpreter for VirtualObjectCommandInterpreterImpl { Command::AwaitAnySuccessful { .. } => Err(anyhow!( "AwaitAnySuccessful is currently unsupported in the Rust SDK" ))?, - Command::AwaitAwakeableOrTimeout { .. } => Err(anyhow!( - "AwaitAwakeableOrTimeout is currently unsupported in the Rust SDK" - ))?, + Command::AwaitAwakeableOrTimeout { awakeable_key, timeout_millis } => { + let (awakeable_id, awk_fut) = context.awakeable::(); + context.set::(&format!("awk-{awakeable_key}"), awakeable_id); + + last_result = restate_sdk::select! { + res = awk_fut => { + res + }, + _ = context.sleep(Duration::from_millis(timeout_millis)) => { + Err(TerminalError::new("await-timeout")) + } + }?; + }, Command::AwaitOne { command } => { last_result = match command { AwaitableCommand::CreateAwakeable { awakeable_key } => { From cdea04e534159b4a7d9bb4d443406fc13eb33a32 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Fri, 7 Mar 2025 18:13:20 +0100 Subject: [PATCH 2/2] First pass at a select statement --- src/context/macro_support.rs | 1 + src/context/select.rs | 4 ++++ test-services/src/virtual_object_command_interpreter.rs | 7 +++++-- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/context/macro_support.rs b/src/context/macro_support.rs index 2479f69..7c015dd 100644 --- a/src/context/macro_support.rs +++ b/src/context/macro_support.rs @@ -2,6 +2,7 @@ use crate::endpoint::ContextInternal; use restate_sdk_shared_core::NotificationHandle; // Sealed future trait, used by select statement +#[doc(hidden)] pub trait SealedDurableFuture { fn inner_context(&self) -> ContextInternal; fn handle(&self) -> NotificationHandle; diff --git a/src/context/select.rs b/src/context/select.rs index 3c1ad6c..f66dec8 100644 --- a/src/context/select.rs +++ b/src/context/select.rs @@ -1,3 +1,7 @@ +// Thanks tokio for the help! +// https://github.com/tokio-rs/tokio/blob/a258bff7018940b438e5de3fb846588454df4e4d/tokio/src/macros/select.rs +// MIT License + /// Select macro, alike tokio::select: /// /// ```rust diff --git a/test-services/src/virtual_object_command_interpreter.rs b/test-services/src/virtual_object_command_interpreter.rs index f269ab2..d401c91 100644 --- a/test-services/src/virtual_object_command_interpreter.rs +++ b/test-services/src/virtual_object_command_interpreter.rs @@ -106,7 +106,10 @@ impl VirtualObjectCommandInterpreter for VirtualObjectCommandInterpreterImpl { Command::AwaitAnySuccessful { .. } => Err(anyhow!( "AwaitAnySuccessful is currently unsupported in the Rust SDK" ))?, - Command::AwaitAwakeableOrTimeout { awakeable_key, timeout_millis } => { + Command::AwaitAwakeableOrTimeout { + awakeable_key, + timeout_millis, + } => { let (awakeable_id, awk_fut) = context.awakeable::(); context.set::(&format!("awk-{awakeable_key}"), awakeable_id); @@ -118,7 +121,7 @@ impl VirtualObjectCommandInterpreter for VirtualObjectCommandInterpreterImpl { Err(TerminalError::new("await-timeout")) } }?; - }, + } Command::AwaitOne { command } => { last_result = match command { AwaitableCommand::CreateAwakeable { awakeable_key } => {