diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index 865c228..9952479 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -105,7 +105,7 @@ jobs: cache-to: type=gha,mode=max,scope=${{ github.workflow }} - name: Run test tool - uses: restatedev/sdk-test-suite@v2.4 + uses: restatedev/sdk-test-suite@v3.0 with: restateContainerImage: ${{ inputs.restateCommit != '' && 'localhost/restatedev/restate-commit-download:latest' || (inputs.restateImage != '' && inputs.restateImage || 'ghcr.io/restatedev/restate:main') }} serviceContainerImage: "restatedev/rust-test-services" diff --git a/Cargo.toml b/Cargo.toml index 97dc91d..a5104b2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ default = ["http_server", "rand", "uuid"] hyper = ["dep:hyper", "http-body-util", "restate-sdk-shared-core/http"] http_server = ["hyper", "hyper/server", "hyper/http2", "hyper-util", "tokio/net", "tokio/signal", "tokio/macros"] + [dependencies] bytes = "1.6.1" futures = "0.3" @@ -23,8 +24,7 @@ pin-project-lite = "0.2" rand = { version = "0.8.5", optional = true } regress = "0.10" restate-sdk-macros = { version = "0.3.2", path = "macros" } -restate-sdk-shared-core = "0.1.0" -sha2 = "=0.11.0-pre.3" +restate-sdk-shared-core = { git = "https://github.com/restatedev/sdk-shared-core.git", branch = "main", features = ["request_identity", "sha2_random_seed", "http"] } serde = "1.0" serde_json = "1.0" thiserror = "1.0.63" diff --git a/README.md b/README.md index ea39d16..1e01ccd 100644 --- a/README.md +++ b/README.md @@ -121,10 +121,11 @@ The Rust SDK is currently in active development, and might break across releases The compatibility with Restate is described in the following table: -| Restate Server\sdk-rust | 0.0/0.1/0.2 | 0.3 | -|-------------------------|-------------|-----| -| 1.0 | ✅ | ❌ | -| 1.1 | ✅ | ✅ | +| Restate Server\sdk-rust | 0.0/0.1/0.2 | 0.3 | 0.4 | +|-------------------------|-------------|-----|-----| +| 1.0 | ✅ | ❌ | ❌ | +| 1.1 | ✅ | ✅ | ❌ | +| 1.2 | ✅ | ✅ | ✅ | ## Contributing diff --git a/examples/cron.rs b/examples/cron.rs index 4b43f19..eef7107 100644 --- a/examples/cron.rs +++ b/examples/cron.rs @@ -77,7 +77,7 @@ impl PeriodicTaskImpl { .object_client::(context.key()) .run() // And send with a delay - .send_with_delay(Duration::from_secs(10)); + .send_after(Duration::from_secs(10)); } } diff --git a/src/context/mod.rs b/src/context/mod.rs index 08fa37a..4a0129e 100644 --- a/src/context/mod.rs +++ b/src/context/mod.rs @@ -8,7 +8,7 @@ use std::time::Duration; mod request; mod run; -pub use request::{Request, RequestTarget}; +pub use request::{CallFuture, InvocationHandle, Request, RequestTarget}; pub use run::{RunClosure, RunFuture, RunRetryPolicy}; pub type HeaderMap = http::HeaderMap; @@ -370,20 +370,20 @@ impl<'ctx, CTX: private::SealedContext<'ctx>> ContextTimers<'ctx> for CTX {} /// // To a Service: /// ctx.service_client::() /// .my_handler(String::from("Hi!")) -/// .send_with_delay(Duration::from_millis(5000)); +/// .send_after(Duration::from_millis(5000)); /// /// // To a Virtual Object: /// ctx.object_client::("Mary") /// .my_handler(String::from("Hi!")) -/// .send_with_delay(Duration::from_millis(5000)); +/// .send_after(Duration::from_millis(5000)); /// /// // To a Workflow: /// ctx.workflow_client::("my-workflow-id") /// .run(String::from("Hi!")) -/// .send_with_delay(Duration::from_millis(5000)); +/// .send_after(Duration::from_millis(5000)); /// ctx.workflow_client::("my-workflow-id") /// .interact_with_workflow() -/// .send_with_delay(Duration::from_millis(5000)); +/// .send_after(Duration::from_millis(5000)); /// # Ok(()) /// # } /// ``` @@ -433,6 +433,11 @@ pub trait ContextClient<'ctx>: private::SealedContext<'ctx> { Request::new(self.inner_context(), request_target, req) } + /// Create an [`InvocationHandle`] from an invocation id. + fn invocation_handle(&self, invocation_id: String) -> impl InvocationHandle + 'ctx { + self.inner_context().invocation_handle(invocation_id) + } + /// Create a service client. The service client is generated by the [`restate_sdk_macros::service`] macro with the same name of the trait suffixed with `Client`. /// /// ```rust,no_run @@ -454,7 +459,7 @@ pub trait ContextClient<'ctx>: private::SealedContext<'ctx> { /// client.handle().send(); /// /// // Schedule the request to be executed later - /// client.handle().send_with_delay(Duration::from_secs(60)); + /// client.handle().send_after(Duration::from_secs(60)); /// # } /// ``` fn service_client(&self) -> C @@ -485,7 +490,7 @@ pub trait ContextClient<'ctx>: private::SealedContext<'ctx> { /// client.handle().send(); /// /// // Schedule the request to be executed later - /// client.handle().send_with_delay(Duration::from_secs(60)); + /// client.handle().send_after(Duration::from_secs(60)); /// # } /// ``` fn object_client(&self, key: impl Into) -> C @@ -516,7 +521,7 @@ pub trait ContextClient<'ctx>: private::SealedContext<'ctx> { /// client.handle().send(); /// /// // Schedule the request to be executed later - /// client.handle().send_with_delay(Duration::from_secs(60)); + /// client.handle().send_after(Duration::from_secs(60)); /// # } /// ``` fn workflow_client(&self, key: impl Into) -> C @@ -627,7 +632,7 @@ pub trait ContextAwakeables<'ctx>: private::SealedContext<'ctx> { &self, ) -> ( String, - impl Future> + Send + Sync + 'ctx, + impl Future> + Send + 'ctx, ) { self.inner_context().awakeable() } diff --git a/src/context/request.rs b/src/context/request.rs index d0104f5..0284628 100644 --- a/src/context/request.rs +++ b/src/context/request.rs @@ -72,6 +72,7 @@ impl fmt::Display for RequestTarget { pub struct Request<'a, Req, Res = ()> { ctx: &'a ContextInternal, request_target: RequestTarget, + idempotency_key: Option, req: Req, res: PhantomData, } @@ -81,33 +82,54 @@ impl<'a, Req, Res> Request<'a, Req, Res> { Self { ctx, request_target, + idempotency_key: None, req, res: PhantomData, } } + /// Add idempotency key to the request + pub fn idempotency_key(mut self, idempotency_key: impl Into) -> Self { + self.idempotency_key = Some(idempotency_key.into()); + self + } + /// Call a service. This returns a future encapsulating the response. - pub fn call(self) -> impl Future> + Send + pub fn call(self) -> impl CallFuture> + Send where Req: Serialize + 'static, Res: Deserialize + 'static, { - self.ctx.call(self.request_target, self.req) + self.ctx + .call(self.request_target, self.idempotency_key, self.req) } /// Send the request to the service, without waiting for the response. - pub fn send(self) + pub fn send(self) -> impl InvocationHandle where Req: Serialize + 'static, { - self.ctx.send(self.request_target, self.req, None) + self.ctx + .send(self.request_target, self.idempotency_key, self.req, None) } /// Schedule the request to the service, without waiting for the response. - pub fn send_with_delay(self, duration: Duration) + pub fn send_after(self, delay: Duration) -> impl InvocationHandle where Req: Serialize + 'static, { - self.ctx.send(self.request_target, self.req, Some(duration)) + self.ctx.send( + self.request_target, + self.idempotency_key, + self.req, + Some(delay), + ) } } + +pub trait InvocationHandle { + fn invocation_id(&self) -> impl Future> + Send; + fn cancel(&self) -> impl Future> + Send; +} + +pub trait CallFuture: Future + InvocationHandle {} diff --git a/src/endpoint/context.rs b/src/endpoint/context.rs index 536dcb5..89dd52c 100644 --- a/src/endpoint/context.rs +++ b/src/endpoint/context.rs @@ -1,4 +1,6 @@ -use crate::context::{Request, RequestTarget, RunClosure, RunRetryPolicy}; +use crate::context::{ + CallFuture, InvocationHandle, Request, RequestTarget, RunClosure, RunFuture, RunRetryPolicy, +}; use crate::endpoint::futures::async_result_poll::VmAsyncResultPollFuture; use crate::endpoint::futures::intercept_error::InterceptErrorFuture; use crate::endpoint::futures::trap::TrapFuture; @@ -6,12 +8,12 @@ use crate::endpoint::handler_state::HandlerStateNotifier; use crate::endpoint::{Error, ErrorInner, InputReceiver, OutputSender}; use crate::errors::{HandlerErrorInner, HandlerResult, TerminalError}; use crate::serde::{Deserialize, Serialize}; -use futures::future::Either; +use futures::future::{BoxFuture, Either, Shared}; use futures::{FutureExt, TryFutureExt}; use pin_project_lite::pin_project; use restate_sdk_shared_core::{ - CoreVM, Failure, NonEmptyValue, RetryPolicy, RunEnterResult, RunExitResult, TakeOutputResult, - Target, Value, VM, + CoreVM, DoProgressResponse, Error as CoreError, NonEmptyValue, NotificationHandle, RetryPolicy, + RunExitResult, TakeOutputResult, Target, Value, VM, }; use std::borrow::Cow; use std::collections::HashMap; @@ -21,7 +23,7 @@ use std::mem; use std::pin::Pin; use std::sync::{Arc, Mutex}; use std::task::{ready, Context, Poll}; -use std::time::{Duration, Instant}; +use std::time::{Duration, Instant, SystemTime}; pub struct ContextInternalInner { pub(crate) vm: CoreVM, @@ -46,8 +48,11 @@ impl ContextInternalInner { } pub(super) fn fail(&mut self, e: Error) { - self.vm - .notify_error(e.0.to_string().into(), format!("{:#}", e.0).into(), None); + self.vm.notify_error( + CoreError::new(500u16, e.0.to_string()) + .with_stacktrace(Cow::Owned(format!("{:#}", e.0))), + None, + ); self.handler_state.mark_error(e); } } @@ -94,6 +99,18 @@ macro_rules! must_lock { }; } +macro_rules! unwrap_or_trap { + ($inner_lock:expr, $res:expr) => { + match $res { + Ok(t) => t, + Err(e) => { + $inner_lock.fail(e.into()); + return Either::Right(TrapFuture::default()); + } + } + }; +} + #[derive(Debug, Eq, PartialEq)] pub struct InputMetadata { pub invocation_id: String, @@ -109,16 +126,22 @@ impl From for Target { service: name, handler, key: None, + idempotency_key: None, + headers: vec![], }, RequestTarget::Object { name, key, handler } => Target { service: name, handler, key: Some(key), + idempotency_key: None, + headers: vec![], }, RequestTarget::Workflow { name, key, handler } => Target { service: name, handler, key: Some(key), + idempotency_key: None, + headers: vec![], }, } } @@ -197,51 +220,45 @@ impl ContextInternal { pub fn get( &self, key: &str, - ) -> impl Future, TerminalError>> + Send + Sync { - let maybe_handle = { must_lock!(self.inner).vm.sys_state_get(key.to_owned()) }; - - let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) - .map(|res| match res { - Ok(Value::Void) => Ok(Ok(None)), - Ok(Value::Success(mut s)) => { - let t = T::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { - syscall: "get_state", - err: Box::new(e), - })?; - Ok(Ok(Some(t))) - } - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "get_state", - }), - Err(e) => Err(e), - }); + ) -> impl Future, TerminalError>> + Send { + let mut inner_lock = must_lock!(self.inner); + let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_state_get(key.to_owned())); + + let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res { + Ok(Value::Void) => Ok(Ok(None)), + Ok(Value::Success(mut s)) => { + let t = + T::deserialize(&mut s).map_err(|e| Error::deserialization("get_state", e))?; + Ok(Ok(Some(t))) + } + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: <&'static str>::from(v), + syscall: "get_state", + } + .into()), + Err(e) => Err(e), + }); - InterceptErrorFuture::new(self.clone(), poll_future.map_err(Error)) + Either::Left(InterceptErrorFuture::new(self.clone(), poll_future)) } - pub fn get_keys( - &self, - ) -> impl Future, TerminalError>> + Send + Sync { - let maybe_handle = { must_lock!(self.inner).vm.sys_state_get_keys() }; - - let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) - .map(|res| match res { - Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "empty", - syscall: "get_state", - }), - Ok(Value::Success(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "success", - syscall: "get_state", - }), - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Ok(Value::StateKeys(s)) => Ok(Ok(s)), - Err(e) => Err(e), - }); + pub fn get_keys(&self) -> impl Future, TerminalError>> + Send { + let mut inner_lock = must_lock!(self.inner); + let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_state_get_keys()); + + let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res { + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(Value::StateKeys(s)) => Ok(Ok(s)), + Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: <&'static str>::from(v), + syscall: "get_keys", + } + .into()), + Err(e) => Err(e), + }); - InterceptErrorFuture::new(self.clone(), poll_future.map_err(Error)) + Either::Left(InterceptErrorFuture::new(self.clone(), poll_future)) } pub fn set(&self, key: &str, t: T) { @@ -251,13 +268,7 @@ impl ContextInternal { let _ = inner_lock.vm.sys_state_set(key.to_owned(), b); } Err(e) => { - inner_lock.fail( - ErrorInner::Serialization { - syscall: "set_state", - err: Box::new(e), - } - .into(), - ); + inner_lock.fail(Error::serialization("set_state", e)); } } } @@ -272,26 +283,29 @@ impl ContextInternal { pub fn sleep( &self, - duration: Duration, - ) -> impl Future> + Send + Sync { - let maybe_handle = { must_lock!(self.inner).vm.sys_sleep(duration) }; - - let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) - .map(|res| match res { - Ok(Value::Void) => Ok(Ok(())), - Ok(Value::Success(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "success", - syscall: "sleep", - }), - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Err(e) => Err(e), - Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "sleep", - }), - }); + sleep_duration: Duration, + ) -> impl Future> + Send { + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .expect("Duration since unix epoch cannot fail"); + let mut inner_lock = must_lock!(self.inner); + let handle = unwrap_or_trap!( + inner_lock, + inner_lock.vm.sys_sleep(now + sleep_duration, Some(now)) + ); + + let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res { + Ok(Value::Void) => Ok(Ok(())), + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: <&'static str>::from(v), + syscall: "sleep", + } + .into()), + Err(e) => Err(e), + }); - InterceptErrorFuture::new(self.clone(), poll_future.map_err(Error)) + Either::Left(InterceptErrorFuture::new(self.clone(), poll_future)) } pub fn request(&self, request_target: RequestTarget, req: Req) -> Request { @@ -301,120 +315,174 @@ impl ContextInternal { pub fn call( &self, request_target: RequestTarget, + idempotency_key: Option, req: Req, - ) -> impl Future> + Send + Sync { + ) -> impl CallFuture> + Send { let mut inner_lock = must_lock!(self.inner); - let input = match Req::serialize(&req) { - Ok(t) => t, - Err(e) => { - inner_lock.fail( - ErrorInner::Serialization { - syscall: "call", - err: Box::new(e), - } - .into(), - ); - return Either::Right(TrapFuture::default()); - } - }; + let mut target: Target = request_target.into(); + target.idempotency_key = idempotency_key; + let input = unwrap_or_trap!( + inner_lock, + Req::serialize(&req).map_err(|e| Error::serialization("call", e)) + ); - let maybe_handle = inner_lock.vm.sys_call(request_target.into(), input); + let call_handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_call(target, input)); drop(inner_lock); - let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) + // Let's prepare the two futures here + let invocation_id_fut = InterceptErrorFuture::new( + self.clone(), + get_async_result( + Arc::clone(&self.inner), + call_handle.invocation_id_notification_handle, + ) .map(|res| match res { - Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "empty", + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(Value::InvocationId(s)) => Ok(Ok(s)), + Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: <&'static str>::from(v), syscall: "call", - }), - Ok(Value::Success(mut s)) => { - let t = Res::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { - syscall: "call", - err: Box::new(e), - })?; - Ok(Ok(t)) } - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", + .into()), + Err(e) => Err(e), + }), + ); + let result_future = InterceptErrorFuture::new( + self.clone(), + get_async_result( + Arc::clone(&self.inner), + call_handle.call_notification_handle, + ) + .map(|res| match res { + Ok(Value::Success(mut s)) => Ok(Ok( + Res::deserialize(&mut s).map_err(|e| Error::deserialization("call", e))? + )), + Ok(Value::Failure(f)) => Ok(Err(TerminalError::from(f))), + Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: <&'static str>::from(v), syscall: "call", - }), + } + .into()), Err(e) => Err(e), - }); + }), + ); - Either::Left(InterceptErrorFuture::new( - self.clone(), - poll_future.map_err(Error), - )) + Either::Left(CallFutureImpl { + invocation_id_future: invocation_id_fut.shared(), + result_future, + ctx: self.clone(), + }) } pub fn send( &self, request_target: RequestTarget, + idempotency_key: Option, req: Req, delay: Option, - ) { + ) -> impl InvocationHandle { let mut inner_lock = must_lock!(self.inner); - match Req::serialize(&req) { - Ok(t) => { - let _ = inner_lock.vm.sys_send(request_target.into(), t, delay); + let mut target: Target = request_target.into(); + target.idempotency_key = idempotency_key; + + let input = match Req::serialize(&req) { + Ok(b) => b, + Err(e) => { + inner_lock.fail(Error::serialization("call", e)); + return Either::Right(TrapFuture::<()>::default()); } + }; + + let send_handle = match inner_lock.vm.sys_send( + target, + input, + delay.map(|delay| { + SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .expect("Duration since unix epoch cannot fail") + + delay + }), + ) { + Ok(h) => h, Err(e) => { - inner_lock.fail( - ErrorInner::Serialization { - syscall: "call", - err: Box::new(e), - } - .into(), - ); + inner_lock.fail(e.into()); + return Either::Right(TrapFuture::<()>::default()); } }; + drop(inner_lock); + + let invocation_id_fut = InterceptErrorFuture::new( + self.clone(), + get_async_result( + Arc::clone(&self.inner), + send_handle.invocation_id_notification_handle, + ) + .map(|res| match res { + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(Value::InvocationId(s)) => Ok(Ok(s)), + Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: <&'static str>::from(v), + syscall: "call", + } + .into()), + Err(e) => Err(e), + }), + ); + + Either::Left(SendRequestHandle { + invocation_id_future: invocation_id_fut.shared(), + ctx: self.clone(), + }) + } + + pub fn invocation_handle(&self, invocation_id: String) -> impl InvocationHandle { + InvocationIdBackedInvocationHandle { + ctx: self.clone(), + invocation_id, + } } pub fn awakeable( &self, ) -> ( String, - impl Future> + Send + Sync, + impl Future> + Send, ) { - let maybe_awakeable_id_and_handle = { must_lock!(self.inner).vm.sys_awakeable() }; - - let (awakeable_id, maybe_handle) = match maybe_awakeable_id_and_handle { - Ok((s, handle)) => (s, Ok(handle)), - Err(e) => ( - // TODO NOW this is REALLY BAD. The reason for this is that we would need to return a future of a future instead, which is not nice. - // we assume for the time being this works because no user should use the awakeable without doing any other syscall first, which will prevent this invalid awakeable id to work in the first place. - "invalid".to_owned(), - Err(e), - ), + let mut inner_lock = must_lock!(self.inner); + let maybe_awakeable_id_and_handle = inner_lock.vm.sys_awakeable(); + + let (awakeable_id, handle) = match maybe_awakeable_id_and_handle { + Ok((s, handle)) => (s, handle), + Err(e) => { + inner_lock.fail(e.into()); + return ( + // TODO NOW this is REALLY BAD. The reason for this is that we would need to return a future of a future instead, which is not nice. + // we assume for the time being this works because no user should use the awakeable without doing any other syscall first, which will prevent this invalid awakeable id to work in the first place. + "invalid".to_owned(), + Either::Right(TrapFuture::default()), + ); + } }; + drop(inner_lock); - let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) - .map(|res| match res { - Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "empty", - syscall: "awakeable", - }), - Ok(Value::Success(mut s)) => { - let t = T::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { - syscall: "awakeable", - err: Box::new(e), - })?; - Ok(Ok(t)) - } - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "awakeable", - }), - Err(e) => Err(e), - }); + let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res { + Ok(Value::Success(mut s)) => Ok(Ok( + T::deserialize(&mut s).map_err(|e| Error::deserialization("awakeable", e))? + )), + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: <&'static str>::from(v), + syscall: "awakeable", + } + .into()), + Err(e) => Err(e), + }); ( awakeable_id, - InterceptErrorFuture::new(self.clone(), poll_future.map_err(Error)), + Either::Left(InterceptErrorFuture::new(self.clone(), poll_future)), ) } @@ -427,13 +495,7 @@ impl ContextInternal { .sys_complete_awakeable(id.to_owned(), NonEmptyValue::Success(b)); } Err(e) => { - inner_lock.fail( - ErrorInner::Serialization { - syscall: "resolve_awakeable", - err: Box::new(e), - } - .into(), - ); + inner_lock.fail(Error::serialization("resolve_awakeable", e)); } } } @@ -447,58 +509,53 @@ impl ContextInternal { pub fn promise( &self, name: &str, - ) -> impl Future> + Send + Sync { - let maybe_handle = { must_lock!(self.inner).vm.sys_get_promise(name.to_owned()) }; + ) -> impl Future> + Send { + let mut inner_lock = must_lock!(self.inner); + let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_get_promise(name.to_owned())); + drop(inner_lock); - let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) - .map(|res| match res { - Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "empty", - syscall: "promise", - }), - Ok(Value::Success(mut s)) => { - let t = T::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { - syscall: "promise", - err: Box::new(e), - })?; - Ok(Ok(t)) - } - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "promise", - }), - Err(e) => Err(e), - }); + let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res { + Ok(Value::Success(mut s)) => { + let t = T::deserialize(&mut s).map_err(|e| Error::deserialization("promise", e))?; + Ok(Ok(t)) + } + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: <&'static str>::from(v), + syscall: "promise", + } + .into()), + Err(e) => Err(e), + }); - InterceptErrorFuture::new(self.clone(), poll_future.map_err(Error)) + Either::Left(InterceptErrorFuture::new(self.clone(), poll_future)) } pub fn peek_promise( &self, name: &str, - ) -> impl Future, TerminalError>> + Send + Sync { - let maybe_handle = { must_lock!(self.inner).vm.sys_peek_promise(name.to_owned()) }; + ) -> impl Future, TerminalError>> + Send { + let mut inner_lock = must_lock!(self.inner); + let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_peek_promise(name.to_owned())); + drop(inner_lock); - let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) - .map(|res| match res { - Ok(Value::Void) => Ok(Ok(None)), - Ok(Value::Success(mut s)) => { - let t = T::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { - syscall: "peek_promise", - err: Box::new(e), - })?; - Ok(Ok(Some(t))) - } - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "peek_promise", - }), - Err(e) => Err(e), - }); + let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res { + Ok(Value::Void) => Ok(Ok(None)), + Ok(Value::Success(mut s)) => { + let t = T::deserialize(&mut s) + .map_err(|e| Error::deserialization("peek_promise", e))?; + Ok(Ok(Some(t))) + } + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: <&'static str>::from(v), + syscall: "peek_promise", + } + .into()), + Err(e) => Err(e), + }); - InterceptErrorFuture::new(self.clone(), poll_future.map_err(Error)) + Either::Left(InterceptErrorFuture::new(self.clone(), poll_future)) } pub fn resolve_promise(&self, name: &str, t: T) { @@ -530,17 +587,17 @@ impl ContextInternal { pub fn run<'a, Run, Fut, Out>( &'a self, run_closure: Run, - ) -> impl crate::context::RunFuture> + Send + 'a + ) -> impl RunFuture> + Send + 'a where Run: RunClosure + Send + 'a, Fut: Future> + Send + 'a, Out: Serialize + Deserialize + 'static, { let this = Arc::clone(&self.inner); - - InterceptErrorFuture::new(self.clone(), RunFuture::new(this, run_closure)) + InterceptErrorFuture::new(self.clone(), RunFutureImpl::new(this, run_closure)) } + // Used by codegen pub fn handle_handler_result(&self, res: HandlerResult) { let mut inner_lock = must_lock!(self.inner); @@ -591,52 +648,80 @@ impl ContextInternal { } pin_project! { - struct RunFuture { + struct RunFutureImpl { name: String, retry_policy: RetryPolicy, phantom_data: PhantomData Ret>, - closure: Option, - inner_ctx: Option>>, #[pin] - state: RunState, + state: RunState, } } pin_project! { #[project = RunStateProj] - enum RunState { - New, + enum RunState { + New { + ctx: Option>>, + closure: Option, + }, ClosureRunning { + ctx: Option>>, + handle: NotificationHandle, start_time: Instant, #[pin] - fut: Fut, + closure_fut: RunFnFut, }, - PollFutureRunning { - #[pin] - fut: VmAsyncResultPollFuture + WaitingResultFut { + result_fut: BoxFuture<'static, Result, Error>> } } } -impl RunFuture { - fn new(inner_ctx: Arc>, closure: Run) -> Self { +impl RunFutureImpl { + fn new(ctx: Arc>, closure: Run) -> Self { Self { name: "".to_string(), retry_policy: RetryPolicy::Infinite, phantom_data: PhantomData, - inner_ctx: Some(inner_ctx), - closure: Some(closure), - state: RunState::New, + state: RunState::New { + ctx: Some(ctx), + closure: Some(closure), + }, } } + + fn boxed_result_fut( + ctx: Arc>, + handle: NotificationHandle, + ) -> BoxFuture<'static, Result, Error>> + where + Ret: Deserialize, + { + get_async_result(Arc::clone(&ctx), handle) + .map(|res| match res { + Ok(Value::Success(mut s)) => { + let t = + Ret::deserialize(&mut s).map_err(|e| Error::deserialization("run", e))?; + Ok(Ok(t)) + } + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: <&'static str>::from(v), + syscall: "run", + } + .into()), + Err(e) => Err(e), + }) + .boxed() + } } -impl crate::context::RunFuture, Error>> - for RunFuture +impl RunFuture, Error>> + for RunFutureImpl where - Run: RunClosure + Send, - Fut: Future> + Send, - Out: Serialize + Deserialize, + Run: RunClosure + Send, + Ret: Serialize + Deserialize, + RunFnFut: Future> + Send, { fn retry_policy(mut self, retry_policy: RunRetryPolicy) -> Self { self.retry_policy = RetryPolicy::Exponential { @@ -655,59 +740,67 @@ where } } -impl Future for RunFuture +impl Future for RunFutureImpl where - Run: RunClosure + Send, - Out: Serialize + Deserialize, - Fut: Future> + Send, + Run: RunClosure + Send, + Ret: Serialize + Deserialize, + RunFnFut: Future> + Send, { - type Output = Result, Error>; + type Output = Result, Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut this = self.project(); loop { match this.state.as_mut().project() { - RunStateProj::New => { - let enter_result = { - must_lock!(this - .inner_ctx - .as_mut() - .expect("Future should not be polled after returning Poll::Ready")) - .vm - .sys_run_enter(this.name.to_owned()) - }; + RunStateProj::New { ctx, closure, .. } => { + let ctx = ctx + .take() + .expect("Future should not be polled after returning Poll::Ready"); + let closure = closure + .take() + .expect("Future should not be polled after returning Poll::Ready"); + let mut inner_ctx = must_lock!(ctx); - // Enter the side effect - match enter_result.map_err(ErrorInner::VM)? { - RunEnterResult::Executed(NonEmptyValue::Success(mut v)) => { - let t = Out::deserialize(&mut v).map_err(|e| { - ErrorInner::Deserialization { - syscall: "run", - err: Box::new(e), - } - })?; - return Poll::Ready(Ok(Ok(t))); + let handle = inner_ctx + .vm + .sys_run(this.name.to_owned()) + .map_err(ErrorInner::from)?; + + // Now we do progress once to check whether this closure should be executed or not. + match inner_ctx.vm.do_progress(vec![handle]) { + Ok(DoProgressResponse::ExecuteRun(handle_to_run)) => { + // In case it returns ExecuteRun, it must be the handle we just gave it, + // and it means we need to execute the closure + assert_eq!(handle, handle_to_run); + + drop(inner_ctx); + this.state.set(RunState::ClosureRunning { + ctx: Some(ctx), + handle, + start_time: Instant::now(), + closure_fut: closure.run(), + }); } - RunEnterResult::Executed(NonEmptyValue::Failure(f)) => { - return Poll::Ready(Ok(Err(f.into()))) + _ => { + drop(inner_ctx); + // In all the other cases, just move on waiting the result, + // the poll future state will take care of doing whatever needs to be done here, + // that is propagating state machine error, or result, or whatever + this.state.set(RunState::WaitingResultFut { + result_fut: Self::boxed_result_fut(Arc::clone(&ctx), handle), + }) } - RunEnterResult::NotExecuted(_) => {} - }; - - // We need to run the closure - this.state.set(RunState::ClosureRunning { - start_time: Instant::now(), - fut: this - .closure - .take() - .expect("Future should not be polled after returning Poll::Ready") - .run(), - }); + } } - RunStateProj::ClosureRunning { start_time, fut } => { - let res = match ready!(fut.poll(cx)) { - Ok(t) => RunExitResult::Success(Out::serialize(&t).map_err(|e| { + RunStateProj::ClosureRunning { + ctx, + handle, + start_time, + closure_fut, + } => { + let res = match ready!(closure_fut.poll(cx)) { + Ok(t) => RunExitResult::Success(Ret::serialize(&t).map_err(|e| { ErrorInner::Serialization { syscall: "run", err: Box::new(e), @@ -716,10 +809,7 @@ where Err(e) => match e.0 { HandlerErrorInner::Retryable(err) => RunExitResult::RetryableFailure { attempt_duration: start_time.elapsed(), - failure: Failure { - code: 500, - message: err.to_string(), - }, + error: CoreError::new(500u16, err.to_string()), }, HandlerErrorInner::Terminal(t) => { RunExitResult::TerminalFailure(TerminalError(t).into()) @@ -727,48 +817,179 @@ where }, }; - let inner_ctx = this - .inner_ctx + let ctx = ctx .take() .expect("Future should not be polled after returning Poll::Ready"); - - let handle = { - must_lock!(inner_ctx) - .vm - .sys_run_exit(res, mem::take(this.retry_policy)) + let handle = *handle; + + let _ = { + must_lock!(ctx).vm.propose_run_completion( + handle, + res, + mem::take(this.retry_policy), + ) }; - this.state.set(RunState::PollFutureRunning { - fut: VmAsyncResultPollFuture::new(Cow::Owned(inner_ctx), handle), - }); - } - RunStateProj::PollFutureRunning { fut } => { - let value = ready!(fut.poll(cx))?; - - return Poll::Ready(match value { - Value::Void => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "empty", - syscall: "run", - } - .into()), - Value::Success(mut s) => { - let t = Out::deserialize(&mut s).map_err(|e| { - ErrorInner::Deserialization { - syscall: "run", - err: Box::new(e), - } - })?; - Ok(Ok(t)) - } - Value::Failure(f) => Ok(Err(f.into())), - Value::StateKeys(_) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "run", - } - .into()), + this.state.set(RunState::WaitingResultFut { + result_fut: Self::boxed_result_fut(Arc::clone(&ctx), handle), }); } + RunStateProj::WaitingResultFut { result_fut } => return result_fut.poll_unpin(cx), } } } } + +struct SendRequestHandle { + invocation_id_future: Shared, + ctx: ContextInternal, +} + +impl> + Send> InvocationHandle + for SendRequestHandle +{ + fn invocation_id(&self) -> impl Future> + Send { + Shared::clone(&self.invocation_id_future) + } + + fn cancel(&self) -> impl Future> + Send { + let cloned_invocation_id_fut = Shared::clone(&self.invocation_id_future); + let cloned_ctx = Arc::clone(&self.ctx.inner); + async move { + let inv_id = cloned_invocation_id_fut.await?; + let mut inner_lock = must_lock!(cloned_ctx); + let _ = inner_lock.vm.sys_cancel_invocation(inv_id); + drop(inner_lock); + Ok(()) + } + } +} + +pin_project! { + struct CallFutureImpl { + #[pin] + invocation_id_future: Shared, + #[pin] + result_future: ResultFut, + ctx: ContextInternal, + } +} + +impl Future for CallFutureImpl +where + InvIdFut: Future> + Send, + ResultFut: Future> + Send, +{ + type Output = ResultFut::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + this.result_future.poll(cx) + } +} + +impl InvocationHandle for CallFutureImpl +where + InvIdFut: Future> + Send, +{ + fn invocation_id(&self) -> impl Future> + Send { + Shared::clone(&self.invocation_id_future) + } + + fn cancel(&self) -> impl Future> + Send { + let cloned_invocation_id_fut = Shared::clone(&self.invocation_id_future); + let cloned_ctx = Arc::clone(&self.ctx.inner); + async move { + let inv_id = cloned_invocation_id_fut.await?; + let mut inner_lock = must_lock!(cloned_ctx); + let _ = inner_lock.vm.sys_cancel_invocation(inv_id); + drop(inner_lock); + Ok(()) + } + } +} + +impl CallFuture> + for CallFutureImpl +where + InvIdFut: Future> + Send, + ResultFut: Future> + Send, +{ +} + +impl InvocationHandle for Either +where + A: InvocationHandle, + B: InvocationHandle, +{ + fn invocation_id(&self) -> impl Future> + Send { + match self { + Either::Left(l) => Either::Left(l.invocation_id()), + Either::Right(r) => Either::Right(r.invocation_id()), + } + } + + fn cancel(&self) -> impl Future> + Send { + match self { + Either::Left(l) => Either::Left(l.cancel()), + Either::Right(r) => Either::Right(r.cancel()), + } + } +} + +impl CallFuture for Either +where + A: CallFuture, + B: CallFuture, +{ +} + +struct InvocationIdBackedInvocationHandle { + ctx: ContextInternal, + invocation_id: String, +} + +impl InvocationHandle for InvocationIdBackedInvocationHandle { + fn invocation_id(&self) -> impl Future> + Send { + ready(Ok(self.invocation_id.clone())) + } + + fn cancel(&self) -> impl Future> + Send { + let mut inner_lock = must_lock!(self.ctx.inner); + let _ = inner_lock + .vm + .sys_cancel_invocation(self.invocation_id.clone()); + ready(Ok(())) + } +} + +impl Error { + fn serialization( + syscall: &'static str, + e: E, + ) -> Self { + ErrorInner::Serialization { + syscall, + err: Box::new(e), + } + .into() + } + + fn deserialization( + syscall: &'static str, + e: E, + ) -> Self { + ErrorInner::Deserialization { + syscall, + err: Box::new(e), + } + .into() + } +} + +fn get_async_result( + ctx: Arc>, + handle: NotificationHandle, +) -> impl Future> + Send { + VmAsyncResultPollFuture::new(ctx, handle).map_err(Error::from) +} diff --git a/src/endpoint/futures/async_result_poll.rs b/src/endpoint/futures/async_result_poll.rs index 8f6ef5d..63eaadb 100644 --- a/src/endpoint/futures/async_result_poll.rs +++ b/src/endpoint/futures/async_result_poll.rs @@ -1,45 +1,39 @@ use crate::endpoint::context::ContextInternalInner; use crate::endpoint::ErrorInner; use restate_sdk_shared_core::{ - AsyncResultHandle, SuspendedOrVMError, TakeOutputResult, VMError, Value, VM, + DoProgressResponse, Error as CoreError, NotificationHandle, TakeOutputResult, TerminalFailure, + Value, VM, }; -use std::borrow::Cow; use std::future::Future; use std::pin::Pin; use std::sync::{Arc, Mutex}; use std::task::Poll; pub(crate) struct VmAsyncResultPollFuture { - state: Option, + state: Option, } impl VmAsyncResultPollFuture { - pub fn new( - inner: Cow<'_, Arc>>, - handle: Result, - ) -> Self { + pub fn new(ctx: Arc>, handle: NotificationHandle) -> Self { VmAsyncResultPollFuture { - state: Some(match handle { - Ok(handle) => PollState::Init { - ctx: inner.into_owned(), - handle, - }, - Err(err) => PollState::Failed(ErrorInner::VM(err)), - }), + state: Some(AsyncResultPollState::Init { ctx, handle }), } } } -enum PollState { +enum AsyncResultPollState { Init { ctx: Arc>, - handle: AsyncResultHandle, + handle: NotificationHandle, + }, + PollProgress { + ctx: Arc>, + handle: NotificationHandle, }, WaitingInput { ctx: Arc>, - handle: AsyncResultHandle, + handle: NotificationHandle, }, - Failed(ErrorInner), } macro_rules! must_lock { @@ -58,52 +52,35 @@ impl Future for VmAsyncResultPollFuture { .take() .expect("Future should not be polled after Poll::Ready") { - PollState::Init { ctx, handle } => { - // Acquire lock + AsyncResultPollState::Init { ctx, handle } => { let mut inner_lock = must_lock!(ctx); - // Let's consume some output + // Let's consume some output to begin with let out = inner_lock.vm.take_output(); match out { TakeOutputResult::Buffer(b) => { if !inner_lock.write.send(b) { - self.state = Some(PollState::Failed(ErrorInner::Suspended)); - continue; + return Poll::Ready(Err(ErrorInner::Suspended)); } } TakeOutputResult::EOF => { - self.state = - Some(PollState::Failed(ErrorInner::UnexpectedOutputClosed)); - continue; + return Poll::Ready(Err(ErrorInner::UnexpectedOutputClosed)) } } - // Notify that we reached an await point - inner_lock.vm.notify_await_point(handle); - - // At this point let's try to take the async result - match inner_lock.vm.take_async_result(handle) { - Ok(Some(v)) => return Poll::Ready(Ok(v)), - Ok(None) => { - drop(inner_lock); - self.state = Some(PollState::WaitingInput { ctx, handle }); - } - Err(SuspendedOrVMError::Suspended(_)) => { - self.state = Some(PollState::Failed(ErrorInner::Suspended)); - } - Err(SuspendedOrVMError::VM(e)) => { - self.state = Some(PollState::Failed(ErrorInner::VM(e))); - } - } + // We can now start polling + drop(inner_lock); + self.state = Some(AsyncResultPollState::PollProgress { ctx, handle }); } - PollState::WaitingInput { ctx, handle } => { + AsyncResultPollState::WaitingInput { ctx, handle } => { let mut inner_lock = must_lock!(ctx); let read_result = match inner_lock.read.poll_recv(cx) { Poll::Ready(t) => t, Poll::Pending => { + // Still need to wait for input drop(inner_lock); - self.state = Some(PollState::WaitingInput { ctx, handle }); + self.state = Some(AsyncResultPollState::WaitingInput { ctx, handle }); return Poll::Pending; } }; @@ -112,29 +89,56 @@ impl Future for VmAsyncResultPollFuture { match read_result { Some(Ok(b)) => inner_lock.vm.notify_input(b), Some(Err(e)) => inner_lock.vm.notify_error( - "Error when reading the body".into(), - e.to_string().into(), + CoreError::new(500u16, format!("Error when reading the body {e:?}",)), None, ), None => inner_lock.vm.notify_input_closed(), } - // Now try to take async result again - match inner_lock.vm.take_async_result(handle) { - Ok(Some(v)) => return Poll::Ready(Ok(v)), - Ok(None) => { + // It's time to poll progress again + drop(inner_lock); + self.state = Some(AsyncResultPollState::PollProgress { ctx, handle }); + } + AsyncResultPollState::PollProgress { ctx, handle } => { + let mut inner_lock = must_lock!(ctx); + + match inner_lock.vm.do_progress(vec![handle]) { + Ok(DoProgressResponse::AnyCompleted) => { + // We're good, we got the response + } + Ok(DoProgressResponse::ReadFromInput) => { drop(inner_lock); - self.state = Some(PollState::WaitingInput { ctx, handle }); + self.state = Some(AsyncResultPollState::WaitingInput { ctx, handle }); + continue; + } + Ok(DoProgressResponse::ExecuteRun(_)) => { + unimplemented!() } - Err(SuspendedOrVMError::Suspended(_)) => { - self.state = Some(PollState::Failed(ErrorInner::Suspended)); + Ok(DoProgressResponse::WaitingPendingRun) => { + unimplemented!() } - Err(SuspendedOrVMError::VM(e)) => { - self.state = Some(PollState::Failed(ErrorInner::VM(e))); + Ok(DoProgressResponse::CancelSignalReceived) => { + return Poll::Ready(Ok(Value::Failure(TerminalFailure { + code: 409, + message: "cancelled".to_string(), + }))) + } + Err(e) => { + return Poll::Ready(Err(e.into())); + } + }; + + // At this point let's try to take the notification + match inner_lock.vm.take_notification(handle) { + Ok(Some(v)) => return Poll::Ready(Ok(v)), + Ok(None) => { + panic!( + "This is not supposed to happen, handle was flagged as completed" + ) } + Err(e) => return Poll::Ready(Err(e.into())), } } - PollState::Failed(err) => return Poll::Ready(Err(err)), } } } diff --git a/src/endpoint/futures/intercept_error.rs b/src/endpoint/futures/intercept_error.rs index b486fbf..81eafe5 100644 --- a/src/endpoint/futures/intercept_error.rs +++ b/src/endpoint/futures/intercept_error.rs @@ -1,5 +1,6 @@ -use crate::context::{RunFuture, RunRetryPolicy}; +use crate::context::{InvocationHandle, RunFuture, RunRetryPolicy}; use crate::endpoint::{ContextInternal, Error}; +use crate::errors::TerminalError; use pin_project_lite::pin_project; use std::future::Future; use std::pin::Pin; @@ -58,3 +59,13 @@ where self } } + +impl InvocationHandle for InterceptErrorFuture { + fn invocation_id(&self) -> impl Future> + Send { + self.fut.invocation_id() + } + + fn cancel(&self) -> impl Future> + Send { + self.fut.cancel() + } +} diff --git a/src/endpoint/futures/trap.rs b/src/endpoint/futures/trap.rs index 9b0269d..b0a4ae9 100644 --- a/src/endpoint/futures/trap.rs +++ b/src/endpoint/futures/trap.rs @@ -1,3 +1,5 @@ +use crate::context::{CallFuture, InvocationHandle}; +use crate::errors::TerminalError; use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; @@ -20,3 +22,15 @@ impl Future for TrapFuture { Poll::Pending } } + +impl InvocationHandle for TrapFuture { + fn invocation_id(&self) -> impl Future> + Send { + TrapFuture::default() + } + + fn cancel(&self) -> impl Future> + Send { + TrapFuture::default() + } +} + +impl CallFuture for TrapFuture {} diff --git a/src/endpoint/mod.rs b/src/endpoint/mod.rs index 90f907a..0bea694 100644 --- a/src/endpoint/mod.rs +++ b/src/endpoint/mod.rs @@ -11,7 +11,7 @@ use ::futures::{Stream, StreamExt}; use bytes::Bytes; pub use context::{ContextInternal, InputMetadata}; use restate_sdk_shared_core::{ - CoreVM, Header, HeaderMap, IdentityVerifier, KeyError, ResponseHead, VMError, VerifyError, VM, + CoreVM, Error as CoreError, Header, HeaderMap, IdentityVerifier, KeyError, VerifyError, VM, }; use std::collections::HashMap; use std::future::poll_fn; @@ -105,7 +105,7 @@ pub(crate) enum ErrorInner { #[error("Received a request for unknown service handler '{0}/{1}'")] UnknownServiceHandler(String, String), #[error("Error when processing the request: {0:?}")] - VM(#[from] VMError), + VM(#[from] CoreError), #[error("Error when verifying identity: {0:?}")] IdentityVerification(#[from] VerifyError), #[error("Cannot convert header '{0}', reason: {1}")] @@ -142,6 +142,27 @@ pub(crate) enum ErrorInner { }, } +impl From for ErrorInner { + fn from(_: restate_sdk_shared_core::SuspendedError) -> Self { + Self::Suspended + } +} + +impl From for ErrorInner { + fn from(value: restate_sdk_shared_core::SuspendedOrVMError) -> Self { + match value { + restate_sdk_shared_core::SuspendedOrVMError::Suspended(e) => e.into(), + restate_sdk_shared_core::SuspendedOrVMError::VM(e) => e.into(), + } + } +} + +impl From for Error { + fn from(e: CoreError) -> Self { + ErrorInner::from(e).into() + } +} + struct BoxedService( Box>> + Send + Sync + 'static>, ); @@ -176,8 +197,8 @@ impl Default for Builder { Self { svcs: Default::default(), discovery: crate::discovery::Endpoint { - max_protocol_version: 2, - min_protocol_version: 2, + max_protocol_version: 4, + min_protocol_version: 4, protocol_mode: Some(crate::discovery::ProtocolMode::BidiStream), services: vec![], }, @@ -274,13 +295,11 @@ impl Endpoint { } return Ok(Response::ReplyNow { - response_head: ResponseHead { - status_code: 200, - headers: vec![Header { - key: "content-type".into(), - value: DISCOVERY_CONTENT_TYPE.into(), - }], - }, + status_code: 200, + headers: vec![Header { + key: "content-type".into(), + value: DISCOVERY_CONTENT_TYPE.into(), + }], body: Bytes::from( serde_json::to_string(&self.0.discovery) .expect("Discovery should be serializable"), @@ -296,13 +315,16 @@ impl Endpoint { Some(last_elements) => (last_elements[1].to_owned(), last_elements[2].to_owned()), }; - let vm = CoreVM::new(headers).map_err(ErrorInner::VM)?; + let vm = CoreVM::new(headers, Default::default()).map_err(ErrorInner::VM)?; if !self.0.svcs.contains_key(&svc_name) { return Err(ErrorInner::UnknownService(svc_name.to_owned()).into()); } + let response_head = vm.get_response_head(); + Ok(Response::BidiStream { - response_head: vm.get_response_head(), + status_code: response_head.status_code, + headers: response_head.headers, handler: BidiStreamRunner { svc_name, handler_name, @@ -315,11 +337,13 @@ impl Endpoint { pub enum Response { ReplyNow { - response_head: ResponseHead, + status_code: u16, + headers: Vec
, body: Bytes, }, BidiStream { - response_head: ResponseHead, + status_code: u16, + headers: Vec
, handler: BidiStreamRunner, }, } @@ -391,8 +415,7 @@ async fn init_loop_vm(vm: &mut CoreVM, input_rx: &mut InputReceiver) -> Result<( match input_rx.recv().await { Some(Ok(b)) => vm.notify_input(b), Some(Err(e)) => vm.notify_error( - "Error when reading the body".into(), - e.to_string().into(), + CoreError::new(500u16, format!("Error when reading the body: {e}")), None, ), None => vm.notify_input_closed(), diff --git a/src/errors.rs b/src/errors.rs index b56064e..5f9783d 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -16,7 +16,7 @@ //! ``` //! //! You can catch terminal exceptions. For example, you can catch the terminal exception that comes out of a [call to another service][crate::context::ContextClient], and build your control flow around it. -use restate_sdk_shared_core::Failure; +use restate_sdk_shared_core::TerminalFailure; use std::error::Error as StdError; use std::fmt; @@ -141,8 +141,8 @@ impl AsRef for TerminalError { } } -impl From for TerminalError { - fn from(value: Failure) -> Self { +impl From for TerminalError { + fn from(value: TerminalFailure) -> Self { Self(TerminalErrorInner { code: value.code, message: value.message, @@ -150,7 +150,7 @@ impl From for TerminalError { } } -impl From for Failure { +impl From for TerminalFailure { fn from(value: TerminalError) -> Self { Self { code: value.0.code, diff --git a/src/hyper.rs b/src/hyper.rs index 4367a16..a5445e6 100644 --- a/src/hyper.rs +++ b/src/hyper.rs @@ -10,7 +10,7 @@ use http::{response, HeaderName, HeaderValue, Request, Response}; use http_body_util::{BodyExt, Either, Full}; use hyper::body::{Body, Frame, Incoming}; use hyper::service::Service; -use restate_sdk_shared_core::ResponseHead; +use restate_sdk_shared_core::Header; use std::convert::Infallible; use std::future::{ready, Ready}; use std::ops::Deref; @@ -56,13 +56,18 @@ impl Service> for HyperEndpoint { match endpoint_response { endpoint::Response::ReplyNow { - response_head, + status_code, + headers, body, - } => ready(Ok(response_builder_from_response_head(response_head) - .body(Either::Left(Full::new(body))) - .expect("Headers should be valid"))), + } => ready(Ok(response_builder_from_response_head( + status_code, + headers, + ) + .body(Either::Left(Full::new(body))) + .expect("Headers should be valid"))), endpoint::Response::BidiStream { - response_head, + status_code, + headers, handler, } => { let input_receiver = @@ -73,24 +78,30 @@ impl Service> for HyperEndpoint { let handler_fut = Box::pin(handler.handle(input_receiver, output_sender)); - ready(Ok(response_builder_from_response_head(response_head) - .body(Either::Right(BidiStreamRunner { - fut: Some(handler_fut), - output_rx, - end_stream: false, - })) - .expect("Headers should be valid"))) + ready(Ok(response_builder_from_response_head( + status_code, + headers, + ) + .body(Either::Right(BidiStreamRunner { + fut: Some(handler_fut), + output_rx, + end_stream: false, + })) + .expect("Headers should be valid"))) } } } } -fn response_builder_from_response_head(response_head: ResponseHead) -> response::Builder { +fn response_builder_from_response_head( + status_code: u16, + headers: Vec
, +) -> response::Builder { let mut response_builder = Response::builder() - .status(response_head.status_code) + .status(status_code) .header(X_RESTATE_SERVER, X_RESTATE_SERVER_VALUE); - for header in response_head.headers { + for header in headers { response_builder = response_builder.header(header.key.deref(), header.value.deref()); } diff --git a/src/lib.rs b/src/lib.rs index 930e9d6..842211b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -506,9 +506,10 @@ pub mod prelude { pub use crate::http_server::HttpServer; pub use crate::context::{ - Context, ContextAwakeables, ContextClient, ContextPromises, ContextReadState, - ContextSideEffects, ContextTimers, ContextWriteState, HeaderMap, ObjectContext, Request, - RunFuture, RunRetryPolicy, SharedObjectContext, SharedWorkflowContext, WorkflowContext, + CallFuture, Context, ContextAwakeables, ContextClient, ContextPromises, ContextReadState, + ContextSideEffects, ContextTimers, ContextWriteState, HeaderMap, InvocationHandle, + ObjectContext, Request, RunFuture, RunRetryPolicy, SharedObjectContext, + SharedWorkflowContext, WorkflowContext, }; pub use crate::endpoint::Endpoint; pub use crate::errors::{HandlerError, HandlerResult, TerminalError}; diff --git a/test-services/Dockerfile b/test-services/Dockerfile index c041674..50e4ff7 100644 --- a/test-services/Dockerfile +++ b/test-services/Dockerfile @@ -7,5 +7,6 @@ RUN cargo build -p test-services RUN cp ./target/debug/test-services /bin/server ENV RUST_LOG="debug,restate_shared_core=trace" +ENV RUST_BACKTRACE=1 CMD ["/bin/server"] \ No newline at end of file diff --git a/test-services/README.md b/test-services/README.md index 259a391..b08ef7c 100644 --- a/test-services/README.md +++ b/test-services/README.md @@ -9,5 +9,5 @@ $ podman build -f test-services/Dockerfile -t restatedev/rust-test-services . To run (download the [sdk-test-suite](https://github.com/restatedev/sdk-test-suite) first): ```shell -$ java -jar restate-sdk-test-suite.jar run restatedev/rust-test-services +$ java -jar restate-sdk-test-suite.jar run localhost/restatedev/rust-test-services:latest ``` \ No newline at end of file diff --git a/test-services/exclusions.yaml b/test-services/exclusions.yaml index c868ed3..5c9150a 100644 --- a/test-services/exclusions.yaml +++ b/test-services/exclusions.yaml @@ -1,14 +1,21 @@ exclusions: "alwaysSuspending": - - "dev.restate.sdktesting.tests.AwaitTimeout" + - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwaitAny" + - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwakeableTimeoutCommand" + - "dev.restate.sdktesting.tests.Combinators.firstSuccessfulCompletedAwakeable" "default": - - "dev.restate.sdktesting.tests.AwaitTimeout" - - "dev.restate.sdktesting.tests.RawHandler" + - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwaitAny" + - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwakeableTimeoutCommand" + - "dev.restate.sdktesting.tests.Combinators.firstSuccessfulCompletedAwakeable" "singleThreadSinglePartition": - - "dev.restate.sdktesting.tests.AwaitTimeout" - - "dev.restate.sdktesting.tests.RawHandler" + - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwaitAny" + - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwakeableTimeoutCommand" + - "dev.restate.sdktesting.tests.Combinators.firstSuccessfulCompletedAwakeable" "threeNodes": - - "dev.restate.sdktesting.tests.AwaitTimeout" - - "dev.restate.sdktesting.tests.RawHandler" + - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwaitAny" + - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwakeableTimeoutCommand" + - "dev.restate.sdktesting.tests.Combinators.firstSuccessfulCompletedAwakeable" "threeNodesAlwaysSuspending": - - "dev.restate.sdktesting.tests.AwaitTimeout" + - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwaitAny" + - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwakeableTimeoutCommand" + - "dev.restate.sdktesting.tests.Combinators.firstSuccessfulCompletedAwakeable" diff --git a/test-services/src/main.rs b/test-services/src/main.rs index d4f261a..bde49a9 100644 --- a/test-services/src/main.rs +++ b/test-services/src/main.rs @@ -9,6 +9,7 @@ mod map_object; mod non_deterministic; mod proxy; mod test_utils_service; +mod virtual_object_command_interpreter; use restate_sdk::prelude::{Endpoint, HttpServer}; use std::env; @@ -76,6 +77,13 @@ async fn main() { test_utils_service::TestUtilsServiceImpl, )) } + if services == "*" || services.contains("VirtualObjectCommandInterpreter") { + builder = builder.bind( + virtual_object_command_interpreter::VirtualObjectCommandInterpreter::serve( + virtual_object_command_interpreter::VirtualObjectCommandInterpreterImpl, + ), + ) + } if let Ok(key) = env::var("E2E_REQUEST_SIGNING_ENV") { builder = builder.identity_key(&key).unwrap() diff --git a/test-services/src/proxy.rs b/test-services/src/proxy.rs index 36954f6..6b1f221 100644 --- a/test-services/src/proxy.rs +++ b/test-services/src/proxy.rs @@ -11,6 +11,7 @@ pub(crate) struct ProxyRequest { service_name: String, virtual_object_key: Option, handler_name: String, + idempotency_key: Option, message: Vec, delay_millis: Option, } @@ -46,7 +47,7 @@ pub(crate) trait Proxy { #[name = "call"] async fn call(req: Json) -> HandlerResult>>; #[name = "oneWayCall"] - async fn one_way_call(req: Json) -> HandlerResult<()>; + async fn one_way_call(req: Json) -> HandlerResult; #[name = "manyCalls"] async fn many_calls(req: Json>) -> HandlerResult<()>; } @@ -59,27 +60,33 @@ impl Proxy for ProxyImpl { ctx: Context<'_>, Json(req): Json, ) -> HandlerResult>> { - Ok(ctx - .request::, Vec>(req.to_target(), req.message) - .call() - .await? - .into()) + let mut request = ctx.request::, Vec>(req.to_target(), req.message); + if let Some(idempotency_key) = req.idempotency_key { + request = request.idempotency_key(idempotency_key); + } + Ok(request.call().await?.into()) } async fn one_way_call( &self, ctx: Context<'_>, Json(req): Json, - ) -> HandlerResult<()> { - let request = ctx.request::<_, ()>(req.to_target(), req.message); + ) -> HandlerResult { + let mut request = ctx.request::<_, ()>(req.to_target(), req.message); + if let Some(idempotency_key) = req.idempotency_key { + request = request.idempotency_key(idempotency_key); + } - if let Some(delay_millis) = req.delay_millis { - request.send_with_delay(Duration::from_millis(delay_millis)); + let invocation_id = if let Some(delay_millis) = req.delay_millis { + request + .send_after(Duration::from_millis(delay_millis)) + .invocation_id() + .await? } else { - request.send(); - } + request.send().invocation_id().await? + }; - Ok(()) + Ok(invocation_id) } async fn many_calls( @@ -90,11 +97,14 @@ impl Proxy for ProxyImpl { let mut futures: Vec, TerminalError>>> = vec![]; for req in requests { - let restate_req = + let mut restate_req = ctx.request::<_, Vec>(req.proxy_request.to_target(), req.proxy_request.message); + if let Some(idempotency_key) = req.proxy_request.idempotency_key { + restate_req = restate_req.idempotency_key(idempotency_key); + } if req.one_way_call { if let Some(delay_millis) = req.proxy_request.delay_millis { - restate_req.send_with_delay(Duration::from_millis(delay_millis)); + restate_req.send_after(Duration::from_millis(delay_millis)); } else { restate_req.send(); } diff --git a/test-services/src/test_utils_service.rs b/test-services/src/test_utils_service.rs index a1d84a1..0152062 100644 --- a/test-services/src/test_utils_service.rs +++ b/test-services/src/test_utils_service.rs @@ -1,48 +1,12 @@ -use crate::awakeable_holder; -use crate::list_object::ListObjectClient; use futures::future::BoxFuture; use futures::FutureExt; use restate_sdk::prelude::*; -use serde::{Deserialize, Serialize}; use std::collections::HashMap; +use std::convert::Infallible; use std::sync::atomic::{AtomicU8, Ordering}; use std::sync::Arc; use std::time::Duration; -#[derive(Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct CreateAwakeableAndAwaitItRequest { - awakeable_key: String, - await_timeout: Option, -} - -#[derive(Serialize, Deserialize)] -#[serde(tag = "type")] -#[serde(rename_all_fields = "camelCase")] -pub(crate) enum CreateAwakeableAndAwaitItResponse { - #[serde(rename = "timeout")] - Timeout, - #[serde(rename = "result")] - Result { value: String }, -} - -#[derive(Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct InterpretRequest { - list_name: String, - commands: Vec, -} - -#[derive(Serialize, Deserialize)] -#[serde(tag = "type")] -#[serde(rename_all_fields = "camelCase")] -pub(crate) enum InterpretCommand { - #[serde(rename = "createAwakeableAndAwaitIt")] - CreateAwakeableAndAwaitIt { awakeable_key: String }, - #[serde(rename = "getEnvVariable")] - GetEnvVariable { env_name: String }, -} - #[restate_sdk::service] #[name = "TestUtilsService"] pub(crate) trait TestUtilsService { @@ -50,20 +14,16 @@ pub(crate) trait TestUtilsService { async fn echo(input: String) -> HandlerResult; #[name = "uppercaseEcho"] async fn uppercase_echo(input: String) -> HandlerResult; + #[name = "rawEcho"] + async fn raw_echo(input: Vec) -> Result, Infallible>; #[name = "echoHeaders"] async fn echo_headers() -> HandlerResult>>; - #[name = "createAwakeableAndAwaitIt"] - async fn create_awakeable_and_await_it( - req: Json, - ) -> HandlerResult>; #[name = "sleepConcurrently"] async fn sleep_concurrently(millis_durations: Json>) -> HandlerResult<()>; #[name = "countExecutedSideEffects"] async fn count_executed_side_effects(increments: u32) -> HandlerResult; - #[name = "getEnvVariable"] - async fn get_env_variable(env: String) -> HandlerResult; - #[name = "interpretCommands"] - async fn interpret_commands(req: Json) -> HandlerResult<()>; + #[name = "cancelInvocation"] + async fn cancel_invocation(invocation_id: String) -> Result<(), TerminalError>; } pub(crate) struct TestUtilsServiceImpl; @@ -77,6 +37,10 @@ impl TestUtilsService for TestUtilsServiceImpl { Ok(input.to_ascii_uppercase()) } + async fn raw_echo(&self, _: Context<'_>, input: Vec) -> Result, Infallible> { + Ok(input) + } + async fn echo_headers( &self, context: Context<'_>, @@ -92,27 +56,6 @@ impl TestUtilsService for TestUtilsServiceImpl { Ok(headers.into()) } - async fn create_awakeable_and_await_it( - &self, - context: Context<'_>, - Json(req): Json, - ) -> HandlerResult> { - if req.await_timeout.is_some() { - unimplemented!("await timeout is not yet implemented"); - } - - let (awk_id, awakeable) = context.awakeable::(); - - context - .object_client::(req.awakeable_key) - .hold(awk_id) - .call() - .await?; - let value = awakeable.await?; - - Ok(CreateAwakeableAndAwaitItResponse::Result { value }.into()) - } - async fn sleep_concurrently( &self, context: Context<'_>, @@ -151,37 +94,12 @@ impl TestUtilsService for TestUtilsServiceImpl { Ok(counter.load(Ordering::SeqCst) as u32) } - async fn get_env_variable(&self, _: Context<'_>, env: String) -> HandlerResult { - Ok(std::env::var(env).ok().unwrap_or_default()) - } - - async fn interpret_commands( + async fn cancel_invocation( &self, - context: Context<'_>, - Json(req): Json, - ) -> HandlerResult<()> { - let list_client = context.object_client::(req.list_name); - - for cmd in req.commands { - match cmd { - InterpretCommand::CreateAwakeableAndAwaitIt { awakeable_key } => { - let (awk_id, awakeable) = context.awakeable::(); - context - .object_client::(awakeable_key) - .hold(awk_id) - .call() - .await?; - let value = awakeable.await?; - list_client.append(value).send(); - } - InterpretCommand::GetEnvVariable { env_name } => { - list_client - .append(std::env::var(env_name).ok().unwrap_or_default()) - .send(); - } - } - } - + ctx: Context<'_>, + invocation_id: String, + ) -> Result<(), TerminalError> { + ctx.invocation_handle(invocation_id).cancel().await?; Ok(()) } } diff --git a/test-services/src/virtual_object_command_interpreter.rs b/test-services/src/virtual_object_command_interpreter.rs new file mode 100644 index 0000000..994d20c --- /dev/null +++ b/test-services/src/virtual_object_command_interpreter.rs @@ -0,0 +1,249 @@ +use anyhow::anyhow; +use futures::TryFutureExt; +use restate_sdk::prelude::*; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +#[derive(Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct InterpretRequest { + commands: Vec, +} + +#[derive(Serialize, Deserialize)] +#[serde(tag = "type")] +#[serde(rename_all_fields = "camelCase")] +pub(crate) enum Command { + #[serde(rename = "awaitAnySuccessful")] + AwaitAnySuccessful { commands: Vec }, + #[serde(rename = "awaitAny")] + AwaitAny { commands: Vec }, + #[serde(rename = "awaitOne")] + AwaitOne { command: AwaitableCommand }, + #[serde(rename = "awaitAwakeableOrTimeout")] + AwaitAwakeableOrTimeout { + awakeable_key: String, + timeout_millis: u64, + }, + #[serde(rename = "resolveAwakeable")] + ResolveAwakeable { + awakeable_key: String, + value: String, + }, + #[serde(rename = "rejectAwakeable")] + RejectAwakeable { + awakeable_key: String, + reason: String, + }, + #[serde(rename = "getEnvVariable")] + GetEnvVariable { env_name: String }, +} + +#[derive(Serialize, Deserialize)] +#[serde(tag = "type")] +#[serde(rename_all_fields = "camelCase")] +pub(crate) enum AwaitableCommand { + #[serde(rename = "createAwakeable")] + CreateAwakeable { awakeable_key: String }, + #[serde(rename = "sleep")] + Sleep { timeout_millis: u64 }, + #[serde(rename = "runThrowTerminalException")] + RunThrowTerminalException { reason: String }, +} + +#[derive(Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct ResolveAwakeable { + awakeable_key: String, + value: String, +} + +#[derive(Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct RejectAwakeable { + awakeable_key: String, + reason: String, +} + +#[restate_sdk::object] +#[name = "VirtualObjectCommandInterpreter"] +pub(crate) trait VirtualObjectCommandInterpreter { + #[name = "interpretCommands"] + async fn interpret_commands(req: Json) -> HandlerResult; + + #[name = "resolveAwakeable"] + #[shared] + async fn resolve_awakeable(req: Json) -> HandlerResult<()>; + + #[name = "rejectAwakeable"] + #[shared] + async fn reject_awakeable(req: Json) -> HandlerResult<()>; + + #[name = "hasAwakeable"] + #[shared] + async fn has_awakeable(awakeable_key: String) -> HandlerResult; + + #[name = "getResults"] + #[shared] + async fn get_results() -> HandlerResult>>; +} + +pub(crate) struct VirtualObjectCommandInterpreterImpl; + +impl VirtualObjectCommandInterpreter for VirtualObjectCommandInterpreterImpl { + async fn interpret_commands( + &self, + context: ObjectContext<'_>, + Json(req): Json, + ) -> HandlerResult { + let mut last_result: String = Default::default(); + + for cmd in req.commands { + match cmd { + Command::AwaitAny { .. } => { + Err(anyhow!("AwaitAny is currently unsupported in the Rust SDK"))? + } + Command::AwaitAnySuccessful { .. } => Err(anyhow!( + "AwaitAnySuccessful is currently unsupported in the Rust SDK" + ))?, + Command::AwaitAwakeableOrTimeout { .. } => Err(anyhow!( + "AwaitAwakeableOrTimeout is currently unsupported in the Rust SDK" + ))?, + Command::AwaitOne { command } => { + last_result = match command { + AwaitableCommand::CreateAwakeable { awakeable_key } => { + let (awakeable_id, fut) = context.awakeable::(); + context.set::(&format!("awk-{awakeable_key}"), awakeable_id); + fut.await? + } + AwaitableCommand::Sleep { timeout_millis } => { + context + .sleep(Duration::from_millis(timeout_millis)) + .map_ok(|_| "sleep".to_string()) + .await? + } + AwaitableCommand::RunThrowTerminalException { reason } => { + context + .run::<_, _, String>( + || async move { Err(TerminalError::new(reason))? }, + ) + .await? + } + } + } + Command::GetEnvVariable { env_name } => { + last_result = std::env::var(env_name).ok().unwrap_or_default(); + } + Command::ResolveAwakeable { + awakeable_key, + value, + } => { + let Some(awakeable_id) = context + .get::(&format!("awk-{awakeable_key}")) + .await? + else { + Err(TerminalError::new( + "Awakeable is not registered yet".to_string(), + ))? + }; + + context.resolve_awakeable(&awakeable_id, value); + last_result = Default::default(); + } + Command::RejectAwakeable { + awakeable_key, + reason, + } => { + let Some(awakeable_id) = context + .get::(&format!("awk-{awakeable_key}")) + .await? + else { + Err(TerminalError::new( + "Awakeable is not registered yet".to_string(), + ))? + }; + + context.reject_awakeable(&awakeable_id, TerminalError::new(reason)); + last_result = Default::default(); + } + } + + let mut old_results = context + .get::>>("results") + .await? + .unwrap_or_default() + .into_inner(); + old_results.push(last_result.clone()); + context.set("results", Json(old_results)); + } + + Ok(last_result) + } + + async fn resolve_awakeable( + &self, + context: SharedObjectContext<'_>, + req: Json, + ) -> Result<(), HandlerError> { + let ResolveAwakeable { + awakeable_key, + value, + } = req.into_inner(); + let Some(awakeable_id) = context + .get::(&format!("awk-{awakeable_key}")) + .await? + else { + Err(TerminalError::new( + "Awakeable is not registered yet".to_string(), + ))? + }; + + context.resolve_awakeable(&awakeable_id, value); + + Ok(()) + } + + async fn reject_awakeable( + &self, + context: SharedObjectContext<'_>, + req: Json, + ) -> Result<(), HandlerError> { + let RejectAwakeable { + awakeable_key, + reason, + } = req.into_inner(); + let Some(awakeable_id) = context + .get::(&format!("awk-{awakeable_key}")) + .await? + else { + Err(TerminalError::new( + "Awakeable is not registered yet".to_string(), + ))? + }; + + context.reject_awakeable(&awakeable_id, TerminalError::new(reason)); + + Ok(()) + } + + async fn has_awakeable( + &self, + context: SharedObjectContext<'_>, + awakeable_key: String, + ) -> Result { + Ok(context + .get::(&format!("awk-{awakeable_key}")) + .await? + .is_some()) + } + + async fn get_results( + &self, + context: SharedObjectContext<'_>, + ) -> Result>, HandlerError> { + Ok(context + .get::>>("results") + .await? + .unwrap_or_default()) + } +}