From f1a580427f4457541d6d7eb38cf6092a669ff610 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Thu, 22 Aug 2024 18:07:58 +0200 Subject: [PATCH 1/7] First pass at the implementation of side effect retry --- Cargo.toml | 2 +- src/context/mod.rs | 19 +++- src/context/run.rs | 70 ++++++++++++ src/endpoint/context.rs | 203 ++++++++++++++++++++--------------- src/endpoint/mod.rs | 23 ++-- src/lib.rs | 2 +- test-services/src/failing.rs | 27 +++-- 7 files changed, 231 insertions(+), 115 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2270314..4fa5bd5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,7 @@ pin-project-lite = "0.2" rand = { version = "0.8.5", optional = true } regress = "0.10" restate-sdk-macros = { version = "0.2.1", path = "macros" } -restate-sdk-shared-core = { version = "0.0.5" } +restate-sdk-shared-core = { git = "https://github.com/restatedev/sdk-shared-core.git", branch = "side-effect-retry" } serde = "1.0" serde_json = "1.0" thiserror = "1.0.63" diff --git a/src/context/mod.rs b/src/context/mod.rs index 69ca125..c2ce3aa 100644 --- a/src/context/mod.rs +++ b/src/context/mod.rs @@ -10,7 +10,7 @@ mod request; mod run; pub use request::{Request, RequestTarget}; -pub use run::RunClosure; +pub use run::{RunClosure, RunRetryPolicy}; pub type HeaderMap = http::HeaderMap; /// Service handler context. @@ -385,6 +385,23 @@ pub trait ContextSideEffects<'ctx>: private::SealedContext<'ctx> { self.inner_context().run(name, run_closure) } + /// Run a non-deterministic operation and record its result. + /// + fn run_with_retry( + &self, + name: &'ctx str, + retry_policy: RunRetryPolicy, + run_closure: R, + ) -> impl Future> + 'ctx + where + R: RunClosure + Send + Sync + 'ctx, + T: Serialize + Deserialize, + F: Future> + Send + Sync + 'ctx, + { + self.inner_context() + .run_with_retry(name, retry_policy, run_closure) + } + /// Return a random seed inherently predictable, based on the invocation id, which is not secret. /// /// This value is stable during the invocation lifecycle, thus across retries. diff --git a/src/context/run.rs b/src/context/run.rs index 1d90fc4..1c70583 100644 --- a/src/context/run.rs +++ b/src/context/run.rs @@ -1,6 +1,7 @@ use crate::errors::HandlerResult; use crate::serde::{Deserialize, Serialize}; use std::future::Future; +use std::time::Duration; /// Run closure trait pub trait RunClosure { @@ -23,3 +24,72 @@ where self() } } + +/// This struct represents the policy to execute retries for run closures. +#[derive(Debug, Clone)] +pub struct RunRetryPolicy { + pub(crate) initial_interval: Duration, + pub(crate) factor: f32, + pub(crate) max_interval: Option, + pub(crate) max_attempts: Option, + pub(crate) max_duration: Option, +} + +impl Default for RunRetryPolicy { + fn default() -> Self { + Self { + initial_interval: Duration::from_millis(100), + factor: 2.0, + max_interval: Some(Duration::from_secs(2)), + max_attempts: None, + max_duration: Some(Duration::from_secs(50)), + } + } +} + +impl RunRetryPolicy { + /// Create a new retry policy. + pub fn new() -> Self { + Self { + initial_interval: Duration::from_millis(100), + factor: 1.0, + max_interval: None, + max_attempts: None, + max_duration: None, + } + } + + /// Initial interval for the first retry attempt. + pub fn with_initial_interval(mut self, initial_interval: Duration) -> Self { + self.initial_interval = initial_interval; + self + } + + /// Maximum interval between retries. + pub fn with_factor(mut self, factor: f32) -> Self { + self.factor = factor; + self + } + + /// Maximum interval between retries. + pub fn with_max_interval(mut self, max_interval: Duration) -> Self { + self.max_interval = Some(max_interval); + self + } + + /// Gives up retrying when either this number of attempts is reached, + /// or `max_duration` (if set) is reached first. + /// Infinite retries if this field and `max_duration` are unset. + pub fn with_max_attempts(mut self, max_attempts: u32) -> Self { + self.max_attempts = Some(max_attempts); + self + } + + /// Gives up retrying when either the retry loop lasted for this given max duration, + /// or `max_attempts` (if set) is reached first. + /// Infinite retries if this field and `max_attempts` are unset. + pub fn with_max_duration(mut self, max_duration: Duration) -> Self { + self.max_duration = Some(max_duration); + self + } +} diff --git a/src/endpoint/context.rs b/src/endpoint/context.rs index 29eb8f1..56bbd68 100644 --- a/src/endpoint/context.rs +++ b/src/endpoint/context.rs @@ -1,22 +1,21 @@ -use crate::context::{Request, RequestTarget, RunClosure}; +use crate::context::{Request, RequestTarget, RunClosure, RunRetryPolicy}; use crate::endpoint::futures::{InterceptErrorFuture, TrapFuture}; use crate::endpoint::handler_state::HandlerStateNotifier; use crate::endpoint::{Error, ErrorInner, InputReceiver, OutputSender}; use crate::errors::{HandlerErrorInner, HandlerResult, TerminalError}; use crate::serde::{Deserialize, Serialize}; -use bytes::Bytes; use futures::future::Either; use futures::{FutureExt, TryFutureExt}; use restate_sdk_shared_core::{ - AsyncResultHandle, CoreVM, NonEmptyValue, RunEnterResult, SuspendedOrVMError, TakeOutputResult, - Target, VMError, Value, VM, + AsyncResultHandle, CoreVM, Failure, NonEmptyValue, RetryPolicy, RunEnterResult, RunExitResult, + SuspendedOrVMError, TakeOutputResult, Target, VMError, Value, VM, }; use std::collections::HashMap; use std::future::{ready, Future}; use std::pin::Pin; use std::sync::{Arc, Mutex}; use std::task::Poll; -use std::time::Duration; +use std::time::{Duration, Instant}; pub struct ContextInternalInner { vm: CoreVM, @@ -42,7 +41,7 @@ impl ContextInternalInner { pub(super) fn fail(&mut self, e: Error) { self.vm - .notify_error(e.0.to_string().into(), format!("{:#}", e.0).into()); + .notify_error(e.0.to_string().into(), format!("{:#}", e.0).into(), None); self.handler_state.mark_error(e); } } @@ -130,37 +129,38 @@ impl ContextInternal { pub fn input(&self) -> impl Future { let mut inner_lock = must_lock!(self.inner); - let input_result = inner_lock - .vm - .sys_input() - .map_err(ErrorInner::VM) - .map(|raw_input| { - let headers = http::HeaderMap::::try_from( - &raw_input - .headers - .into_iter() - .map(|h| (h.key.to_string(), h.value.to_string())) - .collect::>(), - ) - .map_err(|e| { - TerminalError::new_with_code(400, format!("Cannot decode headers: {e:?}")) - })?; + let input_result = + inner_lock + .vm + .sys_input() + .map_err(ErrorInner::VM) + .map(|mut raw_input| { + let headers = http::HeaderMap::::try_from( + &raw_input + .headers + .into_iter() + .map(|h| (h.key.to_string(), h.value.to_string())) + .collect::>(), + ) + .map_err(|e| { + TerminalError::new_with_code(400, format!("Cannot decode headers: {e:?}")) + })?; - Ok::<_, TerminalError>(( - T::deserialize(&mut (raw_input.input.into())).map_err(|e| { - TerminalError::new_with_code( - 400, - format!("Cannot decode input payload: {e:?}"), - ) - })?, - InputMetadata { - invocation_id: raw_input.invocation_id, - random_seed: raw_input.random_seed, - key: raw_input.key, - headers, - }, - )) - }); + Ok::<_, TerminalError>(( + T::deserialize(&mut (raw_input.input)).map_err(|e| { + TerminalError::new_with_code( + 400, + format!("Cannot decode input payload: {e:?}"), + ) + })?, + InputMetadata { + invocation_id: raw_input.invocation_id, + random_seed: raw_input.random_seed, + key: raw_input.key, + headers, + }, + )) + }); match input_result { Ok(Ok(i)) => { @@ -196,9 +196,8 @@ impl ContextInternal { let poll_future = self.create_poll_future(maybe_handle).map(|res| match res { Ok(Value::Void) => Ok(Ok(None)), - Ok(Value::Success(s)) => { - let mut b = Bytes::from(s); - let t = T::deserialize(&mut b).map_err(|e| ErrorInner::Deserialization { + Ok(Value::Success(mut s)) => { + let t = T::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { syscall: "get_state", err: Box::new(e), })?; @@ -241,7 +240,7 @@ impl ContextInternal { let mut inner_lock = must_lock!(self.inner); match t.serialize() { Ok(b) => { - let _ = inner_lock.vm.sys_state_set(key.to_owned(), b.to_vec()); + let _ = inner_lock.vm.sys_state_set(key.to_owned(), b); } Err(e) => { inner_lock.fail( @@ -311,9 +310,7 @@ impl ContextInternal { } }; - let maybe_handle = inner_lock - .vm - .sys_call(request_target.into(), input.to_vec()); + let maybe_handle = inner_lock.vm.sys_call(request_target.into(), input); drop(inner_lock); let poll_future = self.create_poll_future(maybe_handle).map(|res| match res { @@ -321,9 +318,8 @@ impl ContextInternal { variant: "empty", syscall: "call", }), - Ok(Value::Success(s)) => { - let mut b = Bytes::from(s); - let t = Res::deserialize(&mut b).map_err(|e| ErrorInner::Deserialization { + Ok(Value::Success(mut s)) => { + let t = Res::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { syscall: "call", err: Box::new(e), })?; @@ -353,9 +349,7 @@ impl ContextInternal { match Req::serialize(&req) { Ok(t) => { - let _ = inner_lock - .vm - .sys_send(request_target.into(), t.to_vec(), delay); + let _ = inner_lock.vm.sys_send(request_target.into(), t, delay); } Err(e) => { inner_lock.fail( @@ -392,9 +386,8 @@ impl ContextInternal { variant: "empty", syscall: "awakeable", }), - Ok(Value::Success(s)) => { - let mut b = Bytes::from(s); - let t = T::deserialize(&mut b).map_err(|e| ErrorInner::Deserialization { + Ok(Value::Success(mut s)) => { + let t = T::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { syscall: "awakeable", err: Box::new(e), })?; @@ -420,7 +413,7 @@ impl ContextInternal { Ok(b) => { let _ = inner_lock .vm - .sys_complete_awakeable(id.to_owned(), NonEmptyValue::Success(b.to_vec())); + .sys_complete_awakeable(id.to_owned(), NonEmptyValue::Success(b)); } Err(e) => { inner_lock.fail( @@ -451,9 +444,8 @@ impl ContextInternal { variant: "empty", syscall: "promise", }), - Ok(Value::Success(s)) => { - let mut b = Bytes::from(s); - let t = T::deserialize(&mut b).map_err(|e| ErrorInner::Deserialization { + Ok(Value::Success(mut s)) => { + let t = T::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { syscall: "promise", err: Box::new(e), })?; @@ -478,9 +470,8 @@ impl ContextInternal { let poll_future = self.create_poll_future(maybe_handle).map(|res| match res { Ok(Value::Void) => Ok(Ok(None)), - Ok(Value::Success(s)) => { - let mut b = Bytes::from(s); - let t = T::deserialize(&mut b).map_err(|e| ErrorInner::Deserialization { + Ok(Value::Success(mut s)) => { + let t = T::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { syscall: "peek_promise", err: Box::new(e), })?; @@ -503,7 +494,7 @@ impl ContextInternal { Ok(b) => { let _ = inner_lock .vm - .sys_complete_promise(name.to_owned(), NonEmptyValue::Success(b.to_vec())); + .sys_complete_promise(name.to_owned(), NonEmptyValue::Success(b)); } Err(e) => { inner_lock.fail( @@ -528,6 +519,44 @@ impl ContextInternal { name: &'a str, run_closure: R, ) -> impl Future> + Send + Sync + 'a + where + R: RunClosure + Send + Sync + 'a, + T: Serialize + Deserialize, + F: Future> + Send + Sync + 'a, + { + self.run_inner(name, RetryPolicy::Infinite, run_closure) + } + + pub fn run_with_retry<'a, R, F, T>( + &'a self, + name: &'a str, + retry_policy: RunRetryPolicy, + run_closure: R, + ) -> impl Future> + Send + Sync + 'a + where + R: RunClosure + Send + Sync + 'a, + T: Serialize + Deserialize, + F: Future> + Send + Sync + 'a, + { + self.run_inner( + name, + RetryPolicy::Exponential { + initial_interval: retry_policy.initial_interval, + factor: retry_policy.factor, + max_interval: retry_policy.max_interval, + max_attempts: retry_policy.max_attempts, + max_duration: retry_policy.max_duration, + }, + run_closure, + ) + } + + fn run_inner<'a, R, F, T>( + &'a self, + name: &'a str, + retry_policy: RetryPolicy, + run_closure: R, + ) -> impl Future> + Send + Sync + 'a where R: RunClosure + Send + Sync + 'a, T: Serialize + Deserialize, @@ -540,38 +569,36 @@ impl ContextInternal { // Enter the side effect match enter_result.map_err(ErrorInner::VM)? { - RunEnterResult::Executed(NonEmptyValue::Success(v)) => { - let mut b = Bytes::from(v); - let t = T::deserialize(&mut b).map_err(|e| ErrorInner::Deserialization { + RunEnterResult::Executed(NonEmptyValue::Success(mut v)) => { + let t = T::deserialize(&mut v).map_err(|e| ErrorInner::Deserialization { syscall: "run", err: Box::new(e), })?; return Ok(Ok(t)); } RunEnterResult::Executed(NonEmptyValue::Failure(f)) => return Ok(Err(f.into())), - RunEnterResult::NotExecuted => {} + RunEnterResult::NotExecuted(_) => {} }; // We need to run the closure + let run_start = Instant::now(); let res = match run_closure.run().await { - Ok(t) => NonEmptyValue::Success( - T::serialize(&t) - .map_err(|e| ErrorInner::Serialization { - syscall: "run", - err: Box::new(e), - })? - .to_vec(), - ), - Err(e) => match e.0 { - HandlerErrorInner::Retryable(err) => { - return Err(ErrorInner::RunResult { - name: name.to_owned(), - err, - } - .into()) + Ok(t) => RunExitResult::Success(T::serialize(&t).map_err(|e| { + ErrorInner::Serialization { + syscall: "run", + err: Box::new(e), } + })?), + Err(e) => match e.0 { + HandlerErrorInner::Retryable(err) => RunExitResult::RetryableFailure { + attempt_duration: run_start.elapsed(), + failure: Failure { + code: 500, + message: err.to_string(), + }, + }, HandlerErrorInner::Terminal(t) => { - NonEmptyValue::Failure(TerminalError(t).into()) + RunExitResult::TerminalFailure(TerminalError(t).into()) } }, }; @@ -579,7 +606,7 @@ impl ContextInternal { let handle = { must_lock!(this) .vm - .sys_run_exit(res) + .sys_run_exit(res, retry_policy) .map_err(ErrorInner::VM)? }; @@ -591,9 +618,8 @@ impl ContextInternal { syscall: "run", } .into()), - Value::Success(s) => { - let mut b = Bytes::from(s); - let t = T::deserialize(&mut b).map_err(|e| ErrorInner::Deserialization { + Value::Success(mut s) => { + let t = T::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { syscall: "run", err: Box::new(e), })?; @@ -614,7 +640,7 @@ impl ContextInternal { let res_to_write = match res { Ok(success) => match T::serialize(&success) { - Ok(t) => NonEmptyValue::Success(t.to_vec()), + Ok(t) => NonEmptyValue::Success(t), Err(e) => { inner_lock.fail( ErrorInner::Serialization { @@ -647,7 +673,7 @@ impl ContextInternal { let out = inner_lock.vm.take_output(); if let TakeOutputResult::Buffer(b) = out { - if !inner_lock.write.send(b.into()) { + if !inner_lock.write.send(b) { // Nothing we can do anymore here } } @@ -707,7 +733,7 @@ impl Future for VmPollFuture { let out = inner_lock.vm.take_output(); match out { TakeOutputResult::Buffer(b) => { - if !inner_lock.write.send(b.into()) { + if !inner_lock.write.send(b) { self.state = Some(PollState::Failed(ErrorInner::Suspended)); continue; } @@ -751,10 +777,11 @@ impl Future for VmPollFuture { // Pass read result to VM match read_result { - Some(Ok(b)) => inner_lock.vm.notify_input(b.to_vec()), + Some(Ok(b)) => inner_lock.vm.notify_input(b), Some(Err(e)) => inner_lock.vm.notify_error( "Error when reading the body".into(), e.to_string().into(), + None, ), None => inner_lock.vm.notify_input_closed(), } diff --git a/src/endpoint/mod.rs b/src/endpoint/mod.rs index 3bf5a93..8609b3d 100644 --- a/src/endpoint/mod.rs +++ b/src/endpoint/mod.rs @@ -82,14 +82,13 @@ impl Error { /// Returns the HTTP status code for this error. pub fn status_code(&self) -> u16 { match &self.0 { - ErrorInner::VM(e) => e.code, + ErrorInner::VM(e) => e.code(), ErrorInner::UnknownService(_) | ErrorInner::UnknownServiceHandler(_, _) => 404, ErrorInner::Suspended | ErrorInner::UnexpectedOutputClosed | ErrorInner::UnexpectedValueVariantForSyscall { .. } | ErrorInner::Deserialization { .. } | ErrorInner::Serialization { .. } - | ErrorInner::RunResult { .. } | ErrorInner::HandlerResult { .. } => 500, ErrorInner::BadDiscovery(_) => 415, ErrorInner::Header { .. } | ErrorInner::BadPath { .. } => 400, @@ -135,12 +134,6 @@ enum ErrorInner { #[source] err: BoxError, }, - #[error("Run '{name}' failed with retryable error: {err:?}'")] - RunResult { - name: String, - #[source] - err: BoxError, - }, #[error("Handler failed with retryable error: {err:?}'")] HandlerResult { #[source] @@ -182,8 +175,8 @@ impl Default for Builder { Self { svcs: Default::default(), discovery: crate::discovery::Endpoint { - max_protocol_version: 1, - min_protocol_version: 1, + max_protocol_version: 2, + min_protocol_version: 2, protocol_mode: Some(crate::discovery::ProtocolMode::BidiStream), services: vec![], }, @@ -372,10 +365,12 @@ impl BidiStreamRunner { async fn init_loop_vm(vm: &mut CoreVM, input_rx: &mut InputReceiver) -> Result<(), ErrorInner> { while !vm.is_ready_to_execute().map_err(ErrorInner::VM)? { match input_rx.recv().await { - Some(Ok(b)) => vm.notify_input(b.to_vec()), - Some(Err(e)) => { - vm.notify_error("Error when reading the body".into(), e.to_string().into()) - } + Some(Ok(b)) => vm.notify_input(b), + Some(Err(e)) => vm.notify_error( + "Error when reading the body".into(), + e.to_string().into(), + None, + ), None => vm.notify_input_closed(), } } diff --git a/src/lib.rs b/src/lib.rs index e80c540..63b2599 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -58,7 +58,7 @@ pub mod prelude { pub use crate::context::{ Context, ContextAwakeables, ContextClient, ContextPromises, ContextReadState, ContextSideEffects, ContextTimers, ContextWriteState, HeaderMap, ObjectContext, Request, - SharedObjectContext, SharedWorkflowContext, WorkflowContext, + RunRetryPolicy, SharedObjectContext, SharedWorkflowContext, WorkflowContext, }; pub use crate::endpoint::Endpoint; pub use crate::errors::{HandlerError, HandlerResult, TerminalError}; diff --git a/test-services/src/failing.rs b/test-services/src/failing.rs index b8b5ec0..b477617 100644 --- a/test-services/src/failing.rs +++ b/test-services/src/failing.rs @@ -2,6 +2,7 @@ use anyhow::anyhow; use restate_sdk::prelude::*; use std::sync::atomic::{AtomicI32, Ordering}; use std::sync::Arc; +use std::time::Duration; #[restate_sdk::object] #[name = "Failing"] @@ -65,17 +66,23 @@ impl Failing for FailingImpl { ) -> HandlerResult { let cloned_eventual_side_effect_calls = Arc::clone(&self.eventual_success_side_effects); let success_attempt = context - .run("failing_side_effect", || async move { - let current_attempt = - cloned_eventual_side_effect_calls.fetch_add(1, Ordering::SeqCst) + 1; + .run_with_retry( + "failing_side_effect", + RunRetryPolicy::new() + .with_initial_interval(Duration::from_millis(10)) + .with_factor(1.0), + || async move { + let current_attempt = + cloned_eventual_side_effect_calls.fetch_add(1, Ordering::SeqCst) + 1; - if current_attempt >= 4 { - cloned_eventual_side_effect_calls.store(0, Ordering::SeqCst); - Ok(current_attempt) - } else { - Err(anyhow!("Failed at attempt ${current_attempt}").into()) - } - }) + if current_attempt >= 4 { + cloned_eventual_side_effect_calls.store(0, Ordering::SeqCst); + Ok(current_attempt) + } else { + Err(anyhow!("Failed at attempt ${current_attempt}").into()) + } + }, + ) .await?; Ok(success_attempt) From c0f82a8076f4ba49f8e4b8fcc06ec875a2b87aba Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Mon, 26 Aug 2024 15:21:58 +0200 Subject: [PATCH 2/7] Implementation of the new test from https://github.com/restatedev/sdk-test-suite/pull/10 --- test-services/src/failing.rs | 71 ++++++++++++++++++++++++++---------- 1 file changed, 52 insertions(+), 19 deletions(-) diff --git a/test-services/src/failing.rs b/test-services/src/failing.rs index b477617..12adafa 100644 --- a/test-services/src/failing.rs +++ b/test-services/src/failing.rs @@ -13,16 +13,22 @@ pub(crate) trait Failing { async fn call_terminally_failing_call(error_message: String) -> HandlerResult; #[name = "failingCallWithEventualSuccess"] async fn failing_call_with_eventual_success() -> HandlerResult; - #[name = "failingSideEffectWithEventualSuccess"] - async fn failing_side_effect_with_eventual_success() -> HandlerResult; #[name = "terminallyFailingSideEffect"] async fn terminally_failing_side_effect(error_message: String) -> HandlerResult<()>; + #[name = "sideEffectSucceedsAfterGivenAttempts"] + async fn side_effect_succeeds_after_given_attempts(minimum_attempts: i32) + -> HandlerResult; + #[name = "sideEffectFailsAfterGivenAttempts"] + async fn side_effect_fails_after_given_attempts( + retry_policy_max_retry_count: i32, + ) -> HandlerResult; } #[derive(Clone, Default)] pub(crate) struct FailingImpl { eventual_success_calls: Arc, eventual_success_side_effects: Arc, + eventual_failure_side_effects: Arc, } impl Failing for FailingImpl { @@ -60,11 +66,26 @@ impl Failing for FailingImpl { } } - async fn failing_side_effect_with_eventual_success( + async fn terminally_failing_side_effect( &self, context: ObjectContext<'_>, + error_message: String, + ) -> HandlerResult<()> { + context + .run("failing_side_effect", || async move { + Err::<(), _>(TerminalError::new(error_message).into()) + }) + .await?; + + unreachable!("This should be unreachable") + } + + async fn side_effect_succeeds_after_given_attempts( + &self, + context: ObjectContext<'_>, + minimum_attempts: i32, ) -> HandlerResult { - let cloned_eventual_side_effect_calls = Arc::clone(&self.eventual_success_side_effects); + let cloned_counter = Arc::clone(&self.eventual_success_side_effects); let success_attempt = context .run_with_retry( "failing_side_effect", @@ -72,14 +93,13 @@ impl Failing for FailingImpl { .with_initial_interval(Duration::from_millis(10)) .with_factor(1.0), || async move { - let current_attempt = - cloned_eventual_side_effect_calls.fetch_add(1, Ordering::SeqCst) + 1; + let current_attempt = cloned_counter.fetch_add(1, Ordering::SeqCst) + 1; - if current_attempt >= 4 { - cloned_eventual_side_effect_calls.store(0, Ordering::SeqCst); + if current_attempt >= minimum_attempts { + cloned_counter.store(0, Ordering::SeqCst); Ok(current_attempt) } else { - Err(anyhow!("Failed at attempt ${current_attempt}").into()) + Err(anyhow!("Failed at attempt {current_attempt}").into()) } }, ) @@ -88,17 +108,30 @@ impl Failing for FailingImpl { Ok(success_attempt) } - async fn terminally_failing_side_effect( + async fn side_effect_fails_after_given_attempts( &self, context: ObjectContext<'_>, - error_message: String, - ) -> HandlerResult<()> { - context - .run("failing_side_effect", || async move { - Err::<(), _>(TerminalError::new(error_message).into()) - }) - .await?; - - unreachable!("This should be unreachable") + retry_policy_max_retry_count: i32, + ) -> HandlerResult { + let cloned_counter = Arc::clone(&self.eventual_failure_side_effects); + if context + .run_with_retry::<_, _, ()>( + "failing_side_effect", + RunRetryPolicy::new() + .with_initial_interval(Duration::from_millis(10)) + .with_factor(1.0) + .with_max_attempts(retry_policy_max_retry_count as u32), + || async move { + let current_attempt = cloned_counter.fetch_add(1, Ordering::SeqCst) + 1; + Err(anyhow!("Failed at attempt {current_attempt}").into()) + }, + ) + .await + .is_err() + { + Ok(self.eventual_failure_side_effects.load(Ordering::SeqCst)) + } else { + Err(TerminalError::new("Expecting the side effect to fail!"))? + } } } From b7a525fd084b3ca7d6dd39cb91b6325849d64f65 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Mon, 26 Aug 2024 18:48:13 +0200 Subject: [PATCH 3/7] Nicer API for the run method --- examples/failures.rs | 2 +- examples/run.rs | 3 +- src/context/mod.rs | 30 +-- src/context/run.rs | 5 + src/endpoint/context.rs | 327 +++++++++++++++--------- src/endpoint/futures.rs | 22 +- src/lib.rs | 2 +- test-services/src/failing.rs | 39 ++- test-services/src/test_utils_service.rs | 2 +- 9 files changed, 262 insertions(+), 170 deletions(-) diff --git a/examples/failures.rs b/examples/failures.rs index bdc2c14..0ab4251 100644 --- a/examples/failures.rs +++ b/examples/failures.rs @@ -16,7 +16,7 @@ struct MyError; impl FailureExample for FailureExampleImpl { async fn do_run(&self, context: Context<'_>) -> HandlerResult<()> { context - .run("get_ip", || async move { + .run(|| async move { if rand::thread_rng().next_u32() % 4 == 0 { return Err(TerminalError::new("Failed!!!").into()); } diff --git a/examples/run.rs b/examples/run.rs index 2df394a..f6127d7 100644 --- a/examples/run.rs +++ b/examples/run.rs @@ -11,7 +11,7 @@ struct RunExampleImpl(reqwest::Client); impl RunExample for RunExampleImpl { async fn do_run(&self, context: Context<'_>) -> HandlerResult>> { let res = context - .run("get_ip", || async move { + .run(|| async move { let req = self.0.get("https://httpbin.org/ip").build()?; let res = self @@ -23,6 +23,7 @@ impl RunExample for RunExampleImpl { Ok(Json::from(res)) }) + .named("get_ip") .await? .into_inner(); diff --git a/src/context/mod.rs b/src/context/mod.rs index c2ce3aa..b682908 100644 --- a/src/context/mod.rs +++ b/src/context/mod.rs @@ -10,7 +10,8 @@ mod request; mod run; pub use request::{Request, RequestTarget}; -pub use run::{RunClosure, RunRetryPolicy}; +pub use run::{RunClosure, RunFuture, RunRetryPolicy}; + pub type HeaderMap = http::HeaderMap; /// Service handler context. @@ -372,34 +373,13 @@ impl<'ctx, CTX: private::SealedContext<'ctx>> ContextAwakeables<'ctx> for CTX {} pub trait ContextSideEffects<'ctx>: private::SealedContext<'ctx> { /// Run a non-deterministic operation and record its result. /// - fn run( - &self, - name: &'ctx str, - run_closure: R, - ) -> impl Future> + 'ctx - where - R: RunClosure + Send + Sync + 'ctx, - T: Serialize + Deserialize, - F: Future> + Send + Sync + 'ctx, - { - self.inner_context().run(name, run_closure) - } - - /// Run a non-deterministic operation and record its result. - /// - fn run_with_retry( - &self, - name: &'ctx str, - retry_policy: RunRetryPolicy, - run_closure: R, - ) -> impl Future> + 'ctx + fn run(&self, run_closure: R) -> impl RunFuture> + 'ctx where R: RunClosure + Send + Sync + 'ctx, - T: Serialize + Deserialize, F: Future> + Send + Sync + 'ctx, + T: Serialize + Deserialize + 'static, { - self.inner_context() - .run_with_retry(name, retry_policy, run_closure) + self.inner_context().run(run_closure) } /// Return a random seed inherently predictable, based on the invocation id, which is not secret. diff --git a/src/context/run.rs b/src/context/run.rs index 1c70583..45abb6d 100644 --- a/src/context/run.rs +++ b/src/context/run.rs @@ -25,6 +25,11 @@ where } } +pub trait RunFuture: Future { + fn with_retry_policy(self, retry_policy: RunRetryPolicy) -> Self; + fn named(self, name: impl Into) -> Self; +} + /// This struct represents the policy to execute retries for run closures. #[derive(Debug, Clone)] pub struct RunRetryPolicy { diff --git a/src/endpoint/context.rs b/src/endpoint/context.rs index 56bbd68..1812691 100644 --- a/src/endpoint/context.rs +++ b/src/endpoint/context.rs @@ -6,15 +6,19 @@ use crate::errors::{HandlerErrorInner, HandlerResult, TerminalError}; use crate::serde::{Deserialize, Serialize}; use futures::future::Either; use futures::{FutureExt, TryFutureExt}; +use pin_project_lite::pin_project; use restate_sdk_shared_core::{ AsyncResultHandle, CoreVM, Failure, NonEmptyValue, RetryPolicy, RunEnterResult, RunExitResult, SuspendedOrVMError, TakeOutputResult, Target, VMError, Value, VM, }; +use std::borrow::Cow; use std::collections::HashMap; use std::future::{ready, Future}; +use std::marker::PhantomData; +use std::mem; use std::pin::Pin; use std::sync::{Arc, Mutex}; -use std::task::Poll; +use std::task::{ready, Context, Poll}; use std::time::{Duration, Instant}; pub struct ContextInternalInner { @@ -514,125 +518,18 @@ impl ContextInternal { .sys_complete_promise(id.to_owned(), NonEmptyValue::Failure(failure.into())); } - pub fn run<'a, R, F, T>( + pub fn run<'a, Run, Fut, Res>( &'a self, - name: &'a str, - run_closure: R, - ) -> impl Future> + Send + Sync + 'a + run_closure: Run, + ) -> impl crate::context::RunFuture> + Send + Sync + 'a where - R: RunClosure + Send + Sync + 'a, - T: Serialize + Deserialize, - F: Future> + Send + Sync + 'a, - { - self.run_inner(name, RetryPolicy::Infinite, run_closure) - } - - pub fn run_with_retry<'a, R, F, T>( - &'a self, - name: &'a str, - retry_policy: RunRetryPolicy, - run_closure: R, - ) -> impl Future> + Send + Sync + 'a - where - R: RunClosure + Send + Sync + 'a, - T: Serialize + Deserialize, - F: Future> + Send + Sync + 'a, - { - self.run_inner( - name, - RetryPolicy::Exponential { - initial_interval: retry_policy.initial_interval, - factor: retry_policy.factor, - max_interval: retry_policy.max_interval, - max_attempts: retry_policy.max_attempts, - max_duration: retry_policy.max_duration, - }, - run_closure, - ) - } - - fn run_inner<'a, R, F, T>( - &'a self, - name: &'a str, - retry_policy: RetryPolicy, - run_closure: R, - ) -> impl Future> + Send + Sync + 'a - where - R: RunClosure + Send + Sync + 'a, - T: Serialize + Deserialize, - F: Future> + Send + Sync + 'a, + Run: RunClosure + Send + Sync + 'a, + Fut: Future> + Send + Sync + 'a, + Res: Serialize + Deserialize + 'static, { let this = Arc::clone(&self.inner); - InterceptErrorFuture::new(self.clone(), async move { - let enter_result = { must_lock!(this).vm.sys_run_enter(name.to_owned()) }; - - // Enter the side effect - match enter_result.map_err(ErrorInner::VM)? { - RunEnterResult::Executed(NonEmptyValue::Success(mut v)) => { - let t = T::deserialize(&mut v).map_err(|e| ErrorInner::Deserialization { - syscall: "run", - err: Box::new(e), - })?; - return Ok(Ok(t)); - } - RunEnterResult::Executed(NonEmptyValue::Failure(f)) => return Ok(Err(f.into())), - RunEnterResult::NotExecuted(_) => {} - }; - - // We need to run the closure - let run_start = Instant::now(); - let res = match run_closure.run().await { - Ok(t) => RunExitResult::Success(T::serialize(&t).map_err(|e| { - ErrorInner::Serialization { - syscall: "run", - err: Box::new(e), - } - })?), - Err(e) => match e.0 { - HandlerErrorInner::Retryable(err) => RunExitResult::RetryableFailure { - attempt_duration: run_start.elapsed(), - failure: Failure { - code: 500, - message: err.to_string(), - }, - }, - HandlerErrorInner::Terminal(t) => { - RunExitResult::TerminalFailure(TerminalError(t).into()) - } - }, - }; - - let handle = { - must_lock!(this) - .vm - .sys_run_exit(res, retry_policy) - .map_err(ErrorInner::VM)? - }; - - let value = self.create_poll_future(Ok(handle)).await?; - - match value { - Value::Void => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "empty", - syscall: "run", - } - .into()), - Value::Success(mut s) => { - let t = T::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { - syscall: "run", - err: Box::new(e), - })?; - Ok(Ok(t)) - } - Value::Failure(f) => Ok(Err(f.into())), - Value::StateKeys(_) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "run", - } - .into()), - } - }) + InterceptErrorFuture::new(self.clone(), RunFuture::new(this, run_closure)) } pub fn handle_handler_result(&self, res: HandlerResult) { @@ -699,10 +596,210 @@ impl ContextInternal { } } +pin_project! { + struct RunFuture { + name: String, + retry_policy: RetryPolicy, + phantom_data: PhantomData Ret>, + closure: Option, + inner_ctx: Option>>, + #[pin] + state: RunState, + } +} + +pin_project! { + #[project = RunStateProj] + enum RunState { + New, + ClosureRunning { + start_time: Instant, + #[pin] + fut: Fut, + }, + PollFutureRunning { + #[pin] + fut: VmPollFuture + } + } +} + +impl RunFuture { + fn new(inner_ctx: Arc>, closure: Run) -> Self { + Self { + name: "".to_string(), + retry_policy: RetryPolicy::Infinite, + phantom_data: PhantomData, + inner_ctx: Some(inner_ctx), + closure: Some(closure), + state: RunState::New, + } + } +} + +impl crate::context::RunFuture, Error>> + for RunFuture +where + Run: RunClosure + Send + Sync, + Fut: Future> + Send + Sync, + Ret: Serialize + Deserialize, +{ + fn with_retry_policy(mut self, retry_policy: RunRetryPolicy) -> Self { + self.retry_policy = RetryPolicy::Exponential { + initial_interval: retry_policy.initial_interval, + factor: retry_policy.factor, + max_interval: retry_policy.max_interval, + max_attempts: retry_policy.max_attempts, + max_duration: retry_policy.max_duration, + }; + self + } + + fn named(mut self, name: impl Into) -> Self { + self.name = name.into(); + self + } +} + +impl Future for RunFuture +where + Run: RunClosure + Send + Sync, + Res: Serialize + Deserialize, + Fut: Future> + Send + Sync, +{ + type Output = Result, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + loop { + match this.state.as_mut().project() { + RunStateProj::New => { + let enter_result = { + must_lock!(this + .inner_ctx + .as_mut() + .expect("Future should not be polled after returning Poll::Ready")) + .vm + .sys_run_enter(this.name.to_owned()) + }; + + // Enter the side effect + match enter_result.map_err(ErrorInner::VM)? { + RunEnterResult::Executed(NonEmptyValue::Success(mut v)) => { + let t = Res::deserialize(&mut v).map_err(|e| { + ErrorInner::Deserialization { + syscall: "run", + err: Box::new(e), + } + })?; + return Poll::Ready(Ok(Ok(t))); + } + RunEnterResult::Executed(NonEmptyValue::Failure(f)) => { + return Poll::Ready(Ok(Err(f.into()))) + } + RunEnterResult::NotExecuted(_) => {} + }; + + // We need to run the closure + this.state.set(RunState::ClosureRunning { + start_time: Instant::now(), + fut: this + .closure + .take() + .expect("Future should not be polled after returning Poll::Ready") + .run(), + }); + } + RunStateProj::ClosureRunning { start_time, fut } => { + let res = match ready!(fut.poll(cx)) { + Ok(t) => RunExitResult::Success(Res::serialize(&t).map_err(|e| { + ErrorInner::Serialization { + syscall: "run", + err: Box::new(e), + } + })?), + Err(e) => match e.0 { + HandlerErrorInner::Retryable(err) => RunExitResult::RetryableFailure { + attempt_duration: start_time.elapsed(), + failure: Failure { + code: 500, + message: err.to_string(), + }, + }, + HandlerErrorInner::Terminal(t) => { + RunExitResult::TerminalFailure(TerminalError(t).into()) + } + }, + }; + + let inner_ctx = this + .inner_ctx + .take() + .expect("Future should not be polled after returning Poll::Ready"); + + let handle = { + must_lock!(inner_ctx) + .vm + .sys_run_exit(res, mem::take(this.retry_policy)) + }; + + this.state.set(RunState::PollFutureRunning { + fut: VmPollFuture::new(Cow::Owned(inner_ctx), handle), + }); + } + RunStateProj::PollFutureRunning { fut } => { + let value = ready!(fut.poll(cx))?; + + return Poll::Ready(match value { + Value::Void => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "empty", + syscall: "run", + } + .into()), + Value::Success(mut s) => { + let t = Res::deserialize(&mut s).map_err(|e| { + ErrorInner::Deserialization { + syscall: "run", + err: Box::new(e), + } + })?; + Ok(Ok(t)) + } + Value::Failure(f) => Ok(Err(f.into())), + Value::StateKeys(_) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "state_keys", + syscall: "run", + } + .into()), + }); + } + } + } + } +} + struct VmPollFuture { state: Option, } +impl VmPollFuture { + fn new( + inner: Cow<'_, Arc>>, + handle: Result, + ) -> Self { + VmPollFuture { + state: Some(match handle { + Ok(handle) => PollState::Init { + ctx: inner.into_owned(), + handle, + }, + Err(err) => PollState::Failed(ErrorInner::VM(err)), + }), + } + } +} + enum PollState { Init { ctx: Arc>, diff --git a/src/endpoint/futures.rs b/src/endpoint/futures.rs index 4836feb..f19ae1f 100644 --- a/src/endpoint/futures.rs +++ b/src/endpoint/futures.rs @@ -1,3 +1,4 @@ +use crate::context::{RunFuture, RunRetryPolicy}; use crate::endpoint::{ContextInternal, Error}; use pin_project_lite::pin_project; use std::future::Future; @@ -8,7 +9,7 @@ use tokio::sync::oneshot; use tracing::warn; /// Future that traps the execution at this point, but keeps waking up the waker -pub(super) struct TrapFuture(PhantomData); +pub(super) struct TrapFuture(PhantomData T>); impl Default for TrapFuture { fn default() -> Self { @@ -16,10 +17,6 @@ impl Default for TrapFuture { } } -/// This is always safe, because we simply use phantom data inside TrapFuture. -unsafe impl Send for TrapFuture {} -unsafe impl Sync for TrapFuture {} - impl Future for TrapFuture { type Output = T; @@ -68,6 +65,21 @@ where } } +impl RunFuture for InterceptErrorFuture +where + F: RunFuture>, +{ + fn with_retry_policy(mut self, retry_policy: RunRetryPolicy) -> Self { + self.fut = self.fut.with_retry_policy(retry_policy); + self + } + + fn named(mut self, name: impl Into) -> Self { + self.fut = self.fut.named(name); + self + } +} + pin_project! { /// Future that will stop polling when handler is suspended/failed pub(super) struct HandlerStateAwareFuture { diff --git a/src/lib.rs b/src/lib.rs index 63b2599..35f7e18 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -58,7 +58,7 @@ pub mod prelude { pub use crate::context::{ Context, ContextAwakeables, ContextClient, ContextPromises, ContextReadState, ContextSideEffects, ContextTimers, ContextWriteState, HeaderMap, ObjectContext, Request, - RunRetryPolicy, SharedObjectContext, SharedWorkflowContext, WorkflowContext, + RunFuture, RunRetryPolicy, SharedObjectContext, SharedWorkflowContext, WorkflowContext, }; pub use crate::endpoint::Endpoint; pub use crate::errors::{HandlerError, HandlerResult, TerminalError}; diff --git a/test-services/src/failing.rs b/test-services/src/failing.rs index 12adafa..a6249e1 100644 --- a/test-services/src/failing.rs +++ b/test-services/src/failing.rs @@ -72,9 +72,7 @@ impl Failing for FailingImpl { error_message: String, ) -> HandlerResult<()> { context - .run("failing_side_effect", || async move { - Err::<(), _>(TerminalError::new(error_message).into()) - }) + .run(|| async move { Err::<(), _>(TerminalError::new(error_message).into()) }) .await?; unreachable!("This should be unreachable") @@ -87,22 +85,22 @@ impl Failing for FailingImpl { ) -> HandlerResult { let cloned_counter = Arc::clone(&self.eventual_success_side_effects); let success_attempt = context - .run_with_retry( - "failing_side_effect", + .run(|| async move { + let current_attempt = cloned_counter.fetch_add(1, Ordering::SeqCst) + 1; + + if current_attempt >= minimum_attempts { + cloned_counter.store(0, Ordering::SeqCst); + Ok(current_attempt) + } else { + Err(anyhow!("Failed at attempt {current_attempt}").into()) + } + }) + .with_retry_policy( RunRetryPolicy::new() .with_initial_interval(Duration::from_millis(10)) .with_factor(1.0), - || async move { - let current_attempt = cloned_counter.fetch_add(1, Ordering::SeqCst) + 1; - - if current_attempt >= minimum_attempts { - cloned_counter.store(0, Ordering::SeqCst); - Ok(current_attempt) - } else { - Err(anyhow!("Failed at attempt {current_attempt}").into()) - } - }, ) + .named("failing_side_effect") .await?; Ok(success_attempt) @@ -115,16 +113,15 @@ impl Failing for FailingImpl { ) -> HandlerResult { let cloned_counter = Arc::clone(&self.eventual_failure_side_effects); if context - .run_with_retry::<_, _, ()>( - "failing_side_effect", + .run(|| async move { + let current_attempt = cloned_counter.fetch_add(1, Ordering::SeqCst) + 1; + Err::<(), _>(anyhow!("Failed at attempt {current_attempt}").into()) + }) + .with_retry_policy( RunRetryPolicy::new() .with_initial_interval(Duration::from_millis(10)) .with_factor(1.0) .with_max_attempts(retry_policy_max_retry_count as u32), - || async move { - let current_attempt = cloned_counter.fetch_add(1, Ordering::SeqCst) + 1; - Err(anyhow!("Failed at attempt {current_attempt}").into()) - }, ) .await .is_err() diff --git a/test-services/src/test_utils_service.rs b/test-services/src/test_utils_service.rs index cbfc688..a1d84a1 100644 --- a/test-services/src/test_utils_service.rs +++ b/test-services/src/test_utils_service.rs @@ -141,7 +141,7 @@ impl TestUtilsService for TestUtilsServiceImpl { for _ in 0..increments { let counter_clone = Arc::clone(&counter); context - .run("count", || async { + .run(|| async { counter_clone.fetch_add(1, Ordering::SeqCst); Ok(()) }) From 2c7276ae30ddae45e721e03f01ce3bd4e0515209 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Mon, 26 Aug 2024 18:58:30 +0200 Subject: [PATCH 4/7] Move futures around to better org this piece of code --- src/endpoint/context.rs | 385 +++++++------------- src/endpoint/futures.rs | 139 ------- src/endpoint/futures/async_result_poll.rs | 141 +++++++ src/endpoint/futures/handler_state_aware.rs | 65 ++++ src/endpoint/futures/intercept_error.rs | 60 +++ src/endpoint/futures/mod.rs | 4 + src/endpoint/futures/trap.rs | 22 ++ src/endpoint/mod.rs | 7 +- 8 files changed, 423 insertions(+), 400 deletions(-) delete mode 100644 src/endpoint/futures.rs create mode 100644 src/endpoint/futures/async_result_poll.rs create mode 100644 src/endpoint/futures/handler_state_aware.rs create mode 100644 src/endpoint/futures/intercept_error.rs create mode 100644 src/endpoint/futures/mod.rs create mode 100644 src/endpoint/futures/trap.rs diff --git a/src/endpoint/context.rs b/src/endpoint/context.rs index 1812691..892ba2e 100644 --- a/src/endpoint/context.rs +++ b/src/endpoint/context.rs @@ -1,5 +1,7 @@ use crate::context::{Request, RequestTarget, RunClosure, RunRetryPolicy}; -use crate::endpoint::futures::{InterceptErrorFuture, TrapFuture}; +use crate::endpoint::futures::async_result_poll::VmAsyncResultPollFuture; +use crate::endpoint::futures::intercept_error::InterceptErrorFuture; +use crate::endpoint::futures::trap::TrapFuture; use crate::endpoint::handler_state::HandlerStateNotifier; use crate::endpoint::{Error, ErrorInner, InputReceiver, OutputSender}; use crate::errors::{HandlerErrorInner, HandlerResult, TerminalError}; @@ -8,8 +10,8 @@ use futures::future::Either; use futures::{FutureExt, TryFutureExt}; use pin_project_lite::pin_project; use restate_sdk_shared_core::{ - AsyncResultHandle, CoreVM, Failure, NonEmptyValue, RetryPolicy, RunEnterResult, RunExitResult, - SuspendedOrVMError, TakeOutputResult, Target, VMError, Value, VM, + CoreVM, Failure, NonEmptyValue, RetryPolicy, RunEnterResult, RunExitResult, TakeOutputResult, + Target, Value, VM, }; use std::borrow::Cow; use std::collections::HashMap; @@ -22,9 +24,9 @@ use std::task::{ready, Context, Poll}; use std::time::{Duration, Instant}; pub struct ContextInternalInner { - vm: CoreVM, - read: InputReceiver, - write: OutputSender, + pub(crate) vm: CoreVM, + pub(crate) read: InputReceiver, + pub(crate) write: OutputSender, pub(super) handler_state: HandlerStateNotifier, } @@ -198,22 +200,23 @@ impl ContextInternal { ) -> impl Future, TerminalError>> + Send + Sync { let maybe_handle = { must_lock!(self.inner).vm.sys_state_get(key.to_owned()) }; - let poll_future = self.create_poll_future(maybe_handle).map(|res| match res { - Ok(Value::Void) => Ok(Ok(None)), - Ok(Value::Success(mut s)) => { - let t = T::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { + let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) + .map(|res| match res { + Ok(Value::Void) => Ok(Ok(None)), + Ok(Value::Success(mut s)) => { + let t = T::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { + syscall: "get_state", + err: Box::new(e), + })?; + Ok(Ok(Some(t))) + } + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "state_keys", syscall: "get_state", - err: Box::new(e), - })?; - Ok(Ok(Some(t))) - } - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "get_state", - }), - Err(e) => Err(e), - }); + }), + Err(e) => Err(e), + }); InterceptErrorFuture::new(self.clone(), poll_future.map_err(Error)) } @@ -223,19 +226,20 @@ impl ContextInternal { ) -> impl Future, TerminalError>> + Send + Sync { let maybe_handle = { must_lock!(self.inner).vm.sys_state_get_keys() }; - let poll_future = self.create_poll_future(maybe_handle).map(|res| match res { - Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "empty", - syscall: "get_state", - }), - Ok(Value::Success(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "success", - syscall: "get_state", - }), - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Ok(Value::StateKeys(s)) => Ok(Ok(s)), - Err(e) => Err(e), - }); + let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) + .map(|res| match res { + Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "empty", + syscall: "get_state", + }), + Ok(Value::Success(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "success", + syscall: "get_state", + }), + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(Value::StateKeys(s)) => Ok(Ok(s)), + Err(e) => Err(e), + }); InterceptErrorFuture::new(self.clone(), poll_future.map_err(Error)) } @@ -272,19 +276,20 @@ impl ContextInternal { ) -> impl Future> + Send + Sync { let maybe_handle = { must_lock!(self.inner).vm.sys_sleep(duration) }; - let poll_future = self.create_poll_future(maybe_handle).map(|res| match res { - Ok(Value::Void) => Ok(Ok(())), - Ok(Value::Success(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "success", - syscall: "sleep", - }), - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Err(e) => Err(e), - Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "sleep", - }), - }); + let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) + .map(|res| match res { + Ok(Value::Void) => Ok(Ok(())), + Ok(Value::Success(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "success", + syscall: "sleep", + }), + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Err(e) => Err(e), + Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "state_keys", + syscall: "sleep", + }), + }); InterceptErrorFuture::new(self.clone(), poll_future.map_err(Error)) } @@ -317,25 +322,26 @@ impl ContextInternal { let maybe_handle = inner_lock.vm.sys_call(request_target.into(), input); drop(inner_lock); - let poll_future = self.create_poll_future(maybe_handle).map(|res| match res { - Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "empty", - syscall: "call", - }), - Ok(Value::Success(mut s)) => { - let t = Res::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { + let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) + .map(|res| match res { + Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "empty", syscall: "call", - err: Box::new(e), - })?; - Ok(Ok(t)) - } - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "call", - }), - Err(e) => Err(e), - }); + }), + Ok(Value::Success(mut s)) => { + let t = Res::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { + syscall: "call", + err: Box::new(e), + })?; + Ok(Ok(t)) + } + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "state_keys", + syscall: "call", + }), + Err(e) => Err(e), + }); Either::Left(InterceptErrorFuture::new( self.clone(), @@ -385,25 +391,26 @@ impl ContextInternal { ), }; - let poll_future = self.create_poll_future(maybe_handle).map(|res| match res { - Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "empty", - syscall: "awakeable", - }), - Ok(Value::Success(mut s)) => { - let t = T::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { + let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) + .map(|res| match res { + Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "empty", syscall: "awakeable", - err: Box::new(e), - })?; - Ok(Ok(t)) - } - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "awakeable", - }), - Err(e) => Err(e), - }); + }), + Ok(Value::Success(mut s)) => { + let t = T::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { + syscall: "awakeable", + err: Box::new(e), + })?; + Ok(Ok(t)) + } + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "state_keys", + syscall: "awakeable", + }), + Err(e) => Err(e), + }); ( awakeable_id, @@ -443,25 +450,26 @@ impl ContextInternal { ) -> impl Future> + Send + Sync { let maybe_handle = { must_lock!(self.inner).vm.sys_get_promise(name.to_owned()) }; - let poll_future = self.create_poll_future(maybe_handle).map(|res| match res { - Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "empty", - syscall: "promise", - }), - Ok(Value::Success(mut s)) => { - let t = T::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { + let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) + .map(|res| match res { + Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "empty", syscall: "promise", - err: Box::new(e), - })?; - Ok(Ok(t)) - } - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "promise", - }), - Err(e) => Err(e), - }); + }), + Ok(Value::Success(mut s)) => { + let t = T::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { + syscall: "promise", + err: Box::new(e), + })?; + Ok(Ok(t)) + } + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "state_keys", + syscall: "promise", + }), + Err(e) => Err(e), + }); InterceptErrorFuture::new(self.clone(), poll_future.map_err(Error)) } @@ -472,22 +480,23 @@ impl ContextInternal { ) -> impl Future, TerminalError>> + Send + Sync { let maybe_handle = { must_lock!(self.inner).vm.sys_peek_promise(name.to_owned()) }; - let poll_future = self.create_poll_future(maybe_handle).map(|res| match res { - Ok(Value::Void) => Ok(Ok(None)), - Ok(Value::Success(mut s)) => { - let t = T::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { + let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) + .map(|res| match res { + Ok(Value::Void) => Ok(Ok(None)), + Ok(Value::Success(mut s)) => { + let t = T::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { + syscall: "peek_promise", + err: Box::new(e), + })?; + Ok(Ok(Some(t))) + } + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "state_keys", syscall: "peek_promise", - err: Box::new(e), - })?; - Ok(Ok(Some(t))) - } - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "peek_promise", - }), - Err(e) => Err(e), - }); + }), + Err(e) => Err(e), + }); InterceptErrorFuture::new(self.clone(), poll_future.map_err(Error)) } @@ -579,21 +588,6 @@ impl ContextInternal { pub(super) fn fail(&self, e: Error) { must_lock!(self.inner).fail(e) } - - fn create_poll_future( - &self, - handle: Result, - ) -> impl Future> + Send + Sync { - VmPollFuture { - state: Some(match handle { - Ok(handle) => PollState::Init { - ctx: Arc::clone(&self.inner), - handle, - }, - Err(err) => PollState::Failed(ErrorInner::VM(err)), - }), - } - } } pin_project! { @@ -619,7 +613,7 @@ pin_project! { }, PollFutureRunning { #[pin] - fut: VmPollFuture + fut: VmAsyncResultPollFuture } } } @@ -745,7 +739,7 @@ where }; this.state.set(RunState::PollFutureRunning { - fut: VmPollFuture::new(Cow::Owned(inner_ctx), handle), + fut: VmAsyncResultPollFuture::new(Cow::Owned(inner_ctx), handle), }); } RunStateProj::PollFutureRunning { fut } => { @@ -778,128 +772,3 @@ where } } } - -struct VmPollFuture { - state: Option, -} - -impl VmPollFuture { - fn new( - inner: Cow<'_, Arc>>, - handle: Result, - ) -> Self { - VmPollFuture { - state: Some(match handle { - Ok(handle) => PollState::Init { - ctx: inner.into_owned(), - handle, - }, - Err(err) => PollState::Failed(ErrorInner::VM(err)), - }), - } - } -} - -enum PollState { - Init { - ctx: Arc>, - handle: AsyncResultHandle, - }, - WaitingInput { - ctx: Arc>, - handle: AsyncResultHandle, - }, - Failed(ErrorInner), -} - -impl Future for VmPollFuture { - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { - loop { - match self - .state - .take() - .expect("Future should not be polled after Poll::Ready") - { - PollState::Init { ctx, handle } => { - // Acquire lock - let mut inner_lock = must_lock!(ctx); - - // Let's consume some output - let out = inner_lock.vm.take_output(); - match out { - TakeOutputResult::Buffer(b) => { - if !inner_lock.write.send(b) { - self.state = Some(PollState::Failed(ErrorInner::Suspended)); - continue; - } - } - TakeOutputResult::EOF => { - self.state = - Some(PollState::Failed(ErrorInner::UnexpectedOutputClosed)); - continue; - } - } - - // Notify that we reached an await point - inner_lock.vm.notify_await_point(handle); - - // At this point let's try to take the async result - match inner_lock.vm.take_async_result(handle) { - Ok(Some(v)) => return Poll::Ready(Ok(v)), - Ok(None) => { - drop(inner_lock); - self.state = Some(PollState::WaitingInput { ctx, handle }); - } - Err(SuspendedOrVMError::Suspended(_)) => { - self.state = Some(PollState::Failed(ErrorInner::Suspended)); - } - Err(SuspendedOrVMError::VM(e)) => { - self.state = Some(PollState::Failed(ErrorInner::VM(e))); - } - } - } - PollState::WaitingInput { ctx, handle } => { - let mut inner_lock = must_lock!(ctx); - - let read_result = match inner_lock.read.poll_recv(cx) { - Poll::Ready(t) => t, - Poll::Pending => { - drop(inner_lock); - self.state = Some(PollState::WaitingInput { ctx, handle }); - return Poll::Pending; - } - }; - - // Pass read result to VM - match read_result { - Some(Ok(b)) => inner_lock.vm.notify_input(b), - Some(Err(e)) => inner_lock.vm.notify_error( - "Error when reading the body".into(), - e.to_string().into(), - None, - ), - None => inner_lock.vm.notify_input_closed(), - } - - // Now try to take async result again - match inner_lock.vm.take_async_result(handle) { - Ok(Some(v)) => return Poll::Ready(Ok(v)), - Ok(None) => { - drop(inner_lock); - self.state = Some(PollState::WaitingInput { ctx, handle }); - } - Err(SuspendedOrVMError::Suspended(_)) => { - self.state = Some(PollState::Failed(ErrorInner::Suspended)); - } - Err(SuspendedOrVMError::VM(e)) => { - self.state = Some(PollState::Failed(ErrorInner::VM(e))); - } - } - } - PollState::Failed(err) => return Poll::Ready(Err(err)), - } - } - } -} diff --git a/src/endpoint/futures.rs b/src/endpoint/futures.rs deleted file mode 100644 index f19ae1f..0000000 --- a/src/endpoint/futures.rs +++ /dev/null @@ -1,139 +0,0 @@ -use crate::context::{RunFuture, RunRetryPolicy}; -use crate::endpoint::{ContextInternal, Error}; -use pin_project_lite::pin_project; -use std::future::Future; -use std::marker::PhantomData; -use std::pin::Pin; -use std::task::{ready, Context, Poll}; -use tokio::sync::oneshot; -use tracing::warn; - -/// Future that traps the execution at this point, but keeps waking up the waker -pub(super) struct TrapFuture(PhantomData T>); - -impl Default for TrapFuture { - fn default() -> Self { - Self(PhantomData) - } -} - -impl Future for TrapFuture { - type Output = T; - - fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { - ctx.waker().wake_by_ref(); - Poll::Pending - } -} - -pin_project! { - /// Future that intercepts errors of inner future, and passes them to ContextInternal - pub(super) struct InterceptErrorFuture{ - #[pin] - fut: F, - ctx: ContextInternal - } -} - -impl InterceptErrorFuture { - pub(super) fn new(ctx: ContextInternal, fut: F) -> Self { - Self { fut, ctx } - } -} - -impl Future for InterceptErrorFuture -where - F: Future>, -{ - type Output = R; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - let result = ready!(this.fut.poll(cx)); - - match result { - Ok(r) => Poll::Ready(r), - Err(e) => { - this.ctx.fail(e); - - // Here is the secret sauce. This will immediately cause the whole future chain to be polled, - // but the poll here will be intercepted by HandlerStateAwareFuture - cx.waker().wake_by_ref(); - Poll::Pending - } - } - } -} - -impl RunFuture for InterceptErrorFuture -where - F: RunFuture>, -{ - fn with_retry_policy(mut self, retry_policy: RunRetryPolicy) -> Self { - self.fut = self.fut.with_retry_policy(retry_policy); - self - } - - fn named(mut self, name: impl Into) -> Self { - self.fut = self.fut.named(name); - self - } -} - -pin_project! { - /// Future that will stop polling when handler is suspended/failed - pub(super) struct HandlerStateAwareFuture { - #[pin] - fut: F, - handler_state_rx: oneshot::Receiver, - handler_context: ContextInternal, - } -} - -impl HandlerStateAwareFuture { - pub(super) fn new( - handler_context: ContextInternal, - handler_state_rx: oneshot::Receiver, - fut: F, - ) -> HandlerStateAwareFuture { - HandlerStateAwareFuture { - fut, - handler_state_rx, - handler_context, - } - } -} - -impl Future for HandlerStateAwareFuture -where - F: Future, -{ - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - - match this.handler_state_rx.try_recv() { - Ok(e) => { - warn!( - rpc.system = "restate", - rpc.service = %this.handler_context.service_name(), - rpc.method = %this.handler_context.handler_name(), - "Error while processing handler {e:#}" - ); - this.handler_context.consume_to_end(); - Poll::Ready(Err(e)) - } - Err(oneshot::error::TryRecvError::Empty) => match this.fut.poll(cx) { - Poll::Ready(out) => { - this.handler_context.consume_to_end(); - Poll::Ready(Ok(out)) - } - Poll::Pending => Poll::Pending, - }, - Err(oneshot::error::TryRecvError::Closed) => { - panic!("This is unexpected, this future is still being polled although the sender side was dropped. This should not be possible, because the sender is dropped when this future returns Poll:ready().") - } - } - } -} diff --git a/src/endpoint/futures/async_result_poll.rs b/src/endpoint/futures/async_result_poll.rs new file mode 100644 index 0000000..8f6ef5d --- /dev/null +++ b/src/endpoint/futures/async_result_poll.rs @@ -0,0 +1,141 @@ +use crate::endpoint::context::ContextInternalInner; +use crate::endpoint::ErrorInner; +use restate_sdk_shared_core::{ + AsyncResultHandle, SuspendedOrVMError, TakeOutputResult, VMError, Value, VM, +}; +use std::borrow::Cow; +use std::future::Future; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::task::Poll; + +pub(crate) struct VmAsyncResultPollFuture { + state: Option, +} + +impl VmAsyncResultPollFuture { + pub fn new( + inner: Cow<'_, Arc>>, + handle: Result, + ) -> Self { + VmAsyncResultPollFuture { + state: Some(match handle { + Ok(handle) => PollState::Init { + ctx: inner.into_owned(), + handle, + }, + Err(err) => PollState::Failed(ErrorInner::VM(err)), + }), + } + } +} + +enum PollState { + Init { + ctx: Arc>, + handle: AsyncResultHandle, + }, + WaitingInput { + ctx: Arc>, + handle: AsyncResultHandle, + }, + Failed(ErrorInner), +} + +macro_rules! must_lock { + ($mutex:expr) => { + $mutex.try_lock().expect("You're trying to await two futures at the same time and/or trying to perform some operation on the restate context while awaiting a future. This is not supported!") + }; +} + +impl Future for VmAsyncResultPollFuture { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + loop { + match self + .state + .take() + .expect("Future should not be polled after Poll::Ready") + { + PollState::Init { ctx, handle } => { + // Acquire lock + let mut inner_lock = must_lock!(ctx); + + // Let's consume some output + let out = inner_lock.vm.take_output(); + match out { + TakeOutputResult::Buffer(b) => { + if !inner_lock.write.send(b) { + self.state = Some(PollState::Failed(ErrorInner::Suspended)); + continue; + } + } + TakeOutputResult::EOF => { + self.state = + Some(PollState::Failed(ErrorInner::UnexpectedOutputClosed)); + continue; + } + } + + // Notify that we reached an await point + inner_lock.vm.notify_await_point(handle); + + // At this point let's try to take the async result + match inner_lock.vm.take_async_result(handle) { + Ok(Some(v)) => return Poll::Ready(Ok(v)), + Ok(None) => { + drop(inner_lock); + self.state = Some(PollState::WaitingInput { ctx, handle }); + } + Err(SuspendedOrVMError::Suspended(_)) => { + self.state = Some(PollState::Failed(ErrorInner::Suspended)); + } + Err(SuspendedOrVMError::VM(e)) => { + self.state = Some(PollState::Failed(ErrorInner::VM(e))); + } + } + } + PollState::WaitingInput { ctx, handle } => { + let mut inner_lock = must_lock!(ctx); + + let read_result = match inner_lock.read.poll_recv(cx) { + Poll::Ready(t) => t, + Poll::Pending => { + drop(inner_lock); + self.state = Some(PollState::WaitingInput { ctx, handle }); + return Poll::Pending; + } + }; + + // Pass read result to VM + match read_result { + Some(Ok(b)) => inner_lock.vm.notify_input(b), + Some(Err(e)) => inner_lock.vm.notify_error( + "Error when reading the body".into(), + e.to_string().into(), + None, + ), + None => inner_lock.vm.notify_input_closed(), + } + + // Now try to take async result again + match inner_lock.vm.take_async_result(handle) { + Ok(Some(v)) => return Poll::Ready(Ok(v)), + Ok(None) => { + drop(inner_lock); + self.state = Some(PollState::WaitingInput { ctx, handle }); + } + Err(SuspendedOrVMError::Suspended(_)) => { + self.state = Some(PollState::Failed(ErrorInner::Suspended)); + } + Err(SuspendedOrVMError::VM(e)) => { + self.state = Some(PollState::Failed(ErrorInner::VM(e))); + } + } + } + PollState::Failed(err) => return Poll::Ready(Err(err)), + } + } + } +} diff --git a/src/endpoint/futures/handler_state_aware.rs b/src/endpoint/futures/handler_state_aware.rs new file mode 100644 index 0000000..4ed225f --- /dev/null +++ b/src/endpoint/futures/handler_state_aware.rs @@ -0,0 +1,65 @@ +use crate::endpoint::{ContextInternal, Error}; +use pin_project_lite::pin_project; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::sync::oneshot; +use tracing::warn; + +pin_project! { + /// Future that will stop polling when handler is suspended/failed + pub struct HandlerStateAwareFuture { + #[pin] + fut: F, + handler_state_rx: oneshot::Receiver, + handler_context: ContextInternal, + } +} + +impl HandlerStateAwareFuture { + pub fn new( + handler_context: ContextInternal, + handler_state_rx: oneshot::Receiver, + fut: F, + ) -> HandlerStateAwareFuture { + HandlerStateAwareFuture { + fut, + handler_state_rx, + handler_context, + } + } +} + +impl Future for HandlerStateAwareFuture +where + F: Future, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + match this.handler_state_rx.try_recv() { + Ok(e) => { + warn!( + rpc.system = "restate", + rpc.service = %this.handler_context.service_name(), + rpc.method = %this.handler_context.handler_name(), + "Error while processing handler {e:#}" + ); + this.handler_context.consume_to_end(); + Poll::Ready(Err(e)) + } + Err(oneshot::error::TryRecvError::Empty) => match this.fut.poll(cx) { + Poll::Ready(out) => { + this.handler_context.consume_to_end(); + Poll::Ready(Ok(out)) + } + Poll::Pending => Poll::Pending, + }, + Err(oneshot::error::TryRecvError::Closed) => { + panic!("This is unexpected, this future is still being polled although the sender side was dropped. This should not be possible, because the sender is dropped when this future returns Poll:ready().") + } + } + } +} diff --git a/src/endpoint/futures/intercept_error.rs b/src/endpoint/futures/intercept_error.rs new file mode 100644 index 0000000..b187bc5 --- /dev/null +++ b/src/endpoint/futures/intercept_error.rs @@ -0,0 +1,60 @@ +use crate::context::{RunFuture, RunRetryPolicy}; +use crate::endpoint::{ContextInternal, Error}; +use pin_project_lite::pin_project; +use std::future::Future; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; + +pin_project! { + /// Future that intercepts errors of inner future, and passes them to ContextInternal + pub struct InterceptErrorFuture{ + #[pin] + fut: F, + ctx: ContextInternal + } +} + +impl InterceptErrorFuture { + pub fn new(ctx: ContextInternal, fut: F) -> Self { + Self { fut, ctx } + } +} + +impl Future for InterceptErrorFuture +where + F: Future>, +{ + type Output = R; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let result = ready!(this.fut.poll(cx)); + + match result { + Ok(r) => Poll::Ready(r), + Err(e) => { + this.ctx.fail(e); + + // Here is the secret sauce. This will immediately cause the whole future chain to be polled, + // but the poll here will be intercepted by HandlerStateAwareFuture + cx.waker().wake_by_ref(); + Poll::Pending + } + } + } +} + +impl RunFuture for InterceptErrorFuture +where + F: RunFuture>, +{ + fn with_retry_policy(mut self, retry_policy: RunRetryPolicy) -> Self { + self.fut = self.fut.with_retry_policy(retry_policy); + self + } + + fn named(mut self, name: impl Into) -> Self { + self.fut = self.fut.named(name); + self + } +} diff --git a/src/endpoint/futures/mod.rs b/src/endpoint/futures/mod.rs new file mode 100644 index 0000000..9f0b03f --- /dev/null +++ b/src/endpoint/futures/mod.rs @@ -0,0 +1,4 @@ +pub mod async_result_poll; +pub mod handler_state_aware; +pub mod intercept_error; +pub mod trap; diff --git a/src/endpoint/futures/trap.rs b/src/endpoint/futures/trap.rs new file mode 100644 index 0000000..9b0269d --- /dev/null +++ b/src/endpoint/futures/trap.rs @@ -0,0 +1,22 @@ +use std::future::Future; +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// Future that traps the execution at this point, but keeps waking up the waker +pub struct TrapFuture(PhantomData T>); + +impl Default for TrapFuture { + fn default() -> Self { + Self(PhantomData) + } +} + +impl Future for TrapFuture { + type Output = T; + + fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { + ctx.waker().wake_by_ref(); + Poll::Pending + } +} diff --git a/src/endpoint/mod.rs b/src/endpoint/mod.rs index 8609b3d..65e59db 100644 --- a/src/endpoint/mod.rs +++ b/src/endpoint/mod.rs @@ -2,7 +2,8 @@ mod context; mod futures; mod handler_state; -use crate::endpoint::futures::InterceptErrorFuture; +use crate::endpoint::futures::handler_state_aware::HandlerStateAwareFuture; +use crate::endpoint::futures::intercept_error::InterceptErrorFuture; use crate::endpoint::handler_state::HandlerStateNotifier; use crate::service::{Discoverable, Service}; use ::futures::future::BoxFuture; @@ -98,7 +99,7 @@ impl Error { } #[derive(Debug, thiserror::Error)] -enum ErrorInner { +pub(crate) enum ErrorInner { #[error("Received a request for unknown service '{0}'")] UnknownService(String), #[error("Received a request for unknown service handler '{0}/{1}'")] @@ -359,7 +360,7 @@ impl BidiStreamRunner { let user_code_fut = InterceptErrorFuture::new(ctx.clone(), svc.handle(ctx.clone())); // Wrap it in handler state aware future - futures::HandlerStateAwareFuture::new(ctx.clone(), handler_state_rx, user_code_fut).await + HandlerStateAwareFuture::new(ctx.clone(), handler_state_rx, user_code_fut).await } async fn init_loop_vm(vm: &mut CoreVM, input_rx: &mut InputReceiver) -> Result<(), ErrorInner> { From 4ec056e3848c53f41651b7f882a8de2421976c79 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Tue, 27 Aug 2024 09:25:46 +0200 Subject: [PATCH 5/7] Polishing and docs --- src/context/mod.rs | 2 +- src/context/run.rs | 21 +++++++++++++++++++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/context/mod.rs b/src/context/mod.rs index b682908..0726c0b 100644 --- a/src/context/mod.rs +++ b/src/context/mod.rs @@ -372,7 +372,7 @@ impl<'ctx, CTX: private::SealedContext<'ctx>> ContextAwakeables<'ctx> for CTX {} /// Trait exposing Restate functionalities to deal with non-deterministic operations. pub trait ContextSideEffects<'ctx>: private::SealedContext<'ctx> { /// Run a non-deterministic operation and record its result. - /// + #[must_use] fn run(&self, run_closure: R) -> impl RunFuture> + 'ctx where R: RunClosure + Send + Sync + 'ctx, diff --git a/src/context/run.rs b/src/context/run.rs index 45abb6d..3d63dbe 100644 --- a/src/context/run.rs +++ b/src/context/run.rs @@ -25,8 +25,17 @@ where } } +/// Future created using [`super::ContextSideEffects::run`]. pub trait RunFuture: Future { + /// Provide a custom retry policy for this `run` operation. + /// + /// If unspecified, the `run` will be retried using the [Restate invoker retry policy](https://docs.restate.dev/operate/configuration/server), + /// which by default retries indefinitely. fn with_retry_policy(self, retry_policy: RunRetryPolicy) -> Self; + + /// Define a name for this `run` operation. + /// + /// This is used mainly for observability. fn named(self, name: impl Into) -> Self; } @@ -82,16 +91,24 @@ impl RunRetryPolicy { self } - /// Gives up retrying when either this number of attempts is reached, + /// Gives up retrying when either at least the given number of attempts is reached, /// or `max_duration` (if set) is reached first. + /// + /// **Note:** The number of actual retries may be higher than the provided value. + /// This is due to the nature of the run operation, which executes the closure on the service and sends the result afterward to Restate. + /// /// Infinite retries if this field and `max_duration` are unset. pub fn with_max_attempts(mut self, max_attempts: u32) -> Self { self.max_attempts = Some(max_attempts); self } - /// Gives up retrying when either the retry loop lasted for this given max duration, + /// Gives up retrying when either the retry loop lasted at least for this given max duration, /// or `max_attempts` (if set) is reached first. + /// + /// **Note:** The real retry loop duration may be higher than the given duration. + /// This is due to the nature of the run operation, which executes the closure on the service and sends the result afterward to Restate. + /// /// Infinite retries if this field and `max_attempts` are unset. pub fn with_max_duration(mut self, max_duration: Duration) -> Self { self.max_duration = Some(max_duration); From b69cb45fae422bfc5ca97f7cca11c251f58a62e0 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Wed, 28 Aug 2024 09:31:36 +0200 Subject: [PATCH 6/7] Use main for now --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 4fa5bd5..de08c1d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,7 @@ pin-project-lite = "0.2" rand = { version = "0.8.5", optional = true } regress = "0.10" restate-sdk-macros = { version = "0.2.1", path = "macros" } -restate-sdk-shared-core = { git = "https://github.com/restatedev/sdk-shared-core.git", branch = "side-effect-retry" } +restate-sdk-shared-core = { git = "https://github.com/restatedev/sdk-shared-core.git", branch = "main" } serde = "1.0" serde_json = "1.0" thiserror = "1.0.63" From dca950e6e403637e2df287ef196175f65caf4d6e Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Wed, 28 Aug 2024 09:37:46 +0200 Subject: [PATCH 7/7] Use test suite 2.0 --- .github/workflows/integration.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index ac4016e..86bd744 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -41,7 +41,7 @@ jobs: name: "Features integration test (sdk-test-suite version ${{ matrix.sdk-test-suite }})" strategy: matrix: - sdk-test-suite: [ "1.8" ] + sdk-test-suite: [ "2.0" ] permissions: contents: read issues: read