diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index ac4016e..86bd744 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -41,7 +41,7 @@ jobs: name: "Features integration test (sdk-test-suite version ${{ matrix.sdk-test-suite }})" strategy: matrix: - sdk-test-suite: [ "1.8" ] + sdk-test-suite: [ "2.0" ] permissions: contents: read issues: read diff --git a/Cargo.toml b/Cargo.toml index 2270314..de08c1d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,7 @@ pin-project-lite = "0.2" rand = { version = "0.8.5", optional = true } regress = "0.10" restate-sdk-macros = { version = "0.2.1", path = "macros" } -restate-sdk-shared-core = { version = "0.0.5" } +restate-sdk-shared-core = { git = "https://github.com/restatedev/sdk-shared-core.git", branch = "main" } serde = "1.0" serde_json = "1.0" thiserror = "1.0.63" diff --git a/examples/failures.rs b/examples/failures.rs index bdc2c14..0ab4251 100644 --- a/examples/failures.rs +++ b/examples/failures.rs @@ -16,7 +16,7 @@ struct MyError; impl FailureExample for FailureExampleImpl { async fn do_run(&self, context: Context<'_>) -> HandlerResult<()> { context - .run("get_ip", || async move { + .run(|| async move { if rand::thread_rng().next_u32() % 4 == 0 { return Err(TerminalError::new("Failed!!!").into()); } diff --git a/examples/run.rs b/examples/run.rs index 2df394a..f6127d7 100644 --- a/examples/run.rs +++ b/examples/run.rs @@ -11,7 +11,7 @@ struct RunExampleImpl(reqwest::Client); impl RunExample for RunExampleImpl { async fn do_run(&self, context: Context<'_>) -> HandlerResult>> { let res = context - .run("get_ip", || async move { + .run(|| async move { let req = self.0.get("https://httpbin.org/ip").build()?; let res = self @@ -23,6 +23,7 @@ impl RunExample for RunExampleImpl { Ok(Json::from(res)) }) + .named("get_ip") .await? .into_inner(); diff --git a/src/context/mod.rs b/src/context/mod.rs index 69ca125..0726c0b 100644 --- a/src/context/mod.rs +++ b/src/context/mod.rs @@ -10,7 +10,8 @@ mod request; mod run; pub use request::{Request, RequestTarget}; -pub use run::RunClosure; +pub use run::{RunClosure, RunFuture, RunRetryPolicy}; + pub type HeaderMap = http::HeaderMap; /// Service handler context. @@ -371,18 +372,14 @@ impl<'ctx, CTX: private::SealedContext<'ctx>> ContextAwakeables<'ctx> for CTX {} /// Trait exposing Restate functionalities to deal with non-deterministic operations. pub trait ContextSideEffects<'ctx>: private::SealedContext<'ctx> { /// Run a non-deterministic operation and record its result. - /// - fn run( - &self, - name: &'ctx str, - run_closure: R, - ) -> impl Future> + 'ctx + #[must_use] + fn run(&self, run_closure: R) -> impl RunFuture> + 'ctx where R: RunClosure + Send + Sync + 'ctx, - T: Serialize + Deserialize, F: Future> + Send + Sync + 'ctx, + T: Serialize + Deserialize + 'static, { - self.inner_context().run(name, run_closure) + self.inner_context().run(run_closure) } /// Return a random seed inherently predictable, based on the invocation id, which is not secret. diff --git a/src/context/run.rs b/src/context/run.rs index 1d90fc4..3d63dbe 100644 --- a/src/context/run.rs +++ b/src/context/run.rs @@ -1,6 +1,7 @@ use crate::errors::HandlerResult; use crate::serde::{Deserialize, Serialize}; use std::future::Future; +use std::time::Duration; /// Run closure trait pub trait RunClosure { @@ -23,3 +24,94 @@ where self() } } + +/// Future created using [`super::ContextSideEffects::run`]. +pub trait RunFuture: Future { + /// Provide a custom retry policy for this `run` operation. + /// + /// If unspecified, the `run` will be retried using the [Restate invoker retry policy](https://docs.restate.dev/operate/configuration/server), + /// which by default retries indefinitely. + fn with_retry_policy(self, retry_policy: RunRetryPolicy) -> Self; + + /// Define a name for this `run` operation. + /// + /// This is used mainly for observability. + fn named(self, name: impl Into) -> Self; +} + +/// This struct represents the policy to execute retries for run closures. +#[derive(Debug, Clone)] +pub struct RunRetryPolicy { + pub(crate) initial_interval: Duration, + pub(crate) factor: f32, + pub(crate) max_interval: Option, + pub(crate) max_attempts: Option, + pub(crate) max_duration: Option, +} + +impl Default for RunRetryPolicy { + fn default() -> Self { + Self { + initial_interval: Duration::from_millis(100), + factor: 2.0, + max_interval: Some(Duration::from_secs(2)), + max_attempts: None, + max_duration: Some(Duration::from_secs(50)), + } + } +} + +impl RunRetryPolicy { + /// Create a new retry policy. + pub fn new() -> Self { + Self { + initial_interval: Duration::from_millis(100), + factor: 1.0, + max_interval: None, + max_attempts: None, + max_duration: None, + } + } + + /// Initial interval for the first retry attempt. + pub fn with_initial_interval(mut self, initial_interval: Duration) -> Self { + self.initial_interval = initial_interval; + self + } + + /// Maximum interval between retries. + pub fn with_factor(mut self, factor: f32) -> Self { + self.factor = factor; + self + } + + /// Maximum interval between retries. + pub fn with_max_interval(mut self, max_interval: Duration) -> Self { + self.max_interval = Some(max_interval); + self + } + + /// Gives up retrying when either at least the given number of attempts is reached, + /// or `max_duration` (if set) is reached first. + /// + /// **Note:** The number of actual retries may be higher than the provided value. + /// This is due to the nature of the run operation, which executes the closure on the service and sends the result afterward to Restate. + /// + /// Infinite retries if this field and `max_duration` are unset. + pub fn with_max_attempts(mut self, max_attempts: u32) -> Self { + self.max_attempts = Some(max_attempts); + self + } + + /// Gives up retrying when either the retry loop lasted at least for this given max duration, + /// or `max_attempts` (if set) is reached first. + /// + /// **Note:** The real retry loop duration may be higher than the given duration. + /// This is due to the nature of the run operation, which executes the closure on the service and sends the result afterward to Restate. + /// + /// Infinite retries if this field and `max_attempts` are unset. + pub fn with_max_duration(mut self, max_duration: Duration) -> Self { + self.max_duration = Some(max_duration); + self + } +} diff --git a/src/endpoint/context.rs b/src/endpoint/context.rs index 29eb8f1..892ba2e 100644 --- a/src/endpoint/context.rs +++ b/src/endpoint/context.rs @@ -1,27 +1,32 @@ -use crate::context::{Request, RequestTarget, RunClosure}; -use crate::endpoint::futures::{InterceptErrorFuture, TrapFuture}; +use crate::context::{Request, RequestTarget, RunClosure, RunRetryPolicy}; +use crate::endpoint::futures::async_result_poll::VmAsyncResultPollFuture; +use crate::endpoint::futures::intercept_error::InterceptErrorFuture; +use crate::endpoint::futures::trap::TrapFuture; use crate::endpoint::handler_state::HandlerStateNotifier; use crate::endpoint::{Error, ErrorInner, InputReceiver, OutputSender}; use crate::errors::{HandlerErrorInner, HandlerResult, TerminalError}; use crate::serde::{Deserialize, Serialize}; -use bytes::Bytes; use futures::future::Either; use futures::{FutureExt, TryFutureExt}; +use pin_project_lite::pin_project; use restate_sdk_shared_core::{ - AsyncResultHandle, CoreVM, NonEmptyValue, RunEnterResult, SuspendedOrVMError, TakeOutputResult, - Target, VMError, Value, VM, + CoreVM, Failure, NonEmptyValue, RetryPolicy, RunEnterResult, RunExitResult, TakeOutputResult, + Target, Value, VM, }; +use std::borrow::Cow; use std::collections::HashMap; use std::future::{ready, Future}; +use std::marker::PhantomData; +use std::mem; use std::pin::Pin; use std::sync::{Arc, Mutex}; -use std::task::Poll; -use std::time::Duration; +use std::task::{ready, Context, Poll}; +use std::time::{Duration, Instant}; pub struct ContextInternalInner { - vm: CoreVM, - read: InputReceiver, - write: OutputSender, + pub(crate) vm: CoreVM, + pub(crate) read: InputReceiver, + pub(crate) write: OutputSender, pub(super) handler_state: HandlerStateNotifier, } @@ -42,7 +47,7 @@ impl ContextInternalInner { pub(super) fn fail(&mut self, e: Error) { self.vm - .notify_error(e.0.to_string().into(), format!("{:#}", e.0).into()); + .notify_error(e.0.to_string().into(), format!("{:#}", e.0).into(), None); self.handler_state.mark_error(e); } } @@ -130,37 +135,38 @@ impl ContextInternal { pub fn input(&self) -> impl Future { let mut inner_lock = must_lock!(self.inner); - let input_result = inner_lock - .vm - .sys_input() - .map_err(ErrorInner::VM) - .map(|raw_input| { - let headers = http::HeaderMap::::try_from( - &raw_input - .headers - .into_iter() - .map(|h| (h.key.to_string(), h.value.to_string())) - .collect::>(), - ) - .map_err(|e| { - TerminalError::new_with_code(400, format!("Cannot decode headers: {e:?}")) - })?; - - Ok::<_, TerminalError>(( - T::deserialize(&mut (raw_input.input.into())).map_err(|e| { - TerminalError::new_with_code( - 400, - format!("Cannot decode input payload: {e:?}"), - ) - })?, - InputMetadata { - invocation_id: raw_input.invocation_id, - random_seed: raw_input.random_seed, - key: raw_input.key, - headers, - }, - )) - }); + let input_result = + inner_lock + .vm + .sys_input() + .map_err(ErrorInner::VM) + .map(|mut raw_input| { + let headers = http::HeaderMap::::try_from( + &raw_input + .headers + .into_iter() + .map(|h| (h.key.to_string(), h.value.to_string())) + .collect::>(), + ) + .map_err(|e| { + TerminalError::new_with_code(400, format!("Cannot decode headers: {e:?}")) + })?; + + Ok::<_, TerminalError>(( + T::deserialize(&mut (raw_input.input)).map_err(|e| { + TerminalError::new_with_code( + 400, + format!("Cannot decode input payload: {e:?}"), + ) + })?, + InputMetadata { + invocation_id: raw_input.invocation_id, + random_seed: raw_input.random_seed, + key: raw_input.key, + headers, + }, + )) + }); match input_result { Ok(Ok(i)) => { @@ -194,23 +200,23 @@ impl ContextInternal { ) -> impl Future, TerminalError>> + Send + Sync { let maybe_handle = { must_lock!(self.inner).vm.sys_state_get(key.to_owned()) }; - let poll_future = self.create_poll_future(maybe_handle).map(|res| match res { - Ok(Value::Void) => Ok(Ok(None)), - Ok(Value::Success(s)) => { - let mut b = Bytes::from(s); - let t = T::deserialize(&mut b).map_err(|e| ErrorInner::Deserialization { + let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) + .map(|res| match res { + Ok(Value::Void) => Ok(Ok(None)), + Ok(Value::Success(mut s)) => { + let t = T::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { + syscall: "get_state", + err: Box::new(e), + })?; + Ok(Ok(Some(t))) + } + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "state_keys", syscall: "get_state", - err: Box::new(e), - })?; - Ok(Ok(Some(t))) - } - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "get_state", - }), - Err(e) => Err(e), - }); + }), + Err(e) => Err(e), + }); InterceptErrorFuture::new(self.clone(), poll_future.map_err(Error)) } @@ -220,19 +226,20 @@ impl ContextInternal { ) -> impl Future, TerminalError>> + Send + Sync { let maybe_handle = { must_lock!(self.inner).vm.sys_state_get_keys() }; - let poll_future = self.create_poll_future(maybe_handle).map(|res| match res { - Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "empty", - syscall: "get_state", - }), - Ok(Value::Success(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "success", - syscall: "get_state", - }), - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Ok(Value::StateKeys(s)) => Ok(Ok(s)), - Err(e) => Err(e), - }); + let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) + .map(|res| match res { + Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "empty", + syscall: "get_state", + }), + Ok(Value::Success(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "success", + syscall: "get_state", + }), + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(Value::StateKeys(s)) => Ok(Ok(s)), + Err(e) => Err(e), + }); InterceptErrorFuture::new(self.clone(), poll_future.map_err(Error)) } @@ -241,7 +248,7 @@ impl ContextInternal { let mut inner_lock = must_lock!(self.inner); match t.serialize() { Ok(b) => { - let _ = inner_lock.vm.sys_state_set(key.to_owned(), b.to_vec()); + let _ = inner_lock.vm.sys_state_set(key.to_owned(), b); } Err(e) => { inner_lock.fail( @@ -269,19 +276,20 @@ impl ContextInternal { ) -> impl Future> + Send + Sync { let maybe_handle = { must_lock!(self.inner).vm.sys_sleep(duration) }; - let poll_future = self.create_poll_future(maybe_handle).map(|res| match res { - Ok(Value::Void) => Ok(Ok(())), - Ok(Value::Success(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "success", - syscall: "sleep", - }), - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Err(e) => Err(e), - Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "sleep", - }), - }); + let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) + .map(|res| match res { + Ok(Value::Void) => Ok(Ok(())), + Ok(Value::Success(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "success", + syscall: "sleep", + }), + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Err(e) => Err(e), + Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "state_keys", + syscall: "sleep", + }), + }); InterceptErrorFuture::new(self.clone(), poll_future.map_err(Error)) } @@ -311,31 +319,29 @@ impl ContextInternal { } }; - let maybe_handle = inner_lock - .vm - .sys_call(request_target.into(), input.to_vec()); + let maybe_handle = inner_lock.vm.sys_call(request_target.into(), input); drop(inner_lock); - let poll_future = self.create_poll_future(maybe_handle).map(|res| match res { - Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "empty", - syscall: "call", - }), - Ok(Value::Success(s)) => { - let mut b = Bytes::from(s); - let t = Res::deserialize(&mut b).map_err(|e| ErrorInner::Deserialization { + let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) + .map(|res| match res { + Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "empty", syscall: "call", - err: Box::new(e), - })?; - Ok(Ok(t)) - } - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "call", - }), - Err(e) => Err(e), - }); + }), + Ok(Value::Success(mut s)) => { + let t = Res::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { + syscall: "call", + err: Box::new(e), + })?; + Ok(Ok(t)) + } + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "state_keys", + syscall: "call", + }), + Err(e) => Err(e), + }); Either::Left(InterceptErrorFuture::new( self.clone(), @@ -353,9 +359,7 @@ impl ContextInternal { match Req::serialize(&req) { Ok(t) => { - let _ = inner_lock - .vm - .sys_send(request_target.into(), t.to_vec(), delay); + let _ = inner_lock.vm.sys_send(request_target.into(), t, delay); } Err(e) => { inner_lock.fail( @@ -387,26 +391,26 @@ impl ContextInternal { ), }; - let poll_future = self.create_poll_future(maybe_handle).map(|res| match res { - Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "empty", - syscall: "awakeable", - }), - Ok(Value::Success(s)) => { - let mut b = Bytes::from(s); - let t = T::deserialize(&mut b).map_err(|e| ErrorInner::Deserialization { + let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) + .map(|res| match res { + Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "empty", syscall: "awakeable", - err: Box::new(e), - })?; - Ok(Ok(t)) - } - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "awakeable", - }), - Err(e) => Err(e), - }); + }), + Ok(Value::Success(mut s)) => { + let t = T::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { + syscall: "awakeable", + err: Box::new(e), + })?; + Ok(Ok(t)) + } + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "state_keys", + syscall: "awakeable", + }), + Err(e) => Err(e), + }); ( awakeable_id, @@ -420,7 +424,7 @@ impl ContextInternal { Ok(b) => { let _ = inner_lock .vm - .sys_complete_awakeable(id.to_owned(), NonEmptyValue::Success(b.to_vec())); + .sys_complete_awakeable(id.to_owned(), NonEmptyValue::Success(b)); } Err(e) => { inner_lock.fail( @@ -446,26 +450,26 @@ impl ContextInternal { ) -> impl Future> + Send + Sync { let maybe_handle = { must_lock!(self.inner).vm.sys_get_promise(name.to_owned()) }; - let poll_future = self.create_poll_future(maybe_handle).map(|res| match res { - Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "empty", - syscall: "promise", - }), - Ok(Value::Success(s)) => { - let mut b = Bytes::from(s); - let t = T::deserialize(&mut b).map_err(|e| ErrorInner::Deserialization { + let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) + .map(|res| match res { + Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "empty", syscall: "promise", - err: Box::new(e), - })?; - Ok(Ok(t)) - } - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "promise", - }), - Err(e) => Err(e), - }); + }), + Ok(Value::Success(mut s)) => { + let t = T::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { + syscall: "promise", + err: Box::new(e), + })?; + Ok(Ok(t)) + } + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "state_keys", + syscall: "promise", + }), + Err(e) => Err(e), + }); InterceptErrorFuture::new(self.clone(), poll_future.map_err(Error)) } @@ -476,23 +480,23 @@ impl ContextInternal { ) -> impl Future, TerminalError>> + Send + Sync { let maybe_handle = { must_lock!(self.inner).vm.sys_peek_promise(name.to_owned()) }; - let poll_future = self.create_poll_future(maybe_handle).map(|res| match res { - Ok(Value::Void) => Ok(Ok(None)), - Ok(Value::Success(s)) => { - let mut b = Bytes::from(s); - let t = T::deserialize(&mut b).map_err(|e| ErrorInner::Deserialization { + let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) + .map(|res| match res { + Ok(Value::Void) => Ok(Ok(None)), + Ok(Value::Success(mut s)) => { + let t = T::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { + syscall: "peek_promise", + err: Box::new(e), + })?; + Ok(Ok(Some(t))) + } + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "state_keys", syscall: "peek_promise", - err: Box::new(e), - })?; - Ok(Ok(Some(t))) - } - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "peek_promise", - }), - Err(e) => Err(e), - }); + }), + Err(e) => Err(e), + }); InterceptErrorFuture::new(self.clone(), poll_future.map_err(Error)) } @@ -503,7 +507,7 @@ impl ContextInternal { Ok(b) => { let _ = inner_lock .vm - .sys_complete_promise(name.to_owned(), NonEmptyValue::Success(b.to_vec())); + .sys_complete_promise(name.to_owned(), NonEmptyValue::Success(b)); } Err(e) => { inner_lock.fail( @@ -523,90 +527,18 @@ impl ContextInternal { .sys_complete_promise(id.to_owned(), NonEmptyValue::Failure(failure.into())); } - pub fn run<'a, R, F, T>( + pub fn run<'a, Run, Fut, Res>( &'a self, - name: &'a str, - run_closure: R, - ) -> impl Future> + Send + Sync + 'a + run_closure: Run, + ) -> impl crate::context::RunFuture> + Send + Sync + 'a where - R: RunClosure + Send + Sync + 'a, - T: Serialize + Deserialize, - F: Future> + Send + Sync + 'a, + Run: RunClosure + Send + Sync + 'a, + Fut: Future> + Send + Sync + 'a, + Res: Serialize + Deserialize + 'static, { let this = Arc::clone(&self.inner); - InterceptErrorFuture::new(self.clone(), async move { - let enter_result = { must_lock!(this).vm.sys_run_enter(name.to_owned()) }; - - // Enter the side effect - match enter_result.map_err(ErrorInner::VM)? { - RunEnterResult::Executed(NonEmptyValue::Success(v)) => { - let mut b = Bytes::from(v); - let t = T::deserialize(&mut b).map_err(|e| ErrorInner::Deserialization { - syscall: "run", - err: Box::new(e), - })?; - return Ok(Ok(t)); - } - RunEnterResult::Executed(NonEmptyValue::Failure(f)) => return Ok(Err(f.into())), - RunEnterResult::NotExecuted => {} - }; - - // We need to run the closure - let res = match run_closure.run().await { - Ok(t) => NonEmptyValue::Success( - T::serialize(&t) - .map_err(|e| ErrorInner::Serialization { - syscall: "run", - err: Box::new(e), - })? - .to_vec(), - ), - Err(e) => match e.0 { - HandlerErrorInner::Retryable(err) => { - return Err(ErrorInner::RunResult { - name: name.to_owned(), - err, - } - .into()) - } - HandlerErrorInner::Terminal(t) => { - NonEmptyValue::Failure(TerminalError(t).into()) - } - }, - }; - - let handle = { - must_lock!(this) - .vm - .sys_run_exit(res) - .map_err(ErrorInner::VM)? - }; - - let value = self.create_poll_future(Ok(handle)).await?; - - match value { - Value::Void => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "empty", - syscall: "run", - } - .into()), - Value::Success(s) => { - let mut b = Bytes::from(s); - let t = T::deserialize(&mut b).map_err(|e| ErrorInner::Deserialization { - syscall: "run", - err: Box::new(e), - })?; - Ok(Ok(t)) - } - Value::Failure(f) => Ok(Err(f.into())), - Value::StateKeys(_) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "run", - } - .into()), - } - }) + InterceptErrorFuture::new(self.clone(), RunFuture::new(this, run_closure)) } pub fn handle_handler_result(&self, res: HandlerResult) { @@ -614,7 +546,7 @@ impl ContextInternal { let res_to_write = match res { Ok(success) => match T::serialize(&success) { - Ok(t) => NonEmptyValue::Success(t.to_vec()), + Ok(t) => NonEmptyValue::Success(t), Err(e) => { inner_lock.fail( ErrorInner::Serialization { @@ -647,7 +579,7 @@ impl ContextInternal { let out = inner_lock.vm.take_output(); if let TakeOutputResult::Buffer(b) = out { - if !inner_lock.write.send(b.into()) { + if !inner_lock.write.send(b) { // Nothing we can do anymore here } } @@ -656,125 +588,186 @@ impl ContextInternal { pub(super) fn fail(&self, e: Error) { must_lock!(self.inner).fail(e) } +} - fn create_poll_future( - &self, - handle: Result, - ) -> impl Future> + Send + Sync { - VmPollFuture { - state: Some(match handle { - Ok(handle) => PollState::Init { - ctx: Arc::clone(&self.inner), - handle, - }, - Err(err) => PollState::Failed(ErrorInner::VM(err)), - }), +pin_project! { + struct RunFuture { + name: String, + retry_policy: RetryPolicy, + phantom_data: PhantomData Ret>, + closure: Option, + inner_ctx: Option>>, + #[pin] + state: RunState, + } +} + +pin_project! { + #[project = RunStateProj] + enum RunState { + New, + ClosureRunning { + start_time: Instant, + #[pin] + fut: Fut, + }, + PollFutureRunning { + #[pin] + fut: VmAsyncResultPollFuture } } } -struct VmPollFuture { - state: Option, +impl RunFuture { + fn new(inner_ctx: Arc>, closure: Run) -> Self { + Self { + name: "".to_string(), + retry_policy: RetryPolicy::Infinite, + phantom_data: PhantomData, + inner_ctx: Some(inner_ctx), + closure: Some(closure), + state: RunState::New, + } + } } -enum PollState { - Init { - ctx: Arc>, - handle: AsyncResultHandle, - }, - WaitingInput { - ctx: Arc>, - handle: AsyncResultHandle, - }, - Failed(ErrorInner), +impl crate::context::RunFuture, Error>> + for RunFuture +where + Run: RunClosure + Send + Sync, + Fut: Future> + Send + Sync, + Ret: Serialize + Deserialize, +{ + fn with_retry_policy(mut self, retry_policy: RunRetryPolicy) -> Self { + self.retry_policy = RetryPolicy::Exponential { + initial_interval: retry_policy.initial_interval, + factor: retry_policy.factor, + max_interval: retry_policy.max_interval, + max_attempts: retry_policy.max_attempts, + max_duration: retry_policy.max_duration, + }; + self + } + + fn named(mut self, name: impl Into) -> Self { + self.name = name.into(); + self + } } -impl Future for VmPollFuture { - type Output = Result; +impl Future for RunFuture +where + Run: RunClosure + Send + Sync, + Res: Serialize + Deserialize, + Fut: Future> + Send + Sync, +{ + type Output = Result, Error>; - fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { - loop { - match self - .state - .take() - .expect("Future should not be polled after Poll::Ready") - { - PollState::Init { ctx, handle } => { - // Acquire lock - let mut inner_lock = must_lock!(ctx); - - // Let's consume some output - let out = inner_lock.vm.take_output(); - match out { - TakeOutputResult::Buffer(b) => { - if !inner_lock.write.send(b.into()) { - self.state = Some(PollState::Failed(ErrorInner::Suspended)); - continue; - } - } - TakeOutputResult::EOF => { - self.state = - Some(PollState::Failed(ErrorInner::UnexpectedOutputClosed)); - continue; - } - } + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); - // Notify that we reached an await point - inner_lock.vm.notify_await_point(handle); + loop { + match this.state.as_mut().project() { + RunStateProj::New => { + let enter_result = { + must_lock!(this + .inner_ctx + .as_mut() + .expect("Future should not be polled after returning Poll::Ready")) + .vm + .sys_run_enter(this.name.to_owned()) + }; - // At this point let's try to take the async result - match inner_lock.vm.take_async_result(handle) { - Ok(Some(v)) => return Poll::Ready(Ok(v)), - Ok(None) => { - drop(inner_lock); - self.state = Some(PollState::WaitingInput { ctx, handle }); - } - Err(SuspendedOrVMError::Suspended(_)) => { - self.state = Some(PollState::Failed(ErrorInner::Suspended)); + // Enter the side effect + match enter_result.map_err(ErrorInner::VM)? { + RunEnterResult::Executed(NonEmptyValue::Success(mut v)) => { + let t = Res::deserialize(&mut v).map_err(|e| { + ErrorInner::Deserialization { + syscall: "run", + err: Box::new(e), + } + })?; + return Poll::Ready(Ok(Ok(t))); } - Err(SuspendedOrVMError::VM(e)) => { - self.state = Some(PollState::Failed(ErrorInner::VM(e))); + RunEnterResult::Executed(NonEmptyValue::Failure(f)) => { + return Poll::Ready(Ok(Err(f.into()))) } - } + RunEnterResult::NotExecuted(_) => {} + }; + + // We need to run the closure + this.state.set(RunState::ClosureRunning { + start_time: Instant::now(), + fut: this + .closure + .take() + .expect("Future should not be polled after returning Poll::Ready") + .run(), + }); } - PollState::WaitingInput { ctx, handle } => { - let mut inner_lock = must_lock!(ctx); - - let read_result = match inner_lock.read.poll_recv(cx) { - Poll::Ready(t) => t, - Poll::Pending => { - drop(inner_lock); - self.state = Some(PollState::WaitingInput { ctx, handle }); - return Poll::Pending; - } + RunStateProj::ClosureRunning { start_time, fut } => { + let res = match ready!(fut.poll(cx)) { + Ok(t) => RunExitResult::Success(Res::serialize(&t).map_err(|e| { + ErrorInner::Serialization { + syscall: "run", + err: Box::new(e), + } + })?), + Err(e) => match e.0 { + HandlerErrorInner::Retryable(err) => RunExitResult::RetryableFailure { + attempt_duration: start_time.elapsed(), + failure: Failure { + code: 500, + message: err.to_string(), + }, + }, + HandlerErrorInner::Terminal(t) => { + RunExitResult::TerminalFailure(TerminalError(t).into()) + } + }, }; - // Pass read result to VM - match read_result { - Some(Ok(b)) => inner_lock.vm.notify_input(b.to_vec()), - Some(Err(e)) => inner_lock.vm.notify_error( - "Error when reading the body".into(), - e.to_string().into(), - ), - None => inner_lock.vm.notify_input_closed(), - } + let inner_ctx = this + .inner_ctx + .take() + .expect("Future should not be polled after returning Poll::Ready"); - // Now try to take async result again - match inner_lock.vm.take_async_result(handle) { - Ok(Some(v)) => return Poll::Ready(Ok(v)), - Ok(None) => { - drop(inner_lock); - self.state = Some(PollState::WaitingInput { ctx, handle }); + let handle = { + must_lock!(inner_ctx) + .vm + .sys_run_exit(res, mem::take(this.retry_policy)) + }; + + this.state.set(RunState::PollFutureRunning { + fut: VmAsyncResultPollFuture::new(Cow::Owned(inner_ctx), handle), + }); + } + RunStateProj::PollFutureRunning { fut } => { + let value = ready!(fut.poll(cx))?; + + return Poll::Ready(match value { + Value::Void => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "empty", + syscall: "run", } - Err(SuspendedOrVMError::Suspended(_)) => { - self.state = Some(PollState::Failed(ErrorInner::Suspended)); + .into()), + Value::Success(mut s) => { + let t = Res::deserialize(&mut s).map_err(|e| { + ErrorInner::Deserialization { + syscall: "run", + err: Box::new(e), + } + })?; + Ok(Ok(t)) } - Err(SuspendedOrVMError::VM(e)) => { - self.state = Some(PollState::Failed(ErrorInner::VM(e))); + Value::Failure(f) => Ok(Err(f.into())), + Value::StateKeys(_) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "state_keys", + syscall: "run", } - } + .into()), + }); } - PollState::Failed(err) => return Poll::Ready(Err(err)), } } } diff --git a/src/endpoint/futures/async_result_poll.rs b/src/endpoint/futures/async_result_poll.rs new file mode 100644 index 0000000..8f6ef5d --- /dev/null +++ b/src/endpoint/futures/async_result_poll.rs @@ -0,0 +1,141 @@ +use crate::endpoint::context::ContextInternalInner; +use crate::endpoint::ErrorInner; +use restate_sdk_shared_core::{ + AsyncResultHandle, SuspendedOrVMError, TakeOutputResult, VMError, Value, VM, +}; +use std::borrow::Cow; +use std::future::Future; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::task::Poll; + +pub(crate) struct VmAsyncResultPollFuture { + state: Option, +} + +impl VmAsyncResultPollFuture { + pub fn new( + inner: Cow<'_, Arc>>, + handle: Result, + ) -> Self { + VmAsyncResultPollFuture { + state: Some(match handle { + Ok(handle) => PollState::Init { + ctx: inner.into_owned(), + handle, + }, + Err(err) => PollState::Failed(ErrorInner::VM(err)), + }), + } + } +} + +enum PollState { + Init { + ctx: Arc>, + handle: AsyncResultHandle, + }, + WaitingInput { + ctx: Arc>, + handle: AsyncResultHandle, + }, + Failed(ErrorInner), +} + +macro_rules! must_lock { + ($mutex:expr) => { + $mutex.try_lock().expect("You're trying to await two futures at the same time and/or trying to perform some operation on the restate context while awaiting a future. This is not supported!") + }; +} + +impl Future for VmAsyncResultPollFuture { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + loop { + match self + .state + .take() + .expect("Future should not be polled after Poll::Ready") + { + PollState::Init { ctx, handle } => { + // Acquire lock + let mut inner_lock = must_lock!(ctx); + + // Let's consume some output + let out = inner_lock.vm.take_output(); + match out { + TakeOutputResult::Buffer(b) => { + if !inner_lock.write.send(b) { + self.state = Some(PollState::Failed(ErrorInner::Suspended)); + continue; + } + } + TakeOutputResult::EOF => { + self.state = + Some(PollState::Failed(ErrorInner::UnexpectedOutputClosed)); + continue; + } + } + + // Notify that we reached an await point + inner_lock.vm.notify_await_point(handle); + + // At this point let's try to take the async result + match inner_lock.vm.take_async_result(handle) { + Ok(Some(v)) => return Poll::Ready(Ok(v)), + Ok(None) => { + drop(inner_lock); + self.state = Some(PollState::WaitingInput { ctx, handle }); + } + Err(SuspendedOrVMError::Suspended(_)) => { + self.state = Some(PollState::Failed(ErrorInner::Suspended)); + } + Err(SuspendedOrVMError::VM(e)) => { + self.state = Some(PollState::Failed(ErrorInner::VM(e))); + } + } + } + PollState::WaitingInput { ctx, handle } => { + let mut inner_lock = must_lock!(ctx); + + let read_result = match inner_lock.read.poll_recv(cx) { + Poll::Ready(t) => t, + Poll::Pending => { + drop(inner_lock); + self.state = Some(PollState::WaitingInput { ctx, handle }); + return Poll::Pending; + } + }; + + // Pass read result to VM + match read_result { + Some(Ok(b)) => inner_lock.vm.notify_input(b), + Some(Err(e)) => inner_lock.vm.notify_error( + "Error when reading the body".into(), + e.to_string().into(), + None, + ), + None => inner_lock.vm.notify_input_closed(), + } + + // Now try to take async result again + match inner_lock.vm.take_async_result(handle) { + Ok(Some(v)) => return Poll::Ready(Ok(v)), + Ok(None) => { + drop(inner_lock); + self.state = Some(PollState::WaitingInput { ctx, handle }); + } + Err(SuspendedOrVMError::Suspended(_)) => { + self.state = Some(PollState::Failed(ErrorInner::Suspended)); + } + Err(SuspendedOrVMError::VM(e)) => { + self.state = Some(PollState::Failed(ErrorInner::VM(e))); + } + } + } + PollState::Failed(err) => return Poll::Ready(Err(err)), + } + } + } +} diff --git a/src/endpoint/futures.rs b/src/endpoint/futures/handler_state_aware.rs similarity index 53% rename from src/endpoint/futures.rs rename to src/endpoint/futures/handler_state_aware.rs index 4836feb..4ed225f 100644 --- a/src/endpoint/futures.rs +++ b/src/endpoint/futures/handler_state_aware.rs @@ -1,76 +1,14 @@ use crate::endpoint::{ContextInternal, Error}; use pin_project_lite::pin_project; use std::future::Future; -use std::marker::PhantomData; use std::pin::Pin; -use std::task::{ready, Context, Poll}; +use std::task::{Context, Poll}; use tokio::sync::oneshot; use tracing::warn; -/// Future that traps the execution at this point, but keeps waking up the waker -pub(super) struct TrapFuture(PhantomData); - -impl Default for TrapFuture { - fn default() -> Self { - Self(PhantomData) - } -} - -/// This is always safe, because we simply use phantom data inside TrapFuture. -unsafe impl Send for TrapFuture {} -unsafe impl Sync for TrapFuture {} - -impl Future for TrapFuture { - type Output = T; - - fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { - ctx.waker().wake_by_ref(); - Poll::Pending - } -} - -pin_project! { - /// Future that intercepts errors of inner future, and passes them to ContextInternal - pub(super) struct InterceptErrorFuture{ - #[pin] - fut: F, - ctx: ContextInternal - } -} - -impl InterceptErrorFuture { - pub(super) fn new(ctx: ContextInternal, fut: F) -> Self { - Self { fut, ctx } - } -} - -impl Future for InterceptErrorFuture -where - F: Future>, -{ - type Output = R; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - let result = ready!(this.fut.poll(cx)); - - match result { - Ok(r) => Poll::Ready(r), - Err(e) => { - this.ctx.fail(e); - - // Here is the secret sauce. This will immediately cause the whole future chain to be polled, - // but the poll here will be intercepted by HandlerStateAwareFuture - cx.waker().wake_by_ref(); - Poll::Pending - } - } - } -} - pin_project! { /// Future that will stop polling when handler is suspended/failed - pub(super) struct HandlerStateAwareFuture { + pub struct HandlerStateAwareFuture { #[pin] fut: F, handler_state_rx: oneshot::Receiver, @@ -79,7 +17,7 @@ pin_project! { } impl HandlerStateAwareFuture { - pub(super) fn new( + pub fn new( handler_context: ContextInternal, handler_state_rx: oneshot::Receiver, fut: F, diff --git a/src/endpoint/futures/intercept_error.rs b/src/endpoint/futures/intercept_error.rs new file mode 100644 index 0000000..b187bc5 --- /dev/null +++ b/src/endpoint/futures/intercept_error.rs @@ -0,0 +1,60 @@ +use crate::context::{RunFuture, RunRetryPolicy}; +use crate::endpoint::{ContextInternal, Error}; +use pin_project_lite::pin_project; +use std::future::Future; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; + +pin_project! { + /// Future that intercepts errors of inner future, and passes them to ContextInternal + pub struct InterceptErrorFuture{ + #[pin] + fut: F, + ctx: ContextInternal + } +} + +impl InterceptErrorFuture { + pub fn new(ctx: ContextInternal, fut: F) -> Self { + Self { fut, ctx } + } +} + +impl Future for InterceptErrorFuture +where + F: Future>, +{ + type Output = R; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let result = ready!(this.fut.poll(cx)); + + match result { + Ok(r) => Poll::Ready(r), + Err(e) => { + this.ctx.fail(e); + + // Here is the secret sauce. This will immediately cause the whole future chain to be polled, + // but the poll here will be intercepted by HandlerStateAwareFuture + cx.waker().wake_by_ref(); + Poll::Pending + } + } + } +} + +impl RunFuture for InterceptErrorFuture +where + F: RunFuture>, +{ + fn with_retry_policy(mut self, retry_policy: RunRetryPolicy) -> Self { + self.fut = self.fut.with_retry_policy(retry_policy); + self + } + + fn named(mut self, name: impl Into) -> Self { + self.fut = self.fut.named(name); + self + } +} diff --git a/src/endpoint/futures/mod.rs b/src/endpoint/futures/mod.rs new file mode 100644 index 0000000..9f0b03f --- /dev/null +++ b/src/endpoint/futures/mod.rs @@ -0,0 +1,4 @@ +pub mod async_result_poll; +pub mod handler_state_aware; +pub mod intercept_error; +pub mod trap; diff --git a/src/endpoint/futures/trap.rs b/src/endpoint/futures/trap.rs new file mode 100644 index 0000000..9b0269d --- /dev/null +++ b/src/endpoint/futures/trap.rs @@ -0,0 +1,22 @@ +use std::future::Future; +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// Future that traps the execution at this point, but keeps waking up the waker +pub struct TrapFuture(PhantomData T>); + +impl Default for TrapFuture { + fn default() -> Self { + Self(PhantomData) + } +} + +impl Future for TrapFuture { + type Output = T; + + fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { + ctx.waker().wake_by_ref(); + Poll::Pending + } +} diff --git a/src/endpoint/mod.rs b/src/endpoint/mod.rs index 3bf5a93..65e59db 100644 --- a/src/endpoint/mod.rs +++ b/src/endpoint/mod.rs @@ -2,7 +2,8 @@ mod context; mod futures; mod handler_state; -use crate::endpoint::futures::InterceptErrorFuture; +use crate::endpoint::futures::handler_state_aware::HandlerStateAwareFuture; +use crate::endpoint::futures::intercept_error::InterceptErrorFuture; use crate::endpoint::handler_state::HandlerStateNotifier; use crate::service::{Discoverable, Service}; use ::futures::future::BoxFuture; @@ -82,14 +83,13 @@ impl Error { /// Returns the HTTP status code for this error. pub fn status_code(&self) -> u16 { match &self.0 { - ErrorInner::VM(e) => e.code, + ErrorInner::VM(e) => e.code(), ErrorInner::UnknownService(_) | ErrorInner::UnknownServiceHandler(_, _) => 404, ErrorInner::Suspended | ErrorInner::UnexpectedOutputClosed | ErrorInner::UnexpectedValueVariantForSyscall { .. } | ErrorInner::Deserialization { .. } | ErrorInner::Serialization { .. } - | ErrorInner::RunResult { .. } | ErrorInner::HandlerResult { .. } => 500, ErrorInner::BadDiscovery(_) => 415, ErrorInner::Header { .. } | ErrorInner::BadPath { .. } => 400, @@ -99,7 +99,7 @@ impl Error { } #[derive(Debug, thiserror::Error)] -enum ErrorInner { +pub(crate) enum ErrorInner { #[error("Received a request for unknown service '{0}'")] UnknownService(String), #[error("Received a request for unknown service handler '{0}/{1}'")] @@ -135,12 +135,6 @@ enum ErrorInner { #[source] err: BoxError, }, - #[error("Run '{name}' failed with retryable error: {err:?}'")] - RunResult { - name: String, - #[source] - err: BoxError, - }, #[error("Handler failed with retryable error: {err:?}'")] HandlerResult { #[source] @@ -182,8 +176,8 @@ impl Default for Builder { Self { svcs: Default::default(), discovery: crate::discovery::Endpoint { - max_protocol_version: 1, - min_protocol_version: 1, + max_protocol_version: 2, + min_protocol_version: 2, protocol_mode: Some(crate::discovery::ProtocolMode::BidiStream), services: vec![], }, @@ -366,16 +360,18 @@ impl BidiStreamRunner { let user_code_fut = InterceptErrorFuture::new(ctx.clone(), svc.handle(ctx.clone())); // Wrap it in handler state aware future - futures::HandlerStateAwareFuture::new(ctx.clone(), handler_state_rx, user_code_fut).await + HandlerStateAwareFuture::new(ctx.clone(), handler_state_rx, user_code_fut).await } async fn init_loop_vm(vm: &mut CoreVM, input_rx: &mut InputReceiver) -> Result<(), ErrorInner> { while !vm.is_ready_to_execute().map_err(ErrorInner::VM)? { match input_rx.recv().await { - Some(Ok(b)) => vm.notify_input(b.to_vec()), - Some(Err(e)) => { - vm.notify_error("Error when reading the body".into(), e.to_string().into()) - } + Some(Ok(b)) => vm.notify_input(b), + Some(Err(e)) => vm.notify_error( + "Error when reading the body".into(), + e.to_string().into(), + None, + ), None => vm.notify_input_closed(), } } diff --git a/src/lib.rs b/src/lib.rs index e80c540..35f7e18 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -58,7 +58,7 @@ pub mod prelude { pub use crate::context::{ Context, ContextAwakeables, ContextClient, ContextPromises, ContextReadState, ContextSideEffects, ContextTimers, ContextWriteState, HeaderMap, ObjectContext, Request, - SharedObjectContext, SharedWorkflowContext, WorkflowContext, + RunFuture, RunRetryPolicy, SharedObjectContext, SharedWorkflowContext, WorkflowContext, }; pub use crate::endpoint::Endpoint; pub use crate::errors::{HandlerError, HandlerResult, TerminalError}; diff --git a/test-services/src/failing.rs b/test-services/src/failing.rs index b8b5ec0..a6249e1 100644 --- a/test-services/src/failing.rs +++ b/test-services/src/failing.rs @@ -2,6 +2,7 @@ use anyhow::anyhow; use restate_sdk::prelude::*; use std::sync::atomic::{AtomicI32, Ordering}; use std::sync::Arc; +use std::time::Duration; #[restate_sdk::object] #[name = "Failing"] @@ -12,16 +13,22 @@ pub(crate) trait Failing { async fn call_terminally_failing_call(error_message: String) -> HandlerResult; #[name = "failingCallWithEventualSuccess"] async fn failing_call_with_eventual_success() -> HandlerResult; - #[name = "failingSideEffectWithEventualSuccess"] - async fn failing_side_effect_with_eventual_success() -> HandlerResult; #[name = "terminallyFailingSideEffect"] async fn terminally_failing_side_effect(error_message: String) -> HandlerResult<()>; + #[name = "sideEffectSucceedsAfterGivenAttempts"] + async fn side_effect_succeeds_after_given_attempts(minimum_attempts: i32) + -> HandlerResult; + #[name = "sideEffectFailsAfterGivenAttempts"] + async fn side_effect_fails_after_given_attempts( + retry_policy_max_retry_count: i32, + ) -> HandlerResult; } #[derive(Clone, Default)] pub(crate) struct FailingImpl { eventual_success_calls: Arc, eventual_success_side_effects: Arc, + eventual_failure_side_effects: Arc, } impl Failing for FailingImpl { @@ -59,39 +66,69 @@ impl Failing for FailingImpl { } } - async fn failing_side_effect_with_eventual_success( + async fn terminally_failing_side_effect( + &self, + context: ObjectContext<'_>, + error_message: String, + ) -> HandlerResult<()> { + context + .run(|| async move { Err::<(), _>(TerminalError::new(error_message).into()) }) + .await?; + + unreachable!("This should be unreachable") + } + + async fn side_effect_succeeds_after_given_attempts( &self, context: ObjectContext<'_>, + minimum_attempts: i32, ) -> HandlerResult { - let cloned_eventual_side_effect_calls = Arc::clone(&self.eventual_success_side_effects); + let cloned_counter = Arc::clone(&self.eventual_success_side_effects); let success_attempt = context - .run("failing_side_effect", || async move { - let current_attempt = - cloned_eventual_side_effect_calls.fetch_add(1, Ordering::SeqCst) + 1; + .run(|| async move { + let current_attempt = cloned_counter.fetch_add(1, Ordering::SeqCst) + 1; - if current_attempt >= 4 { - cloned_eventual_side_effect_calls.store(0, Ordering::SeqCst); + if current_attempt >= minimum_attempts { + cloned_counter.store(0, Ordering::SeqCst); Ok(current_attempt) } else { - Err(anyhow!("Failed at attempt ${current_attempt}").into()) + Err(anyhow!("Failed at attempt {current_attempt}").into()) } }) + .with_retry_policy( + RunRetryPolicy::new() + .with_initial_interval(Duration::from_millis(10)) + .with_factor(1.0), + ) + .named("failing_side_effect") .await?; Ok(success_attempt) } - async fn terminally_failing_side_effect( + async fn side_effect_fails_after_given_attempts( &self, context: ObjectContext<'_>, - error_message: String, - ) -> HandlerResult<()> { - context - .run("failing_side_effect", || async move { - Err::<(), _>(TerminalError::new(error_message).into()) + retry_policy_max_retry_count: i32, + ) -> HandlerResult { + let cloned_counter = Arc::clone(&self.eventual_failure_side_effects); + if context + .run(|| async move { + let current_attempt = cloned_counter.fetch_add(1, Ordering::SeqCst) + 1; + Err::<(), _>(anyhow!("Failed at attempt {current_attempt}").into()) }) - .await?; - - unreachable!("This should be unreachable") + .with_retry_policy( + RunRetryPolicy::new() + .with_initial_interval(Duration::from_millis(10)) + .with_factor(1.0) + .with_max_attempts(retry_policy_max_retry_count as u32), + ) + .await + .is_err() + { + Ok(self.eventual_failure_side_effects.load(Ordering::SeqCst)) + } else { + Err(TerminalError::new("Expecting the side effect to fail!"))? + } } } diff --git a/test-services/src/test_utils_service.rs b/test-services/src/test_utils_service.rs index cbfc688..a1d84a1 100644 --- a/test-services/src/test_utils_service.rs +++ b/test-services/src/test_utils_service.rs @@ -141,7 +141,7 @@ impl TestUtilsService for TestUtilsServiceImpl { for _ in 0..increments { let counter_clone = Arc::clone(&counter); context - .run("count", || async { + .run(|| async { counter_clone.fetch_add(1, Ordering::SeqCst); Ok(()) })