From 885826cc3d7755f93ccc8fee66f07c1adaff6ec9 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Wed, 18 Sep 2024 14:13:45 +0200 Subject: [PATCH 01/10] IDK what i'm doing here --- Cargo.toml | 4 +- examples/cron.rs | 2 +- src/context/mod.rs | 13 +- src/context/request.rs | 15 +- src/endpoint/context.rs | 331 +++++++++++++++++++--- src/endpoint/futures/async_result_poll.rs | 8 +- src/endpoint/futures/intercept_error.rs | 15 +- src/endpoint/futures/trap.rs | 12 + src/endpoint/mod.rs | 37 +-- src/errors.rs | 8 +- src/lib.rs | 7 +- test-services/Dockerfile | 1 + test-services/src/proxy.rs | 19 +- test-services/src/test_utils_service.rs | 12 + 14 files changed, 396 insertions(+), 88 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 97dc91d..a5104b2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ default = ["http_server", "rand", "uuid"] hyper = ["dep:hyper", "http-body-util", "restate-sdk-shared-core/http"] http_server = ["hyper", "hyper/server", "hyper/http2", "hyper-util", "tokio/net", "tokio/signal", "tokio/macros"] + [dependencies] bytes = "1.6.1" futures = "0.3" @@ -23,8 +24,7 @@ pin-project-lite = "0.2" rand = { version = "0.8.5", optional = true } regress = "0.10" restate-sdk-macros = { version = "0.3.2", path = "macros" } -restate-sdk-shared-core = "0.1.0" -sha2 = "=0.11.0-pre.3" +restate-sdk-shared-core = { git = "https://github.com/restatedev/sdk-shared-core.git", branch = "main", features = ["request_identity", "sha2_random_seed", "http"] } serde = "1.0" serde_json = "1.0" thiserror = "1.0.63" diff --git a/examples/cron.rs b/examples/cron.rs index 4b43f19..eef7107 100644 --- a/examples/cron.rs +++ b/examples/cron.rs @@ -77,7 +77,7 @@ impl PeriodicTaskImpl { .object_client::(context.key()) .run() // And send with a delay - .send_with_delay(Duration::from_secs(10)); + .send_after(Duration::from_secs(10)); } } diff --git a/src/context/mod.rs b/src/context/mod.rs index 08fa37a..0f4e363 100644 --- a/src/context/mod.rs +++ b/src/context/mod.rs @@ -8,7 +8,7 @@ use std::time::Duration; mod request; mod run; -pub use request::{Request, RequestTarget}; +pub use request::{CallFuture, InvocationHandle, Request, RequestTarget}; pub use run::{RunClosure, RunFuture, RunRetryPolicy}; pub type HeaderMap = http::HeaderMap; @@ -433,6 +433,11 @@ pub trait ContextClient<'ctx>: private::SealedContext<'ctx> { Request::new(self.inner_context(), request_target, req) } + /// Create an [`InvocationHandle`] from an invocation id. + fn invocation_handle(&self, invocation_id: String) -> impl InvocationHandle + 'ctx { + self.inner_context().invocation_handle(invocation_id) + } + /// Create a service client. The service client is generated by the [`restate_sdk_macros::service`] macro with the same name of the trait suffixed with `Client`. /// /// ```rust,no_run @@ -454,7 +459,7 @@ pub trait ContextClient<'ctx>: private::SealedContext<'ctx> { /// client.handle().send(); /// /// // Schedule the request to be executed later - /// client.handle().send_with_delay(Duration::from_secs(60)); + /// client.handle().send_after(Duration::from_secs(60)); /// # } /// ``` fn service_client(&self) -> C @@ -485,7 +490,7 @@ pub trait ContextClient<'ctx>: private::SealedContext<'ctx> { /// client.handle().send(); /// /// // Schedule the request to be executed later - /// client.handle().send_with_delay(Duration::from_secs(60)); + /// client.handle().send_after(Duration::from_secs(60)); /// # } /// ``` fn object_client(&self, key: impl Into) -> C @@ -516,7 +521,7 @@ pub trait ContextClient<'ctx>: private::SealedContext<'ctx> { /// client.handle().send(); /// /// // Schedule the request to be executed later - /// client.handle().send_with_delay(Duration::from_secs(60)); + /// client.handle().send_after(Duration::from_secs(60)); /// # } /// ``` fn workflow_client(&self, key: impl Into) -> C diff --git a/src/context/request.rs b/src/context/request.rs index d0104f5..ac87e09 100644 --- a/src/context/request.rs +++ b/src/context/request.rs @@ -87,7 +87,7 @@ impl<'a, Req, Res> Request<'a, Req, Res> { } /// Call a service. This returns a future encapsulating the response. - pub fn call(self) -> impl Future> + Send + pub fn call(self) -> impl CallFuture> + Send where Req: Serialize + 'static, Res: Deserialize + 'static, @@ -96,7 +96,7 @@ impl<'a, Req, Res> Request<'a, Req, Res> { } /// Send the request to the service, without waiting for the response. - pub fn send(self) + pub fn send(self) -> impl InvocationHandle where Req: Serialize + 'static, { @@ -104,10 +104,17 @@ impl<'a, Req, Res> Request<'a, Req, Res> { } /// Schedule the request to the service, without waiting for the response. - pub fn send_with_delay(self, duration: Duration) + pub fn send_after(self, delay: Duration) -> impl InvocationHandle where Req: Serialize + 'static, { - self.ctx.send(self.request_target, self.req, Some(duration)) + self.ctx.send(self.request_target, self.req, Some(delay)) } } + +pub trait InvocationHandle { + fn invocation_id(&self) -> impl Future> + Send; + fn cancel(&self); +} + +pub trait CallFuture: Future + InvocationHandle {} diff --git a/src/endpoint/context.rs b/src/endpoint/context.rs index 536dcb5..ae2ff65 100644 --- a/src/endpoint/context.rs +++ b/src/endpoint/context.rs @@ -1,4 +1,6 @@ -use crate::context::{Request, RequestTarget, RunClosure, RunRetryPolicy}; +use crate::context::{ + CallFuture, InvocationHandle, Request, RequestTarget, RunClosure, RunRetryPolicy, +}; use crate::endpoint::futures::async_result_poll::VmAsyncResultPollFuture; use crate::endpoint::futures::intercept_error::InterceptErrorFuture; use crate::endpoint::futures::trap::TrapFuture; @@ -10,8 +12,8 @@ use futures::future::Either; use futures::{FutureExt, TryFutureExt}; use pin_project_lite::pin_project; use restate_sdk_shared_core::{ - CoreVM, Failure, NonEmptyValue, RetryPolicy, RunEnterResult, RunExitResult, TakeOutputResult, - Target, Value, VM, + CoreVM, NonEmptyValue, NotificationHandle, RetryPolicy, RunExitResult, SendHandle, + TakeOutputResult, Target, TerminalFailure, Value, VM, }; use std::borrow::Cow; use std::collections::HashMap; @@ -21,7 +23,7 @@ use std::mem; use std::pin::Pin; use std::sync::{Arc, Mutex}; use std::task::{ready, Context, Poll}; -use std::time::{Duration, Instant}; +use std::time::{Duration, Instant, SystemTime}; pub struct ContextInternalInner { pub(crate) vm: CoreVM, @@ -47,7 +49,7 @@ impl ContextInternalInner { pub(super) fn fail(&mut self, e: Error) { self.vm - .notify_error(e.0.to_string().into(), format!("{:#}", e.0).into(), None); + .notify_error(e.0.to_string().into(), format!("{:#}", e.0).into()); self.handler_state.mark_error(e); } } @@ -215,6 +217,10 @@ impl ContextInternal { variant: "state_keys", syscall: "get_state", }), + Ok(Value::InvocationId(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "invocation_id", + syscall: "get_state", + }), Err(e) => Err(e), }); @@ -230,11 +236,15 @@ impl ContextInternal { .map(|res| match res { Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { variant: "empty", - syscall: "get_state", + syscall: "get_keys", }), Ok(Value::Success(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { variant: "success", - syscall: "get_state", + syscall: "get_keys", + }), + Ok(Value::InvocationId(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "invocation_id", + syscall: "get_keys", }), Ok(Value::Failure(f)) => Ok(Err(f.into())), Ok(Value::StateKeys(s)) => Ok(Ok(s)), @@ -272,9 +282,16 @@ impl ContextInternal { pub fn sleep( &self, - duration: Duration, + sleep_duration: Duration, ) -> impl Future> + Send + Sync { - let maybe_handle = { must_lock!(self.inner).vm.sys_sleep(duration) }; + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .expect("Duration since unix epoch cannot fail"); + let maybe_handle = { + must_lock!(self.inner) + .vm + .sys_sleep(now + sleep_duration, Some(now)) + }; let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) .map(|res| match res { @@ -289,6 +306,10 @@ impl ContextInternal { variant: "state_keys", syscall: "sleep", }), + Ok(Value::InvocationId(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "invocation_id", + syscall: "sleep", + }), }); InterceptErrorFuture::new(self.clone(), poll_future.map_err(Error)) @@ -302,7 +323,7 @@ impl ContextInternal { &self, request_target: RequestTarget, req: Req, - ) -> impl Future> + Send + Sync { + ) -> impl CallFuture> + Send + Sync { let mut inner_lock = must_lock!(self.inner); let input = match Req::serialize(&req) { @@ -319,34 +340,20 @@ impl ContextInternal { } }; - let maybe_handle = inner_lock.vm.sys_call(request_target.into(), input); + let maybe_handle = inner_lock.vm.sys_call(request_target.into(), input).map(|ch | ch.call_notification_handle); drop(inner_lock); - let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) - .map(|res| match res { - Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "empty", - syscall: "call", - }), - Ok(Value::Success(mut s)) => { - let t = Res::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { - syscall: "call", - err: Box::new(e), - })?; - Ok(Ok(t)) - } - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "call", - }), - Err(e) => Err(e), - }); + let call_future_impl = CallFutureImpl { + poll_future: VmAsyncResultPollFuture::new( + Cow::Borrowed(&self.inner), + maybe_handle.clone(), + ), + res: PhantomData, + ctx: self.clone(), + call_handle: maybe_handle.ok(), + }; - Either::Left(InterceptErrorFuture::new( - self.clone(), - poll_future.map_err(Error), - )) + Either::Left(InterceptErrorFuture::new(self.clone(), call_future_impl)) } pub fn send( @@ -354,12 +361,26 @@ impl ContextInternal { request_target: RequestTarget, req: Req, delay: Option, - ) { + ) -> impl InvocationHandle { let mut inner_lock = must_lock!(self.inner); match Req::serialize(&req) { Ok(t) => { - let _ = inner_lock.vm.sys_send(request_target.into(), t, delay); + let result = inner_lock.vm.sys_send( + request_target.into(), + t, + delay.map(|delay| { + SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .expect("Duration since unix epoch cannot fail") + + delay + }), + ); + drop(inner_lock); + SendRequestHandle { + ctx: self.clone(), + send_handle: result.ok(), + } } Err(e) => { inner_lock.fail( @@ -369,8 +390,19 @@ impl ContextInternal { } .into(), ); + SendRequestHandle { + ctx: self.clone(), + send_handle: None, + } } - }; + } + } + + pub fn invocation_handle(&self, invocation_id: String) -> impl InvocationHandle { + InvocationIdBackedInvocationHandle { + ctx: self.clone(), + invocation_id, + } } pub fn awakeable( @@ -409,6 +441,10 @@ impl ContextInternal { variant: "state_keys", syscall: "awakeable", }), + Ok(Value::InvocationId(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "invocation_id", + syscall: "awakeable", + }), Err(e) => Err(e), }); @@ -468,6 +504,10 @@ impl ContextInternal { variant: "state_keys", syscall: "promise", }), + Ok(Value::InvocationId(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "invocation_id", + syscall: "promise", + }), Err(e) => Err(e), }); @@ -495,6 +535,10 @@ impl ContextInternal { variant: "state_keys", syscall: "peek_promise", }), + Ok(Value::InvocationId(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "invocation_id", + syscall: "peek_promise", + }), Err(e) => Err(e), }); @@ -716,7 +760,7 @@ where Err(e) => match e.0 { HandlerErrorInner::Retryable(err) => RunExitResult::RetryableFailure { attempt_duration: start_time.elapsed(), - failure: Failure { + failure: TerminalFailure { code: 500, message: err.to_string(), }, @@ -766,9 +810,218 @@ where syscall: "run", } .into()), + Value::InvocationId(_) => { + Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "invocation_id", + syscall: "run", + } + .into()) + } }); } } } } } + +struct SendRequestHandle { + ctx: ContextInternal, + send_handle: Option, +} + +impl InvocationHandle for SendRequestHandle { + fn invocation_id(&self) -> impl Future> + Send { + if let Some(ref send_handle) = self.send_handle { + let maybe_handle = { + must_lock!(self.ctx.inner) + .vm + .sys_get_call_invocation_id(GetInvocationIdTarget::SendEntry(*send_handle)) + }; + + let poll_future = VmAsyncResultPollFuture::new( + Cow::Borrowed(&self.ctx.inner), + maybe_handle, + ) + .map(|res| match res { + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(Value::InvocationId(s)) => Ok(Ok(s)), + Err(e) => Err(e), + Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "state_keys", + syscall: "get_call_invocation_id", + }), + Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "void", + syscall: "get_call_invocation_id", + }), + Ok(Value::Success(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "success", + syscall: "get_call_invocation_id", + }), + }); + + Either::Left(InterceptErrorFuture::new( + self.ctx.clone(), + poll_future.map_err(Error), + )) + } else { + // If the send didn't succeed, trap the execution + Either::Right(TrapFuture::default()) + } + } + + fn cancel(&self) { + if let Some(ref send_handle) = self.send_handle { + let mut inner_lock = must_lock!(self.ctx.inner); + let _ = inner_lock + .vm + .sys_cancel_invocation(CancelInvocationTarget::SendEntry(*send_handle)); + } + // If the send didn't succeed, then simply ignore the cancel + } +} + +pin_project! { + struct CallFutureImpl { + #[pin] + poll_future: VmAsyncResultPollFuture, + res: PhantomData R>, + ctx: ContextInternal, + call_handle: Option, + } +} + +impl Future for CallFutureImpl { + type Output = Result, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + this.poll_future + .poll(cx) + .map(|res| match res { + Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "empty", + syscall: "call", + }), + Ok(Value::Success(mut s)) => { + let t = Res::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { + syscall: "call", + err: Box::new(e), + })?; + Ok(Ok(t)) + } + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "state_keys", + syscall: "call", + }), + Ok(Value::InvocationId(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "invocation_id", + syscall: "call", + }), + Err(e) => Err(e), + }) + .map(|res| res.map_err(Error)) + } +} + +impl InvocationHandle for CallFutureImpl { + fn invocation_id(&self) -> impl Future> + Send { + if let Some(ref call_handle) = self.call_handle { + let maybe_handle = { + must_lock!(self.ctx.inner) + .vm + .sys_get_call_invocation_id(GetInvocationIdTarget::CallEntry(*call_handle)) + }; + + let poll_future = VmAsyncResultPollFuture::new( + Cow::Borrowed(&self.ctx.inner), + maybe_handle, + ) + .map(|res| match res { + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(Value::InvocationId(s)) => Ok(Ok(s)), + Err(e) => Err(e), + Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "state_keys", + syscall: "get_call_invocation_id", + }), + Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "void", + syscall: "get_call_invocation_id", + }), + Ok(Value::Success(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: "success", + syscall: "get_call_invocation_id", + }), + }); + + Either::Left(InterceptErrorFuture::new( + self.ctx.clone(), + poll_future.map_err(Error), + )) + } else { + // If the send didn't succeed, trap the execution + Either::Right(TrapFuture::default()) + } + } + + fn cancel(&self) { + if let Some(ref call_handle) = self.call_handle { + let mut inner_lock = must_lock!(self.ctx.inner); + let _ = inner_lock + .vm + .sys_cancel_invocation(CancelInvocationTarget::CallEntry(*call_handle)); + } + // If the send didn't succeed, then simply ignore the cancel + } +} + +impl CallFuture, Error>> + for CallFutureImpl +{ +} + +impl InvocationHandle for Either { + fn invocation_id(&self) -> impl Future> + Send { + match self { + Either::Left(l) => Either::Left(l.invocation_id()), + Either::Right(r) => Either::Right(r.invocation_id()), + } + } + + fn cancel(&self) { + match self { + Either::Left(l) => l.cancel(), + Either::Right(r) => r.cancel(), + } + } +} + +impl CallFuture for Either +where + A: CallFuture, + B: CallFuture, +{ +} + +struct InvocationIdBackedInvocationHandle { + ctx: ContextInternal, + invocation_id: String, +} + +impl InvocationHandle for InvocationIdBackedInvocationHandle { + fn invocation_id(&self) -> impl Future> + Send { + ready(Ok(self.invocation_id.clone())) + } + + fn cancel(&self) { + let mut inner_lock = must_lock!(self.ctx.inner); + let _ = inner_lock + .vm + .sys_cancel_invocation(CancelInvocationTarget::InvocationId( + self.invocation_id.clone(), + )); + } +} diff --git a/src/endpoint/futures/async_result_poll.rs b/src/endpoint/futures/async_result_poll.rs index 8f6ef5d..23f7fa2 100644 --- a/src/endpoint/futures/async_result_poll.rs +++ b/src/endpoint/futures/async_result_poll.rs @@ -1,7 +1,7 @@ use crate::endpoint::context::ContextInternalInner; use crate::endpoint::ErrorInner; use restate_sdk_shared_core::{ - AsyncResultHandle, SuspendedOrVMError, TakeOutputResult, VMError, Value, VM, + NotificationHandle, SuspendedOrVMError, TakeOutputResult, Value, VM, }; use std::borrow::Cow; use std::future::Future; @@ -16,7 +16,7 @@ pub(crate) struct VmAsyncResultPollFuture { impl VmAsyncResultPollFuture { pub fn new( inner: Cow<'_, Arc>>, - handle: Result, + handle: Result, ) -> Self { VmAsyncResultPollFuture { state: Some(match handle { @@ -33,11 +33,11 @@ impl VmAsyncResultPollFuture { enum PollState { Init { ctx: Arc>, - handle: AsyncResultHandle, + handle: NotificationHandle, }, WaitingInput { ctx: Arc>, - handle: AsyncResultHandle, + handle: NotificationHandle, }, Failed(ErrorInner), } diff --git a/src/endpoint/futures/intercept_error.rs b/src/endpoint/futures/intercept_error.rs index b486fbf..606df95 100644 --- a/src/endpoint/futures/intercept_error.rs +++ b/src/endpoint/futures/intercept_error.rs @@ -1,5 +1,6 @@ -use crate::context::{RunFuture, RunRetryPolicy}; +use crate::context::{CallFuture, InvocationHandle, RunFuture, RunRetryPolicy}; use crate::endpoint::{ContextInternal, Error}; +use crate::errors::TerminalError; use pin_project_lite::pin_project; use std::future::Future; use std::pin::Pin; @@ -58,3 +59,15 @@ where self } } + +impl InvocationHandle for InterceptErrorFuture { + fn invocation_id(&self) -> impl Future> + Send { + self.fut.invocation_id() + } + + fn cancel(&self) { + self.fut.cancel() + } +} + +impl CallFuture for InterceptErrorFuture where F: CallFuture> {} diff --git a/src/endpoint/futures/trap.rs b/src/endpoint/futures/trap.rs index 9b0269d..b614de1 100644 --- a/src/endpoint/futures/trap.rs +++ b/src/endpoint/futures/trap.rs @@ -1,3 +1,5 @@ +use crate::context::{CallFuture, InvocationHandle}; +use crate::errors::TerminalError; use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; @@ -20,3 +22,13 @@ impl Future for TrapFuture { Poll::Pending } } + +impl InvocationHandle for TrapFuture { + fn invocation_id(&self) -> impl Future> + Send { + TrapFuture::default() + } + + fn cancel(&self) {} +} + +impl CallFuture for TrapFuture {} diff --git a/src/endpoint/mod.rs b/src/endpoint/mod.rs index 90f907a..3a20feb 100644 --- a/src/endpoint/mod.rs +++ b/src/endpoint/mod.rs @@ -10,9 +10,7 @@ use ::futures::future::BoxFuture; use ::futures::{Stream, StreamExt}; use bytes::Bytes; pub use context::{ContextInternal, InputMetadata}; -use restate_sdk_shared_core::{ - CoreVM, Header, HeaderMap, IdentityVerifier, KeyError, ResponseHead, VMError, VerifyError, VM, -}; +use restate_sdk_shared_core::{CoreVM, Error as VMError, Header, HeaderMap, IdentityVerifier, KeyError, VerifyError, VM}; use std::collections::HashMap; use std::future::poll_fn; use std::pin::Pin; @@ -176,8 +174,8 @@ impl Default for Builder { Self { svcs: Default::default(), discovery: crate::discovery::Endpoint { - max_protocol_version: 2, - min_protocol_version: 2, + max_protocol_version: 3, + min_protocol_version: 3, protocol_mode: Some(crate::discovery::ProtocolMode::BidiStream), services: vec![], }, @@ -274,13 +272,11 @@ impl Endpoint { } return Ok(Response::ReplyNow { - response_head: ResponseHead { - status_code: 200, - headers: vec![Header { - key: "content-type".into(), - value: DISCOVERY_CONTENT_TYPE.into(), - }], - }, + status_code: 200, + headers: vec![Header { + key: "content-type".into(), + value: DISCOVERY_CONTENT_TYPE.into(), + }], body: Bytes::from( serde_json::to_string(&self.0.discovery) .expect("Discovery should be serializable"), @@ -296,13 +292,16 @@ impl Endpoint { Some(last_elements) => (last_elements[1].to_owned(), last_elements[2].to_owned()), }; - let vm = CoreVM::new(headers).map_err(ErrorInner::VM)?; + let vm = CoreVM::new(headers, Default::default()).map_err(ErrorInner::VM)?; if !self.0.svcs.contains_key(&svc_name) { return Err(ErrorInner::UnknownService(svc_name.to_owned()).into()); } + let response_head = vm.get_response_head(); + Ok(Response::BidiStream { - response_head: vm.get_response_head(), + status_code: response_head.status_code, + headers: response_head.headers, handler: BidiStreamRunner { svc_name, handler_name, @@ -315,11 +314,13 @@ impl Endpoint { pub enum Response { ReplyNow { - response_head: ResponseHead, + status_code: u16, + headers: Vec
, body: Bytes, }, BidiStream { - response_head: ResponseHead, + status_code: u16, + headers: Vec
, handler: BidiStreamRunner, }, } @@ -391,8 +392,8 @@ async fn init_loop_vm(vm: &mut CoreVM, input_rx: &mut InputReceiver) -> Result<( match input_rx.recv().await { Some(Ok(b)) => vm.notify_input(b), Some(Err(e)) => vm.notify_error( - "Error when reading the body".into(), - e.to_string().into(), + VMError::new(500u16, format!("Error when reading the body: {e}")) + , None, ), None => vm.notify_input_closed(), diff --git a/src/errors.rs b/src/errors.rs index b56064e..5f9783d 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -16,7 +16,7 @@ //! ``` //! //! You can catch terminal exceptions. For example, you can catch the terminal exception that comes out of a [call to another service][crate::context::ContextClient], and build your control flow around it. -use restate_sdk_shared_core::Failure; +use restate_sdk_shared_core::TerminalFailure; use std::error::Error as StdError; use std::fmt; @@ -141,8 +141,8 @@ impl AsRef for TerminalError { } } -impl From for TerminalError { - fn from(value: Failure) -> Self { +impl From for TerminalError { + fn from(value: TerminalFailure) -> Self { Self(TerminalErrorInner { code: value.code, message: value.message, @@ -150,7 +150,7 @@ impl From for TerminalError { } } -impl From for Failure { +impl From for TerminalFailure { fn from(value: TerminalError) -> Self { Self { code: value.0.code, diff --git a/src/lib.rs b/src/lib.rs index 930e9d6..842211b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -506,9 +506,10 @@ pub mod prelude { pub use crate::http_server::HttpServer; pub use crate::context::{ - Context, ContextAwakeables, ContextClient, ContextPromises, ContextReadState, - ContextSideEffects, ContextTimers, ContextWriteState, HeaderMap, ObjectContext, Request, - RunFuture, RunRetryPolicy, SharedObjectContext, SharedWorkflowContext, WorkflowContext, + CallFuture, Context, ContextAwakeables, ContextClient, ContextPromises, ContextReadState, + ContextSideEffects, ContextTimers, ContextWriteState, HeaderMap, InvocationHandle, + ObjectContext, Request, RunFuture, RunRetryPolicy, SharedObjectContext, + SharedWorkflowContext, WorkflowContext, }; pub use crate::endpoint::Endpoint; pub use crate::errors::{HandlerError, HandlerResult, TerminalError}; diff --git a/test-services/Dockerfile b/test-services/Dockerfile index c041674..50e4ff7 100644 --- a/test-services/Dockerfile +++ b/test-services/Dockerfile @@ -7,5 +7,6 @@ RUN cargo build -p test-services RUN cp ./target/debug/test-services /bin/server ENV RUST_LOG="debug,restate_shared_core=trace" +ENV RUST_BACKTRACE=1 CMD ["/bin/server"] \ No newline at end of file diff --git a/test-services/src/proxy.rs b/test-services/src/proxy.rs index 36954f6..d443a94 100644 --- a/test-services/src/proxy.rs +++ b/test-services/src/proxy.rs @@ -46,7 +46,7 @@ pub(crate) trait Proxy { #[name = "call"] async fn call(req: Json) -> HandlerResult>>; #[name = "oneWayCall"] - async fn one_way_call(req: Json) -> HandlerResult<()>; + async fn one_way_call(req: Json) -> HandlerResult; #[name = "manyCalls"] async fn many_calls(req: Json>) -> HandlerResult<()>; } @@ -70,16 +70,19 @@ impl Proxy for ProxyImpl { &self, ctx: Context<'_>, Json(req): Json, - ) -> HandlerResult<()> { + ) -> HandlerResult { let request = ctx.request::<_, ()>(req.to_target(), req.message); - if let Some(delay_millis) = req.delay_millis { - request.send_with_delay(Duration::from_millis(delay_millis)); + let invocation_id = if let Some(delay_millis) = req.delay_millis { + request + .send_after(Duration::from_millis(delay_millis)) + .invocation_id() + .await? } else { - request.send(); - } + request.send().invocation_id().await? + }; - Ok(()) + Ok(invocation_id) } async fn many_calls( @@ -94,7 +97,7 @@ impl Proxy for ProxyImpl { ctx.request::<_, Vec>(req.proxy_request.to_target(), req.proxy_request.message); if req.one_way_call { if let Some(delay_millis) = req.proxy_request.delay_millis { - restate_req.send_with_delay(Duration::from_millis(delay_millis)); + restate_req.send_after(Duration::from_millis(delay_millis)); } else { restate_req.send(); } diff --git a/test-services/src/test_utils_service.rs b/test-services/src/test_utils_service.rs index a1d84a1..429e6a2 100644 --- a/test-services/src/test_utils_service.rs +++ b/test-services/src/test_utils_service.rs @@ -5,6 +5,7 @@ use futures::FutureExt; use restate_sdk::prelude::*; use serde::{Deserialize, Serialize}; use std::collections::HashMap; +use std::convert::Infallible; use std::sync::atomic::{AtomicU8, Ordering}; use std::sync::Arc; use std::time::Duration; @@ -62,6 +63,8 @@ pub(crate) trait TestUtilsService { async fn count_executed_side_effects(increments: u32) -> HandlerResult; #[name = "getEnvVariable"] async fn get_env_variable(env: String) -> HandlerResult; + #[name = "cancelInvocation"] + async fn cancel_invocation(invocation_id: String) -> Result<(), Infallible>; #[name = "interpretCommands"] async fn interpret_commands(req: Json) -> HandlerResult<()>; } @@ -155,6 +158,15 @@ impl TestUtilsService for TestUtilsServiceImpl { Ok(std::env::var(env).ok().unwrap_or_default()) } + async fn cancel_invocation( + &self, + ctx: Context<'_>, + invocation_id: String, + ) -> Result<(), Infallible> { + ctx.invocation_handle(invocation_id).cancel(); + Ok(()) + } + async fn interpret_commands( &self, context: Context<'_>, From 12b6bce3b41cbb47ffe827640eabb2b963a7008a Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Mon, 3 Mar 2025 18:52:57 +0100 Subject: [PATCH 02/10] This seems to be getting somewhere --- src/context/mod.rs | 2 +- src/context/request.rs | 23 +- src/endpoint/context.rs | 990 +++++++++++----------- src/endpoint/futures/async_result_poll.rs | 129 +-- src/endpoint/futures/intercept_error.rs | 2 +- src/endpoint/futures/trap.rs | 4 +- src/endpoint/mod.rs | 30 +- src/hyper.rs | 43 +- 8 files changed, 647 insertions(+), 576 deletions(-) diff --git a/src/context/mod.rs b/src/context/mod.rs index 0f4e363..b3425b6 100644 --- a/src/context/mod.rs +++ b/src/context/mod.rs @@ -632,7 +632,7 @@ pub trait ContextAwakeables<'ctx>: private::SealedContext<'ctx> { &self, ) -> ( String, - impl Future> + Send + Sync + 'ctx, + impl Future> + Send + 'ctx, ) { self.inner_context().awakeable() } diff --git a/src/context/request.rs b/src/context/request.rs index ac87e09..0284628 100644 --- a/src/context/request.rs +++ b/src/context/request.rs @@ -72,6 +72,7 @@ impl fmt::Display for RequestTarget { pub struct Request<'a, Req, Res = ()> { ctx: &'a ContextInternal, request_target: RequestTarget, + idempotency_key: Option, req: Req, res: PhantomData, } @@ -81,18 +82,26 @@ impl<'a, Req, Res> Request<'a, Req, Res> { Self { ctx, request_target, + idempotency_key: None, req, res: PhantomData, } } + /// Add idempotency key to the request + pub fn idempotency_key(mut self, idempotency_key: impl Into) -> Self { + self.idempotency_key = Some(idempotency_key.into()); + self + } + /// Call a service. This returns a future encapsulating the response. pub fn call(self) -> impl CallFuture> + Send where Req: Serialize + 'static, Res: Deserialize + 'static, { - self.ctx.call(self.request_target, self.req) + self.ctx + .call(self.request_target, self.idempotency_key, self.req) } /// Send the request to the service, without waiting for the response. @@ -100,7 +109,8 @@ impl<'a, Req, Res> Request<'a, Req, Res> { where Req: Serialize + 'static, { - self.ctx.send(self.request_target, self.req, None) + self.ctx + .send(self.request_target, self.idempotency_key, self.req, None) } /// Schedule the request to the service, without waiting for the response. @@ -108,13 +118,18 @@ impl<'a, Req, Res> Request<'a, Req, Res> { where Req: Serialize + 'static, { - self.ctx.send(self.request_target, self.req, Some(delay)) + self.ctx.send( + self.request_target, + self.idempotency_key, + self.req, + Some(delay), + ) } } pub trait InvocationHandle { fn invocation_id(&self) -> impl Future> + Send; - fn cancel(&self); + fn cancel(&self) -> impl Future> + Send; } pub trait CallFuture: Future + InvocationHandle {} diff --git a/src/endpoint/context.rs b/src/endpoint/context.rs index ae2ff65..5a0a317 100644 --- a/src/endpoint/context.rs +++ b/src/endpoint/context.rs @@ -8,12 +8,13 @@ use crate::endpoint::handler_state::HandlerStateNotifier; use crate::endpoint::{Error, ErrorInner, InputReceiver, OutputSender}; use crate::errors::{HandlerErrorInner, HandlerResult, TerminalError}; use crate::serde::{Deserialize, Serialize}; -use futures::future::Either; +use futures::future::{Either, Shared}; use futures::{FutureExt, TryFutureExt}; use pin_project_lite::pin_project; use restate_sdk_shared_core::{ - CoreVM, NonEmptyValue, NotificationHandle, RetryPolicy, RunExitResult, SendHandle, - TakeOutputResult, Target, TerminalFailure, Value, VM, + CallHandle as CoreCallHandle, CoreVM, DoProgressResponse, Error as CoreError, NonEmptyValue, + NotificationHandle, RetryPolicy, RunExitResult, SendHandle as CoreSendHandle, TakeOutputResult, + Target, TerminalFailure, Value, VM, }; use std::borrow::Cow; use std::collections::HashMap; @@ -48,8 +49,10 @@ impl ContextInternalInner { } pub(super) fn fail(&mut self, e: Error) { - self.vm - .notify_error(e.0.to_string().into(), format!("{:#}", e.0).into()); + self.vm.notify_error( + CoreError::new(500u16, e.0.to_string()).with_stacktrace(Cow::Owned(format!("{:#}", e.0))), + None, + ); self.handler_state.mark_error(e); } } @@ -96,6 +99,18 @@ macro_rules! must_lock { }; } +macro_rules! unwrap_or_trap { + ($inner_lock:expr, $res:expr) => { + match $res { + Ok(t) => t, + Err(e) => { + $inner_lock.fail(e.into()); + return Either::Right(TrapFuture::default()); + } + } + }; +} + #[derive(Debug, Eq, PartialEq)] pub struct InputMetadata { pub invocation_id: String, @@ -111,16 +126,22 @@ impl From for Target { service: name, handler, key: None, + idempotency_key: None, + headers: vec![], }, RequestTarget::Object { name, key, handler } => Target { service: name, handler, key: Some(key), + idempotency_key: None, + headers: vec![], }, RequestTarget::Workflow { name, key, handler } => Target { service: name, handler, key: Some(key), + idempotency_key: None, + headers: vec![], }, } } @@ -199,59 +220,45 @@ impl ContextInternal { pub fn get( &self, key: &str, - ) -> impl Future, TerminalError>> + Send + Sync { - let maybe_handle = { must_lock!(self.inner).vm.sys_state_get(key.to_owned()) }; + ) -> impl Future, TerminalError>> + Send { + let mut inner_lock = must_lock!(self.inner); + let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_state_get(key.to_owned())); + + let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res { + Ok(Value::Void) => Ok(Ok(None)), + Ok(Value::Success(mut s)) => { + let t = + T::deserialize(&mut s).map_err(|e| Error::deserialization("get_state", e))?; + Ok(Ok(Some(t))) + } + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: <&'static str>::from(v), + syscall: "get_state", + } + .into()), + Err(e) => Err(e.into()), + }); - let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) - .map(|res| match res { - Ok(Value::Void) => Ok(Ok(None)), - Ok(Value::Success(mut s)) => { - let t = T::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { - syscall: "get_state", - err: Box::new(e), - })?; - Ok(Ok(Some(t))) - } - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "get_state", - }), - Ok(Value::InvocationId(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "invocation_id", - syscall: "get_state", - }), - Err(e) => Err(e), - }); - - InterceptErrorFuture::new(self.clone(), poll_future.map_err(Error)) + Either::Left(InterceptErrorFuture::new(self.clone(), poll_future)) } - pub fn get_keys( - &self, - ) -> impl Future, TerminalError>> + Send + Sync { - let maybe_handle = { must_lock!(self.inner).vm.sys_state_get_keys() }; - - let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) - .map(|res| match res { - Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "empty", - syscall: "get_keys", - }), - Ok(Value::Success(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "success", - syscall: "get_keys", - }), - Ok(Value::InvocationId(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "invocation_id", - syscall: "get_keys", - }), - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Ok(Value::StateKeys(s)) => Ok(Ok(s)), - Err(e) => Err(e), - }); + pub fn get_keys(&self) -> impl Future, TerminalError>> + Send { + let mut inner_lock = must_lock!(self.inner); + let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_state_get_keys()); + + let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res { + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(Value::StateKeys(s)) => Ok(Ok(s)), + Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: <&'static str>::from(v), + syscall: "get_keys", + } + .into()), + Err(e) => Err(e.into()), + }); - InterceptErrorFuture::new(self.clone(), poll_future.map_err(Error)) + Either::Left(InterceptErrorFuture::new(self.clone(), poll_future)) } pub fn set(&self, key: &str, t: T) { @@ -261,13 +268,7 @@ impl ContextInternal { let _ = inner_lock.vm.sys_state_set(key.to_owned(), b); } Err(e) => { - inner_lock.fail( - ErrorInner::Serialization { - syscall: "set_state", - err: Box::new(e), - } - .into(), - ); + inner_lock.fail(Error::serialization("set_state", e)); } } } @@ -283,36 +284,28 @@ impl ContextInternal { pub fn sleep( &self, sleep_duration: Duration, - ) -> impl Future> + Send + Sync { + ) -> impl Future> + Send { let now = SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .expect("Duration since unix epoch cannot fail"); - let maybe_handle = { - must_lock!(self.inner) - .vm - .sys_sleep(now + sleep_duration, Some(now)) - }; + let mut inner_lock = must_lock!(self.inner); + let handle = unwrap_or_trap!( + inner_lock, + inner_lock.vm.sys_sleep(now + sleep_duration, Some(now)) + ); + + let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res { + Ok(Value::Void) => Ok(Ok(())), + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: <&'static str>::from(v), + syscall: "sleep", + } + .into()), + Err(e) => Err(e.into()), + }); - let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) - .map(|res| match res { - Ok(Value::Void) => Ok(Ok(())), - Ok(Value::Success(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "success", - syscall: "sleep", - }), - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Err(e) => Err(e), - Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "sleep", - }), - Ok(Value::InvocationId(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "invocation_id", - syscall: "sleep", - }), - }); - - InterceptErrorFuture::new(self.clone(), poll_future.map_err(Error)) + Either::Left(InterceptErrorFuture::new(self.clone(), poll_future)) } pub fn request(&self, request_target: RequestTarget, req: Req) -> Request { @@ -322,80 +315,129 @@ impl ContextInternal { pub fn call( &self, request_target: RequestTarget, + idempotency_key: Option, req: Req, - ) -> impl CallFuture> + Send + Sync { + ) -> impl CallFuture> + Send { let mut inner_lock = must_lock!(self.inner); - let input = match Req::serialize(&req) { - Ok(t) => t, - Err(e) => { - inner_lock.fail( - ErrorInner::Serialization { - syscall: "call", - err: Box::new(e), - } - .into(), - ); - return Either::Right(TrapFuture::default()); - } - }; + let mut target: Target = request_target.into(); + target.idempotency_key = idempotency_key; + let input = unwrap_or_trap!( + inner_lock, + Req::serialize(&req).map_err(|e| Error::serialization("call", e)) + ); - let maybe_handle = inner_lock.vm.sys_call(request_target.into(), input).map(|ch | ch.call_notification_handle); + let call_handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_call(target, input)); drop(inner_lock); - let call_future_impl = CallFutureImpl { - poll_future: VmAsyncResultPollFuture::new( - Cow::Borrowed(&self.inner), - maybe_handle.clone(), - ), + // Let's prepare the two futures here + let call_invocation_id_fut = InterceptErrorFuture::new( + self.clone(), + get_async_result( + Arc::clone(&self.inner), + call_handle.invocation_id_notification_handle, + ) + .map(|res| match res { + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(Value::InvocationId(s)) => Ok(Ok(s)), + Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: <&'static str>::from(v), + syscall: "call", + }.into()), + Err(e) => Err(e.into()), + }), + ); + let call_result_future = InterceptErrorFuture::new( + self.clone(), + get_async_result( + Arc::clone(&self.inner), + call_handle.call_notification_handle, + ) + .map(|res| match res { + Ok(Value::Success(mut s)) => { + let t = + Res::deserialize(&mut s).map_err(|e| Error::deserialization("call", e))?; + Ok(Ok(t)) + } + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: <&'static str>::from(v), + syscall: "call", + } + .into()), + Err(e) => Err(e.into()), + }), + ); + + Either::Left(CallFutureImpl { + invocation_id_future: call_invocation_id_fut.shared(), + result_future: call_result_future, res: PhantomData, ctx: self.clone(), - call_handle: maybe_handle.ok(), - }; - - Either::Left(InterceptErrorFuture::new(self.clone(), call_future_impl)) + }) } pub fn send( &self, request_target: RequestTarget, + idempotency_key: Option, req: Req, delay: Option, ) -> impl InvocationHandle { let mut inner_lock = must_lock!(self.inner); - match Req::serialize(&req) { - Ok(t) => { - let result = inner_lock.vm.sys_send( - request_target.into(), - t, - delay.map(|delay| { - SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .expect("Duration since unix epoch cannot fail") - + delay - }), - ); - drop(inner_lock); - SendRequestHandle { - ctx: self.clone(), - send_handle: result.ok(), - } + let mut target: Target = request_target.into(); + target.idempotency_key = idempotency_key; + + let input = match Req::serialize(&req) { + Ok(b,) => b, + Err(e) => { + inner_lock.fail(Error::serialization("call", e)); + return Either::Right(TrapFuture::<()>::default()) + } + }; + + let send_handle = match + inner_lock.vm.sys_send( + target, + input, + delay.map(|delay| { + SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .expect("Duration since unix epoch cannot fail") + + delay + }) + ) { + Ok(h,) => h, + Err(e) => { + inner_lock.fail(e.into()); + return Either::Right(TrapFuture::<()>::default()) } - Err(e) => { - inner_lock.fail( - ErrorInner::Serialization { - syscall: "call", - err: Box::new(e), - } - .into(), - ); - SendRequestHandle { - ctx: self.clone(), - send_handle: None, + }; + drop(inner_lock); + + let call_invocation_id_fut = InterceptErrorFuture::new( + self.clone(), + get_async_result( + Arc::clone(&self.inner), + send_handle.invocation_id_notification_handle, + ) + .map(|res| match res { + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(Value::InvocationId(s)) => Ok(Ok(s)), + Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: <&'static str>::from(v), + syscall: "call", } - } - } + .into()), + Err(e) => Err(e.into()), + }), + ); + + Either::Left(SendRequestHandle { + invocation_id_future: call_invocation_id_fut.shared(), + ctx: self.clone(), + }) } pub fn invocation_handle(&self, invocation_id: String) -> impl InvocationHandle { @@ -409,48 +451,40 @@ impl ContextInternal { &self, ) -> ( String, - impl Future> + Send + Sync, + impl Future> + Send, ) { - let maybe_awakeable_id_and_handle = { must_lock!(self.inner).vm.sys_awakeable() }; - - let (awakeable_id, maybe_handle) = match maybe_awakeable_id_and_handle { - Ok((s, handle)) => (s, Ok(handle)), - Err(e) => ( - // TODO NOW this is REALLY BAD. The reason for this is that we would need to return a future of a future instead, which is not nice. - // we assume for the time being this works because no user should use the awakeable without doing any other syscall first, which will prevent this invalid awakeable id to work in the first place. - "invalid".to_owned(), - Err(e), - ), + let mut inner_lock = must_lock!(self.inner); + let maybe_awakeable_id_and_handle = inner_lock.vm.sys_awakeable(); + + let (awakeable_id, handle) = match maybe_awakeable_id_and_handle { + Ok((s, handle)) => (s, handle), + Err(e) => { + inner_lock.fail(e.into()); + return ( + // TODO NOW this is REALLY BAD. The reason for this is that we would need to return a future of a future instead, which is not nice. + // we assume for the time being this works because no user should use the awakeable without doing any other syscall first, which will prevent this invalid awakeable id to work in the first place. + "invalid".to_owned(), + Either::Right(TrapFuture::default()), + ); + } }; + drop(inner_lock); - let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) - .map(|res| match res { - Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "empty", - syscall: "awakeable", - }), - Ok(Value::Success(mut s)) => { - let t = T::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { - syscall: "awakeable", - err: Box::new(e), - })?; - Ok(Ok(t)) - } - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "awakeable", - }), - Ok(Value::InvocationId(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "invocation_id", - syscall: "awakeable", - }), - Err(e) => Err(e), - }); + let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res { + Ok(Value::Success(mut s)) => { + Ok(Ok(T::deserialize(&mut s).map_err(|e| Error::deserialization("awakeable", e))?)) + } + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: <&'static str>::from(v), + syscall: "awakeable", + }.into()), + Err(e) => Err(e), + }); ( awakeable_id, - InterceptErrorFuture::new(self.clone(), poll_future.map_err(Error)), + Either::Left(InterceptErrorFuture::new(self.clone(), poll_future)), ) } @@ -463,13 +497,7 @@ impl ContextInternal { .sys_complete_awakeable(id.to_owned(), NonEmptyValue::Success(b)); } Err(e) => { - inner_lock.fail( - ErrorInner::Serialization { - syscall: "resolve_awakeable", - err: Box::new(e), - } - .into(), - ); + inner_lock.fail(Error::serialization("resolve_awakeable", e)); } } } @@ -483,66 +511,53 @@ impl ContextInternal { pub fn promise( &self, name: &str, - ) -> impl Future> + Send + Sync { - let maybe_handle = { must_lock!(self.inner).vm.sys_get_promise(name.to_owned()) }; + ) -> impl Future> + Send { + let mut inner_lock = must_lock!(self.inner); + let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_get_promise(name.to_owned())); + drop(inner_lock); - let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) - .map(|res| match res { - Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "empty", - syscall: "promise", - }), - Ok(Value::Success(mut s)) => { - let t = T::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { - syscall: "promise", - err: Box::new(e), - })?; - Ok(Ok(t)) - } - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "promise", - }), - Ok(Value::InvocationId(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "invocation_id", - syscall: "promise", - }), - Err(e) => Err(e), - }); - - InterceptErrorFuture::new(self.clone(), poll_future.map_err(Error)) + let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res { + Ok(Value::Success(mut s)) => { + let t = T::deserialize(&mut s).map_err(|e| Error::deserialization("promise", e))?; + Ok(Ok(t)) + } + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: <&'static str>::from(v), + syscall: "promise", + } + .into()), + Err(e) => Err(e.into()), + }); + + Either::Left(InterceptErrorFuture::new(self.clone(), poll_future)) } pub fn peek_promise( &self, name: &str, - ) -> impl Future, TerminalError>> + Send + Sync { - let maybe_handle = { must_lock!(self.inner).vm.sys_peek_promise(name.to_owned()) }; + ) -> impl Future, TerminalError>> + Send { + let mut inner_lock = must_lock!(self.inner); + let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_peek_promise(name.to_owned())); + drop(inner_lock); - let poll_future = VmAsyncResultPollFuture::new(Cow::Borrowed(&self.inner), maybe_handle) - .map(|res| match res { - Ok(Value::Void) => Ok(Ok(None)), - Ok(Value::Success(mut s)) => { - let t = T::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { - syscall: "peek_promise", - err: Box::new(e), - })?; - Ok(Ok(Some(t))) - } - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "peek_promise", - }), - Ok(Value::InvocationId(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "invocation_id", - syscall: "peek_promise", - }), - Err(e) => Err(e), - }); - - InterceptErrorFuture::new(self.clone(), poll_future.map_err(Error)) + let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res { + Ok(Value::Void) => Ok(Ok(None)), + Ok(Value::Success(mut s)) => { + let t = T::deserialize(&mut s) + .map_err(|e| Error::deserialization("peek_promise", e))?; + Ok(Ok(Some(t))) + } + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: <&'static str>::from(v), + syscall: "peek_promise", + } + .into()), + Err(e) => Err(e), + }); + + Either::Left(InterceptErrorFuture::new(self.clone(), poll_future)) } pub fn resolve_promise(&self, name: &str, t: T) { @@ -585,6 +600,7 @@ impl ContextInternal { InterceptErrorFuture::new(self.clone(), RunFuture::new(this, run_closure)) } + // Used by codegen pub fn handle_handler_result(&self, res: HandlerResult) { let mut inner_lock = must_lock!(self.inner); @@ -708,281 +724,192 @@ where type Output = Result, Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.project(); - - loop { - match this.state.as_mut().project() { - RunStateProj::New => { - let enter_result = { - must_lock!(this - .inner_ctx - .as_mut() - .expect("Future should not be polled after returning Poll::Ready")) - .vm - .sys_run_enter(this.name.to_owned()) - }; - - // Enter the side effect - match enter_result.map_err(ErrorInner::VM)? { - RunEnterResult::Executed(NonEmptyValue::Success(mut v)) => { - let t = Out::deserialize(&mut v).map_err(|e| { - ErrorInner::Deserialization { - syscall: "run", - err: Box::new(e), - } - })?; - return Poll::Ready(Ok(Ok(t))); - } - RunEnterResult::Executed(NonEmptyValue::Failure(f)) => { - return Poll::Ready(Ok(Err(f.into()))) - } - RunEnterResult::NotExecuted(_) => {} - }; - - // We need to run the closure - this.state.set(RunState::ClosureRunning { - start_time: Instant::now(), - fut: this - .closure - .take() - .expect("Future should not be polled after returning Poll::Ready") - .run(), - }); - } - RunStateProj::ClosureRunning { start_time, fut } => { - let res = match ready!(fut.poll(cx)) { - Ok(t) => RunExitResult::Success(Out::serialize(&t).map_err(|e| { - ErrorInner::Serialization { - syscall: "run", - err: Box::new(e), - } - })?), - Err(e) => match e.0 { - HandlerErrorInner::Retryable(err) => RunExitResult::RetryableFailure { - attempt_duration: start_time.elapsed(), - failure: TerminalFailure { - code: 500, - message: err.to_string(), - }, - }, - HandlerErrorInner::Terminal(t) => { - RunExitResult::TerminalFailure(TerminalError(t).into()) - } - }, - }; - - let inner_ctx = this - .inner_ctx - .take() - .expect("Future should not be polled after returning Poll::Ready"); - - let handle = { - must_lock!(inner_ctx) - .vm - .sys_run_exit(res, mem::take(this.retry_policy)) - }; - - this.state.set(RunState::PollFutureRunning { - fut: VmAsyncResultPollFuture::new(Cow::Owned(inner_ctx), handle), - }); - } - RunStateProj::PollFutureRunning { fut } => { - let value = ready!(fut.poll(cx))?; - - return Poll::Ready(match value { - Value::Void => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "empty", - syscall: "run", - } - .into()), - Value::Success(mut s) => { - let t = Out::deserialize(&mut s).map_err(|e| { - ErrorInner::Deserialization { - syscall: "run", - err: Box::new(e), - } - })?; - Ok(Ok(t)) - } - Value::Failure(f) => Ok(Err(f.into())), - Value::StateKeys(_) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "run", - } - .into()), - Value::InvocationId(_) => { - Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "invocation_id", - syscall: "run", - } - .into()) - } - }); - } - } - } + unimplemented!() + // let mut this = self.project(); + + // loop { + // match this.state.as_mut().project() { + // RunStateProj::New => { + // let enter_result = { + // must_lock!(this + // .inner_ctx + // .as_mut() + // .expect("Future should not be polled after returning Poll::Ready")) + // .vm + // .sys_run_enter(this.name.to_owned()) + // }; + // + // // Enter the side effect + // match enter_result.map_err(ErrorInner::VM)? { + // RunEnterResult::Executed(NonEmptyValue::Success(mut v)) => { + // let t = Out::deserialize(&mut v).map_err(|e| { + // ErrorInner::Deserialization { + // syscall: "run", + // err: Box::new(e), + // } + // })?; + // return Poll::Ready(Ok(Ok(t))); + // } + // RunEnterResult::Executed(NonEmptyValue::Failure(f)) => { + // return Poll::Ready(Ok(Err(f.into()))) + // } + // RunEnterResult::NotExecuted(_) => {} + // }; + // + // // We need to run the closure + // this.state.set(RunState::ClosureRunning { + // start_time: Instant::now(), + // fut: this + // .closure + // .take() + // .expect("Future should not be polled after returning Poll::Ready") + // .run(), + // }); + // } + // RunStateProj::ClosureRunning { start_time, fut } => { + // let res = match ready!(fut.poll(cx)) { + // Ok(t) => RunExitResult::Success(Out::serialize(&t).map_err(|e| { + // ErrorInner::Serialization { + // syscall: "run", + // err: Box::new(e), + // } + // })?), + // Err(e) => match e.0 { + // HandlerErrorInner::Retryable(err) => RunExitResult::RetryableFailure { + // attempt_duration: start_time.elapsed(), + // failure: TerminalFailure { + // code: 500, + // message: err.to_string(), + // }, + // }, + // HandlerErrorInner::Terminal(t) => { + // RunExitResult::TerminalFailure(TerminalError(t).into()) + // } + // }, + // }; + // + // let inner_ctx = this + // .inner_ctx + // .take() + // .expect("Future should not be polled after returning Poll::Ready"); + // + // let handle = { + // must_lock!(inner_ctx) + // .vm + // .sys_run_exit(res, mem::take(this.retry_policy)) + // }; + // + // this.state.set(RunState::PollFutureRunning { + // fut: VmAsyncResultPollFuture::maybe_new(Cow::Owned(inner_ctx), handle), + // }); + // } + // RunStateProj::PollFutureRunning { fut } => { + // let value = ready!(fut.poll(cx))?; + // + // return Poll::Ready(match value { + // Value::Void => Err(ErrorInner::UnexpectedValueVariantForSyscall { + // variant: "empty", + // syscall: "run", + // } + // .into()), + // Value::Success(mut s) => { + // let t = Out::deserialize(&mut s).map_err(|e| { + // ErrorInner::Deserialization { + // syscall: "run", + // err: Box::new(e), + // } + // })?; + // Ok(Ok(t)) + // } + // Value::Failure(f) => Ok(Err(f.into())), + // Value::StateKeys(_) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + // variant: "state_keys", + // syscall: "run", + // } + // .into()), + // Value::InvocationId(_) => { + // Err(ErrorInner::UnexpectedValueVariantForSyscall { + // variant: "invocation_id", + // syscall: "run", + // } + // .into()) + // } + // }); + // } + // } + // } } } -struct SendRequestHandle { +struct SendRequestHandle { + invocation_id_future: Shared, ctx: ContextInternal, - send_handle: Option, } -impl InvocationHandle for SendRequestHandle { +impl> + Send> InvocationHandle + for SendRequestHandle +{ fn invocation_id(&self) -> impl Future> + Send { - if let Some(ref send_handle) = self.send_handle { - let maybe_handle = { - must_lock!(self.ctx.inner) - .vm - .sys_get_call_invocation_id(GetInvocationIdTarget::SendEntry(*send_handle)) - }; - - let poll_future = VmAsyncResultPollFuture::new( - Cow::Borrowed(&self.ctx.inner), - maybe_handle, - ) - .map(|res| match res { - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Ok(Value::InvocationId(s)) => Ok(Ok(s)), - Err(e) => Err(e), - Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "get_call_invocation_id", - }), - Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "void", - syscall: "get_call_invocation_id", - }), - Ok(Value::Success(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "success", - syscall: "get_call_invocation_id", - }), - }); - - Either::Left(InterceptErrorFuture::new( - self.ctx.clone(), - poll_future.map_err(Error), - )) - } else { - // If the send didn't succeed, trap the execution - Either::Right(TrapFuture::default()) - } + let cloned_invocation_id_fut = Shared::clone(&self.invocation_id_future); + async move { cloned_invocation_id_fut.await } } - fn cancel(&self) { - if let Some(ref send_handle) = self.send_handle { - let mut inner_lock = must_lock!(self.ctx.inner); - let _ = inner_lock - .vm - .sys_cancel_invocation(CancelInvocationTarget::SendEntry(*send_handle)); + fn cancel(&self) -> impl Future> + Send { + let cloned_invocation_id_fut = Shared::clone(&self.invocation_id_future); + let cloned_ctx = Arc::clone(&self.ctx.inner); + async move { + let inv_id = cloned_invocation_id_fut.await?; + let mut inner_lock = must_lock!(cloned_ctx); + let _ = inner_lock.vm.sys_cancel_invocation(inv_id); + drop(inner_lock); + Ok(()) } - // If the send didn't succeed, then simply ignore the cancel } } pin_project! { - struct CallFutureImpl { + struct CallFutureImpl { + #[pin] + invocation_id_future: Shared, #[pin] - poll_future: VmAsyncResultPollFuture, - res: PhantomData R>, + result_future: ResultFut, + res: PhantomData Res>, ctx: ContextInternal, - call_handle: Option, } } -impl Future for CallFutureImpl { - type Output = Result, Error>; +impl> + Send, Res> Future + for CallFutureImpl +{ + type Output = ResultFut::Output; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); - - this.poll_future - .poll(cx) - .map(|res| match res { - Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "empty", - syscall: "call", - }), - Ok(Value::Success(mut s)) => { - let t = Res::deserialize(&mut s).map_err(|e| ErrorInner::Deserialization { - syscall: "call", - err: Box::new(e), - })?; - Ok(Ok(t)) - } - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "call", - }), - Ok(Value::InvocationId(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "invocation_id", - syscall: "call", - }), - Err(e) => Err(e), - }) - .map(|res| res.map_err(Error)) + this.result_future.poll(cx) } } -impl InvocationHandle for CallFutureImpl { +impl> + Send, ResultFut, Res> + InvocationHandle for CallFutureImpl +{ fn invocation_id(&self) -> impl Future> + Send { - if let Some(ref call_handle) = self.call_handle { - let maybe_handle = { - must_lock!(self.ctx.inner) - .vm - .sys_get_call_invocation_id(GetInvocationIdTarget::CallEntry(*call_handle)) - }; - - let poll_future = VmAsyncResultPollFuture::new( - Cow::Borrowed(&self.ctx.inner), - maybe_handle, - ) - .map(|res| match res { - Ok(Value::Failure(f)) => Ok(Err(f.into())), - Ok(Value::InvocationId(s)) => Ok(Ok(s)), - Err(e) => Err(e), - Ok(Value::StateKeys(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "state_keys", - syscall: "get_call_invocation_id", - }), - Ok(Value::Void) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "void", - syscall: "get_call_invocation_id", - }), - Ok(Value::Success(_)) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - variant: "success", - syscall: "get_call_invocation_id", - }), - }); - - Either::Left(InterceptErrorFuture::new( - self.ctx.clone(), - poll_future.map_err(Error), - )) - } else { - // If the send didn't succeed, trap the execution - Either::Right(TrapFuture::default()) - } + let cloned_invocation_id_fut = Shared::clone(&self.invocation_id_future); + async move { cloned_invocation_id_fut.await } } - fn cancel(&self) { - if let Some(ref call_handle) = self.call_handle { - let mut inner_lock = must_lock!(self.ctx.inner); - let _ = inner_lock - .vm - .sys_cancel_invocation(CancelInvocationTarget::CallEntry(*call_handle)); + fn cancel(&self) -> impl Future> + Send { + let cloned_invocation_id_fut = Shared::clone(&self.invocation_id_future); + let cloned_ctx = Arc::clone(&self.ctx.inner); + async move { + let inv_id = cloned_invocation_id_fut.await?; + let mut inner_lock = must_lock!(cloned_ctx); + let _ = inner_lock.vm.sys_cancel_invocation(inv_id); + drop(inner_lock); + Ok(()) } - // If the send didn't succeed, then simply ignore the cancel } } -impl CallFuture, Error>> - for CallFutureImpl -{ -} - impl InvocationHandle for Either { fn invocation_id(&self) -> impl Future> + Send { match self { @@ -991,10 +918,10 @@ impl InvocationHandle for Either } } - fn cancel(&self) { + fn cancel(&self) -> impl Future> + Send { match self { - Either::Left(l) => l.cancel(), - Either::Right(r) => r.cancel(), + Either::Left(l) => Either::Left(l.cancel()), + Either::Right(r) => Either::Right(r.cancel()), } } } @@ -1004,6 +931,8 @@ where A: CallFuture, B: CallFuture, { + + } struct InvocationIdBackedInvocationHandle { @@ -1016,12 +945,99 @@ impl InvocationHandle for InvocationIdBackedInvocationHandle { ready(Ok(self.invocation_id.clone())) } - fn cancel(&self) { + fn cancel(&self) -> impl Future> + Send { let mut inner_lock = must_lock!(self.ctx.inner); let _ = inner_lock .vm - .sys_cancel_invocation(CancelInvocationTarget::InvocationId( - self.invocation_id.clone(), - )); + .sys_cancel_invocation(self.invocation_id.clone()); + ready(Ok(())) } } + +impl Error { + fn serialization( + syscall: &'static str, + e: E, + ) -> Self { + ErrorInner::Serialization { + syscall, + err: Box::new(e), + } + .into() + } + + fn deserialization( + syscall: &'static str, + e: E, + ) -> Self { + ErrorInner::Deserialization { + syscall, + err: Box::new(e), + } + .into() + } +} + +fn get_async_result( + ctx: Arc>, + handle: NotificationHandle, +) -> impl Future> + Send { + async move { + loop { + let mut inner_lock = must_lock!(ctx); + + // Let's consume some output to begin with + let out = inner_lock.vm.take_output(); + match out { + TakeOutputResult::Buffer(b) => { + if !inner_lock.write.send(b) { + return Err(ErrorInner::Suspended); + } + } + TakeOutputResult::EOF => return Err(ErrorInner::UnexpectedOutputClosed), + } + + // Let's do some progress now + match inner_lock.vm.do_progress(vec![handle]) { + Ok(DoProgressResponse::AnyCompleted) => { + // We're good, we got the response + break; + } + Ok(DoProgressResponse::ReadFromInput) => { + drop(inner_lock); + match inner_lock.read.recv().await { + Some(Ok(b)) => must_lock!(ctx).vm.notify_input(b), + Some(Err(e)) => must_lock!(ctx).vm.notify_error( + CoreError::new(500u16, format!("Error when reading the body {e:?}",)), + None, + ), + None => must_lock!(ctx).vm.notify_input_closed(), + } + continue; + } + Ok(DoProgressResponse::ExecuteRun(_)) => { + unimplemented!() + } + Ok(DoProgressResponse::WaitingPendingRun) => { + unimplemented!() + } + Ok(DoProgressResponse::CancelSignalReceived) => { + unimplemented!() + } + Err(e) => { + return Err(e.into()); + } + }; + } + let mut inner_lock = must_lock!(ctx); + + // At this point let's try to take the notification + match inner_lock.vm.take_notification(handle) { + Ok(Some(v)) => return Ok(v), + Ok(None) => { + panic!("This is not supposed to happen, handle was flagged as completed") + } + Err(e) => return Err(e.into()), + } + }.map_err(Error::from) +} diff --git a/src/endpoint/futures/async_result_poll.rs b/src/endpoint/futures/async_result_poll.rs index 23f7fa2..7e89afd 100644 --- a/src/endpoint/futures/async_result_poll.rs +++ b/src/endpoint/futures/async_result_poll.rs @@ -1,7 +1,7 @@ use crate::endpoint::context::ContextInternalInner; -use crate::endpoint::ErrorInner; +use crate::endpoint::{BoxError, ErrorInner}; use restate_sdk_shared_core::{ - NotificationHandle, SuspendedOrVMError, TakeOutputResult, Value, VM, + DoProgressResponse, Error as CoreError, NotificationHandle, TakeOutputResult, Value, VM, }; use std::borrow::Cow; use std::future::Future; @@ -10,35 +10,39 @@ use std::sync::{Arc, Mutex}; use std::task::Poll; pub(crate) struct VmAsyncResultPollFuture { + ctx: Arc>, state: Option, } impl VmAsyncResultPollFuture { - pub fn new( + pub fn maybe_new( inner: Cow<'_, Arc>>, - handle: Result, + handle: Result, ) -> Self { VmAsyncResultPollFuture { + ctx: inner.into_owned(), state: Some(match handle { - Ok(handle) => PollState::Init { - ctx: inner.into_owned(), - handle, - }, + Ok(handle) => PollState::Init(handle), Err(err) => PollState::Failed(ErrorInner::VM(err)), }), } } + + pub fn new( + inner: Cow<'_, Arc>>, + handle: NotificationHandle, + ) -> Self { + VmAsyncResultPollFuture { + ctx: inner.into_owned(), + state: Some(PollState::Init(handle)), + } + } } enum PollState { - Init { - ctx: Arc>, - handle: NotificationHandle, - }, - WaitingInput { - ctx: Arc>, - handle: NotificationHandle, - }, + Init(NotificationHandle), + PollProgress(NotificationHandle), + WaitingInput(NotificationHandle), Failed(ErrorInner), } @@ -52,58 +56,38 @@ impl Future for VmAsyncResultPollFuture { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + let ctx = &self.ctx; + let mut inner_lock = must_lock!(ctx); + let state = &mut self.state; + loop { - match self - .state + match state .take() .expect("Future should not be polled after Poll::Ready") { - PollState::Init { ctx, handle } => { - // Acquire lock - let mut inner_lock = must_lock!(ctx); - - // Let's consume some output + PollState::Init(handle) => { + // Let's consume some output to begin with let out = inner_lock.vm.take_output(); match out { TakeOutputResult::Buffer(b) => { if !inner_lock.write.send(b) { - self.state = Some(PollState::Failed(ErrorInner::Suspended)); - continue; + return Poll::Ready(Err(ErrorInner::Suspended)); } } TakeOutputResult::EOF => { - self.state = - Some(PollState::Failed(ErrorInner::UnexpectedOutputClosed)); - continue; + return Poll::Ready(Err(ErrorInner::UnexpectedOutputClosed)) } } - // Notify that we reached an await point - inner_lock.vm.notify_await_point(handle); - - // At this point let's try to take the async result - match inner_lock.vm.take_async_result(handle) { - Ok(Some(v)) => return Poll::Ready(Ok(v)), - Ok(None) => { - drop(inner_lock); - self.state = Some(PollState::WaitingInput { ctx, handle }); - } - Err(SuspendedOrVMError::Suspended(_)) => { - self.state = Some(PollState::Failed(ErrorInner::Suspended)); - } - Err(SuspendedOrVMError::VM(e)) => { - self.state = Some(PollState::Failed(ErrorInner::VM(e))); - } - } + // We can now start polling + *state = Some(PollState::PollProgress(handle)); } - PollState::WaitingInput { ctx, handle } => { - let mut inner_lock = must_lock!(ctx); - + PollState::WaitingInput(handle) => { let read_result = match inner_lock.read.poll_recv(cx) { Poll::Ready(t) => t, Poll::Pending => { - drop(inner_lock); - self.state = Some(PollState::WaitingInput { ctx, handle }); + // Still need to wait for input + *state = Some(PollState::WaitingInput(handle)); return Poll::Pending; } }; @@ -112,26 +96,47 @@ impl Future for VmAsyncResultPollFuture { match read_result { Some(Ok(b)) => inner_lock.vm.notify_input(b), Some(Err(e)) => inner_lock.vm.notify_error( - "Error when reading the body".into(), - e.to_string().into(), + CoreError::new(500u16, format!("Error when reading the body {e:?}",)), None, ), None => inner_lock.vm.notify_input_closed(), } - // Now try to take async result again - match inner_lock.vm.take_async_result(handle) { - Ok(Some(v)) => return Poll::Ready(Ok(v)), - Ok(None) => { - drop(inner_lock); - self.state = Some(PollState::WaitingInput { ctx, handle }); + // It's time to poll progress again + *state = Some(PollState::PollProgress(handle)); + } + PollState::PollProgress(handle) => { + match inner_lock.vm.do_progress(vec![handle]) { + Ok(DoProgressResponse::AnyCompleted) => { + // We're good, we got the response } - Err(SuspendedOrVMError::Suspended(_)) => { - self.state = Some(PollState::Failed(ErrorInner::Suspended)); + Ok(DoProgressResponse::ReadFromInput) => { + *state = Some(PollState::WaitingInput(handle)); + continue; } - Err(SuspendedOrVMError::VM(e)) => { - self.state = Some(PollState::Failed(ErrorInner::VM(e))); + Ok(DoProgressResponse::ExecuteRun(_)) => { + unimplemented!() + } + Ok(DoProgressResponse::WaitingPendingRun) => { + unimplemented!() + } + Ok(DoProgressResponse::CancelSignalReceived) => { + unimplemented!() + } + Err(e) => { + return Poll::Ready(Err(e.into())); + } + }; + + // At this point let's try to take the notification + match inner_lock.vm.take_notification(handle) { + Ok(Some(v)) => return Poll::Ready(Ok(v)), + Ok(None) => { + panic!( + "This is not supposed to happen, handle was flagged as completed" + ) } + Err(e) => return Poll::Ready(Err(e.into())), } } PollState::Failed(err) => return Poll::Ready(Err(err)), diff --git a/src/endpoint/futures/intercept_error.rs b/src/endpoint/futures/intercept_error.rs index 606df95..7c6e617 100644 --- a/src/endpoint/futures/intercept_error.rs +++ b/src/endpoint/futures/intercept_error.rs @@ -65,7 +65,7 @@ impl InvocationHandle for InterceptErrorFuture { self.fut.invocation_id() } - fn cancel(&self) { + fn cancel(&self) -> impl Future> + Send { self.fut.cancel() } } diff --git a/src/endpoint/futures/trap.rs b/src/endpoint/futures/trap.rs index b614de1..b0a4ae9 100644 --- a/src/endpoint/futures/trap.rs +++ b/src/endpoint/futures/trap.rs @@ -28,7 +28,9 @@ impl InvocationHandle for TrapFuture { TrapFuture::default() } - fn cancel(&self) {} + fn cancel(&self) -> impl Future> + Send { + TrapFuture::default() + } } impl CallFuture for TrapFuture {} diff --git a/src/endpoint/mod.rs b/src/endpoint/mod.rs index 3a20feb..fbbe2de 100644 --- a/src/endpoint/mod.rs +++ b/src/endpoint/mod.rs @@ -10,7 +10,9 @@ use ::futures::future::BoxFuture; use ::futures::{Stream, StreamExt}; use bytes::Bytes; pub use context::{ContextInternal, InputMetadata}; -use restate_sdk_shared_core::{CoreVM, Error as VMError, Header, HeaderMap, IdentityVerifier, KeyError, VerifyError, VM}; +use restate_sdk_shared_core::{ + CoreVM, Error as CoreError, Header, HeaderMap, IdentityVerifier, KeyError, VerifyError, VM, +}; use std::collections::HashMap; use std::future::poll_fn; use std::pin::Pin; @@ -103,7 +105,7 @@ pub(crate) enum ErrorInner { #[error("Received a request for unknown service handler '{0}/{1}'")] UnknownServiceHandler(String, String), #[error("Error when processing the request: {0:?}")] - VM(#[from] VMError), + VM(#[from] CoreError), #[error("Error when verifying identity: {0:?}")] IdentityVerification(#[from] VerifyError), #[error("Cannot convert header '{0}', reason: {1}")] @@ -140,6 +142,27 @@ pub(crate) enum ErrorInner { }, } +impl From for ErrorInner { + fn from(_: restate_sdk_shared_core::SuspendedError) -> Self { + Self::Suspended + } +} + +impl From for ErrorInner { + fn from(value: restate_sdk_shared_core::SuspendedOrVMError) -> Self { + match value { + restate_sdk_shared_core::SuspendedOrVMError::Suspended(e) => e.into(), + restate_sdk_shared_core::SuspendedOrVMError::VM(e) => e.into(), + } + } +} + +impl From for Error { + fn from(e: CoreError) -> Self { + ErrorInner::from(e).into() + } +} + struct BoxedService( Box>> + Send + Sync + 'static>, ); @@ -392,8 +415,7 @@ async fn init_loop_vm(vm: &mut CoreVM, input_rx: &mut InputReceiver) -> Result<( match input_rx.recv().await { Some(Ok(b)) => vm.notify_input(b), Some(Err(e)) => vm.notify_error( - VMError::new(500u16, format!("Error when reading the body: {e}")) - , + CoreError::new(500u16, format!("Error when reading the body: {e}")), None, ), None => vm.notify_input_closed(), diff --git a/src/hyper.rs b/src/hyper.rs index 4367a16..7b92fdc 100644 --- a/src/hyper.rs +++ b/src/hyper.rs @@ -10,7 +10,7 @@ use http::{response, HeaderName, HeaderValue, Request, Response}; use http_body_util::{BodyExt, Either, Full}; use hyper::body::{Body, Frame, Incoming}; use hyper::service::Service; -use restate_sdk_shared_core::ResponseHead; +use restate_sdk_shared_core::{Header, ResponseHead}; use std::convert::Infallible; use std::future::{ready, Ready}; use std::ops::Deref; @@ -56,13 +56,18 @@ impl Service> for HyperEndpoint { match endpoint_response { endpoint::Response::ReplyNow { - response_head, + status_code, + headers, body, - } => ready(Ok(response_builder_from_response_head(response_head) - .body(Either::Left(Full::new(body))) - .expect("Headers should be valid"))), + } => ready(Ok(response_builder_from_response_head( + status_code, + headers, + ) + .body(Either::Left(Full::new(body))) + .expect("Headers should be valid"))), endpoint::Response::BidiStream { - response_head, + status_code, + headers, handler, } => { let input_receiver = @@ -73,24 +78,30 @@ impl Service> for HyperEndpoint { let handler_fut = Box::pin(handler.handle(input_receiver, output_sender)); - ready(Ok(response_builder_from_response_head(response_head) - .body(Either::Right(BidiStreamRunner { - fut: Some(handler_fut), - output_rx, - end_stream: false, - })) - .expect("Headers should be valid"))) + ready(Ok(response_builder_from_response_head( + status_code, + headers, + ) + .body(Either::Right(BidiStreamRunner { + fut: Some(handler_fut), + output_rx, + end_stream: false, + })) + .expect("Headers should be valid"))) } } } } -fn response_builder_from_response_head(response_head: ResponseHead) -> response::Builder { +fn response_builder_from_response_head( + status_code: u16, + headers: Vec
, +) -> response::Builder { let mut response_builder = Response::builder() - .status(response_head.status_code) + .status(status_code) .header(X_RESTATE_SERVER, X_RESTATE_SERVER_VALUE); - for header in response_head.headers { + for header in headers { response_builder = response_builder.header(header.key.deref(), header.value.deref()); } From 76193c12f589595b7bed9c34ff64c2426a96438e Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Tue, 4 Mar 2025 12:44:25 +0100 Subject: [PATCH 03/10] Working! --- src/context/mod.rs | 10 +- src/endpoint/context.rs | 196 ++++++++-------------- src/endpoint/futures/async_result_poll.rs | 77 ++++----- src/endpoint/futures/intercept_error.rs | 4 +- src/hyper.rs | 2 +- test-services/src/test_utils_service.rs | 7 +- 6 files changed, 119 insertions(+), 177 deletions(-) diff --git a/src/context/mod.rs b/src/context/mod.rs index b3425b6..4a0129e 100644 --- a/src/context/mod.rs +++ b/src/context/mod.rs @@ -370,20 +370,20 @@ impl<'ctx, CTX: private::SealedContext<'ctx>> ContextTimers<'ctx> for CTX {} /// // To a Service: /// ctx.service_client::() /// .my_handler(String::from("Hi!")) -/// .send_with_delay(Duration::from_millis(5000)); +/// .send_after(Duration::from_millis(5000)); /// /// // To a Virtual Object: /// ctx.object_client::("Mary") /// .my_handler(String::from("Hi!")) -/// .send_with_delay(Duration::from_millis(5000)); +/// .send_after(Duration::from_millis(5000)); /// /// // To a Workflow: /// ctx.workflow_client::("my-workflow-id") /// .run(String::from("Hi!")) -/// .send_with_delay(Duration::from_millis(5000)); +/// .send_after(Duration::from_millis(5000)); /// ctx.workflow_client::("my-workflow-id") /// .interact_with_workflow() -/// .send_with_delay(Duration::from_millis(5000)); +/// .send_after(Duration::from_millis(5000)); /// # Ok(()) /// # } /// ``` @@ -632,7 +632,7 @@ pub trait ContextAwakeables<'ctx>: private::SealedContext<'ctx> { &self, ) -> ( String, - impl Future> + Send + 'ctx, + impl Future> + Send + 'ctx, ) { self.inner_context().awakeable() } diff --git a/src/endpoint/context.rs b/src/endpoint/context.rs index 5a0a317..e00df54 100644 --- a/src/endpoint/context.rs +++ b/src/endpoint/context.rs @@ -12,18 +12,16 @@ use futures::future::{Either, Shared}; use futures::{FutureExt, TryFutureExt}; use pin_project_lite::pin_project; use restate_sdk_shared_core::{ - CallHandle as CoreCallHandle, CoreVM, DoProgressResponse, Error as CoreError, NonEmptyValue, - NotificationHandle, RetryPolicy, RunExitResult, SendHandle as CoreSendHandle, TakeOutputResult, - Target, TerminalFailure, Value, VM, + CoreVM, Error as CoreError, NonEmptyValue, NotificationHandle, RetryPolicy, TakeOutputResult, + Target, Value, VM, }; use std::borrow::Cow; use std::collections::HashMap; use std::future::{ready, Future}; use std::marker::PhantomData; -use std::mem; use std::pin::Pin; use std::sync::{Arc, Mutex}; -use std::task::{ready, Context, Poll}; +use std::task::{Context, Poll}; use std::time::{Duration, Instant, SystemTime}; pub struct ContextInternalInner { @@ -50,7 +48,8 @@ impl ContextInternalInner { pub(super) fn fail(&mut self, e: Error) { self.vm.notify_error( - CoreError::new(500u16, e.0.to_string()).with_stacktrace(Cow::Owned(format!("{:#}", e.0))), + CoreError::new(500u16, e.0.to_string()) + .with_stacktrace(Cow::Owned(format!("{:#}", e.0))), None, ); self.handler_state.mark_error(e); @@ -237,7 +236,7 @@ impl ContextInternal { syscall: "get_state", } .into()), - Err(e) => Err(e.into()), + Err(e) => Err(e), }); Either::Left(InterceptErrorFuture::new(self.clone(), poll_future)) @@ -255,7 +254,7 @@ impl ContextInternal { syscall: "get_keys", } .into()), - Err(e) => Err(e.into()), + Err(e) => Err(e), }); Either::Left(InterceptErrorFuture::new(self.clone(), poll_future)) @@ -302,7 +301,7 @@ impl ContextInternal { syscall: "sleep", } .into()), - Err(e) => Err(e.into()), + Err(e) => Err(e), }); Either::Left(InterceptErrorFuture::new(self.clone(), poll_future)) @@ -331,7 +330,7 @@ impl ContextInternal { drop(inner_lock); // Let's prepare the two futures here - let call_invocation_id_fut = InterceptErrorFuture::new( + let invocation_id_fut = InterceptErrorFuture::new( self.clone(), get_async_result( Arc::clone(&self.inner), @@ -343,36 +342,34 @@ impl ContextInternal { Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall { variant: <&'static str>::from(v), syscall: "call", - }.into()), - Err(e) => Err(e.into()), + } + .into()), + Err(e) => Err(e), }), ); - let call_result_future = InterceptErrorFuture::new( + let result_future = InterceptErrorFuture::new( self.clone(), get_async_result( Arc::clone(&self.inner), call_handle.call_notification_handle, ) .map(|res| match res { - Ok(Value::Success(mut s)) => { - let t = - Res::deserialize(&mut s).map_err(|e| Error::deserialization("call", e))?; - Ok(Ok(t)) - } - Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(Value::Success(mut s)) => Ok(Ok( + Res::deserialize(&mut s).map_err(|e| Error::deserialization("call", e))? + )), + Ok(Value::Failure(f)) => Ok(Err(TerminalError::from(f))), Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall { variant: <&'static str>::from(v), syscall: "call", } .into()), - Err(e) => Err(e.into()), + Err(e) => Err(e), }), ); Either::Left(CallFutureImpl { - invocation_id_future: call_invocation_id_fut.shared(), - result_future: call_result_future, - res: PhantomData, + invocation_id_future: invocation_id_fut.shared(), + result_future, ctx: self.clone(), }) } @@ -390,33 +387,32 @@ impl ContextInternal { target.idempotency_key = idempotency_key; let input = match Req::serialize(&req) { - Ok(b,) => b, - Err(e) => { - inner_lock.fail(Error::serialization("call", e)); - return Either::Right(TrapFuture::<()>::default()) - } + Ok(b) => b, + Err(e) => { + inner_lock.fail(Error::serialization("call", e)); + return Either::Right(TrapFuture::<()>::default()); + } }; - let send_handle = match - inner_lock.vm.sys_send( - target, - input, - delay.map(|delay| { - SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .expect("Duration since unix epoch cannot fail") - + delay - }) - ) { - Ok(h,) => h, - Err(e) => { + let send_handle = match inner_lock.vm.sys_send( + target, + input, + delay.map(|delay| { + SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .expect("Duration since unix epoch cannot fail") + + delay + }), + ) { + Ok(h) => h, + Err(e) => { inner_lock.fail(e.into()); - return Either::Right(TrapFuture::<()>::default()) + return Either::Right(TrapFuture::<()>::default()); } }; drop(inner_lock); - let call_invocation_id_fut = InterceptErrorFuture::new( + let invocation_id_fut = InterceptErrorFuture::new( self.clone(), get_async_result( Arc::clone(&self.inner), @@ -430,12 +426,12 @@ impl ContextInternal { syscall: "call", } .into()), - Err(e) => Err(e.into()), + Err(e) => Err(e), }), ); Either::Left(SendRequestHandle { - invocation_id_future: call_invocation_id_fut.shared(), + invocation_id_future: invocation_id_fut.shared(), ctx: self.clone(), }) } @@ -471,14 +467,15 @@ impl ContextInternal { drop(inner_lock); let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res { - Ok(Value::Success(mut s)) => { - Ok(Ok(T::deserialize(&mut s).map_err(|e| Error::deserialization("awakeable", e))?)) - } + Ok(Value::Success(mut s)) => Ok(Ok( + T::deserialize(&mut s).map_err(|e| Error::deserialization("awakeable", e))? + )), Ok(Value::Failure(f)) => Ok(Err(f.into())), Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall { variant: <&'static str>::from(v), syscall: "awakeable", - }.into()), + } + .into()), Err(e) => Err(e), }); @@ -527,7 +524,7 @@ impl ContextInternal { syscall: "promise", } .into()), - Err(e) => Err(e.into()), + Err(e) => Err(e), }); Either::Left(InterceptErrorFuture::new(self.clone(), poll_future)) @@ -723,7 +720,7 @@ where { type Output = Result, Error>; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { unimplemented!() // let mut this = self.project(); @@ -850,8 +847,7 @@ impl> + Send> Invocation for SendRequestHandle { fn invocation_id(&self) -> impl Future> + Send { - let cloned_invocation_id_fut = Shared::clone(&self.invocation_id_future); - async move { cloned_invocation_id_fut.await } + Shared::clone(&self.invocation_id_future) } fn cancel(&self) -> impl Future> + Send { @@ -868,18 +864,19 @@ impl> + Send> Invocation } pin_project! { - struct CallFutureImpl { - #[pin] + struct CallFutureImpl { + #[pin] invocation_id_future: Shared, #[pin] result_future: ResultFut, - res: PhantomData Res>, ctx: ContextInternal, } } -impl> + Send, Res> Future - for CallFutureImpl +impl Future for CallFutureImpl +where + InvIdFut: Future> + Send, + ResultFut: Future> + Send, { type Output = ResultFut::Output; @@ -889,12 +886,12 @@ impl> + } } -impl> + Send, ResultFut, Res> - InvocationHandle for CallFutureImpl +impl InvocationHandle for CallFutureImpl +where + InvIdFut: Future> + Send, { fn invocation_id(&self) -> impl Future> + Send { - let cloned_invocation_id_fut = Shared::clone(&self.invocation_id_future); - async move { cloned_invocation_id_fut.await } + Shared::clone(&self.invocation_id_future) } fn cancel(&self) -> impl Future> + Send { @@ -910,7 +907,19 @@ impl> + Send, ResultFut, } } -impl InvocationHandle for Either { +impl CallFuture> + for CallFutureImpl +where + InvIdFut: Future> + Send, + ResultFut: Future> + Send, +{ +} + +impl InvocationHandle for Either +where + A: InvocationHandle, + B: InvocationHandle, +{ fn invocation_id(&self) -> impl Future> + Send { match self { Either::Left(l) => Either::Left(l.invocation_id()), @@ -931,8 +940,6 @@ where A: CallFuture, B: CallFuture, { - - } struct InvocationIdBackedInvocationHandle { @@ -982,62 +989,5 @@ fn get_async_result( ctx: Arc>, handle: NotificationHandle, ) -> impl Future> + Send { - async move { - loop { - let mut inner_lock = must_lock!(ctx); - - // Let's consume some output to begin with - let out = inner_lock.vm.take_output(); - match out { - TakeOutputResult::Buffer(b) => { - if !inner_lock.write.send(b) { - return Err(ErrorInner::Suspended); - } - } - TakeOutputResult::EOF => return Err(ErrorInner::UnexpectedOutputClosed), - } - - // Let's do some progress now - match inner_lock.vm.do_progress(vec![handle]) { - Ok(DoProgressResponse::AnyCompleted) => { - // We're good, we got the response - break; - } - Ok(DoProgressResponse::ReadFromInput) => { - drop(inner_lock); - match inner_lock.read.recv().await { - Some(Ok(b)) => must_lock!(ctx).vm.notify_input(b), - Some(Err(e)) => must_lock!(ctx).vm.notify_error( - CoreError::new(500u16, format!("Error when reading the body {e:?}",)), - None, - ), - None => must_lock!(ctx).vm.notify_input_closed(), - } - continue; - } - Ok(DoProgressResponse::ExecuteRun(_)) => { - unimplemented!() - } - Ok(DoProgressResponse::WaitingPendingRun) => { - unimplemented!() - } - Ok(DoProgressResponse::CancelSignalReceived) => { - unimplemented!() - } - Err(e) => { - return Err(e.into()); - } - }; - } - let mut inner_lock = must_lock!(ctx); - - // At this point let's try to take the notification - match inner_lock.vm.take_notification(handle) { - Ok(Some(v)) => return Ok(v), - Ok(None) => { - panic!("This is not supposed to happen, handle was flagged as completed") - } - Err(e) => return Err(e.into()), - } - }.map_err(Error::from) + VmAsyncResultPollFuture::new(ctx, handle).map_err(Error::from) } diff --git a/src/endpoint/futures/async_result_poll.rs b/src/endpoint/futures/async_result_poll.rs index 7e89afd..74b5ee7 100644 --- a/src/endpoint/futures/async_result_poll.rs +++ b/src/endpoint/futures/async_result_poll.rs @@ -1,49 +1,38 @@ use crate::endpoint::context::ContextInternalInner; -use crate::endpoint::{BoxError, ErrorInner}; +use crate::endpoint::ErrorInner; use restate_sdk_shared_core::{ DoProgressResponse, Error as CoreError, NotificationHandle, TakeOutputResult, Value, VM, }; -use std::borrow::Cow; use std::future::Future; use std::pin::Pin; use std::sync::{Arc, Mutex}; use std::task::Poll; pub(crate) struct VmAsyncResultPollFuture { - ctx: Arc>, - state: Option, + state: Option, } impl VmAsyncResultPollFuture { - pub fn maybe_new( - inner: Cow<'_, Arc>>, - handle: Result, - ) -> Self { + pub fn new(ctx: Arc>, handle: NotificationHandle) -> Self { VmAsyncResultPollFuture { - ctx: inner.into_owned(), - state: Some(match handle { - Ok(handle) => PollState::Init(handle), - Err(err) => PollState::Failed(ErrorInner::VM(err)), - }), - } - } - - pub fn new( - inner: Cow<'_, Arc>>, - handle: NotificationHandle, - ) -> Self { - VmAsyncResultPollFuture { - ctx: inner.into_owned(), - state: Some(PollState::Init(handle)), + state: Some(AsyncResultPollState::Init { ctx, handle }), } } } -enum PollState { - Init(NotificationHandle), - PollProgress(NotificationHandle), - WaitingInput(NotificationHandle), - Failed(ErrorInner), +enum AsyncResultPollState { + Init { + ctx: Arc>, + handle: NotificationHandle, + }, + PollProgress { + ctx: Arc>, + handle: NotificationHandle, + }, + WaitingInput { + ctx: Arc>, + handle: NotificationHandle, + }, } macro_rules! must_lock { @@ -56,16 +45,15 @@ impl Future for VmAsyncResultPollFuture { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { - let ctx = &self.ctx; - let mut inner_lock = must_lock!(ctx); - let state = &mut self.state; - loop { - match state + match self + .state .take() .expect("Future should not be polled after Poll::Ready") { - PollState::Init(handle) => { + AsyncResultPollState::Init { ctx, handle } => { + let mut inner_lock = must_lock!(ctx); + // Let's consume some output to begin with let out = inner_lock.vm.take_output(); match out { @@ -80,14 +68,18 @@ impl Future for VmAsyncResultPollFuture { } // We can now start polling - *state = Some(PollState::PollProgress(handle)); + drop(inner_lock); + self.state = Some(AsyncResultPollState::PollProgress { ctx, handle }); } - PollState::WaitingInput(handle) => { + AsyncResultPollState::WaitingInput { ctx, handle } => { + let mut inner_lock = must_lock!(ctx); + let read_result = match inner_lock.read.poll_recv(cx) { Poll::Ready(t) => t, Poll::Pending => { // Still need to wait for input - *state = Some(PollState::WaitingInput(handle)); + drop(inner_lock); + self.state = Some(AsyncResultPollState::WaitingInput { ctx, handle }); return Poll::Pending; } }; @@ -103,15 +95,19 @@ impl Future for VmAsyncResultPollFuture { } // It's time to poll progress again - *state = Some(PollState::PollProgress(handle)); + drop(inner_lock); + self.state = Some(AsyncResultPollState::PollProgress { ctx, handle }); } - PollState::PollProgress(handle) => { + AsyncResultPollState::PollProgress { ctx, handle } => { + let mut inner_lock = must_lock!(ctx); + match inner_lock.vm.do_progress(vec![handle]) { Ok(DoProgressResponse::AnyCompleted) => { // We're good, we got the response } Ok(DoProgressResponse::ReadFromInput) => { - *state = Some(PollState::WaitingInput(handle)); + drop(inner_lock); + self.state = Some(AsyncResultPollState::WaitingInput { ctx, handle }); continue; } Ok(DoProgressResponse::ExecuteRun(_)) => { @@ -139,7 +135,6 @@ impl Future for VmAsyncResultPollFuture { Err(e) => return Poll::Ready(Err(e.into())), } } - PollState::Failed(err) => return Poll::Ready(Err(err)), } } } diff --git a/src/endpoint/futures/intercept_error.rs b/src/endpoint/futures/intercept_error.rs index 7c6e617..81eafe5 100644 --- a/src/endpoint/futures/intercept_error.rs +++ b/src/endpoint/futures/intercept_error.rs @@ -1,4 +1,4 @@ -use crate::context::{CallFuture, InvocationHandle, RunFuture, RunRetryPolicy}; +use crate::context::{InvocationHandle, RunFuture, RunRetryPolicy}; use crate::endpoint::{ContextInternal, Error}; use crate::errors::TerminalError; use pin_project_lite::pin_project; @@ -69,5 +69,3 @@ impl InvocationHandle for InterceptErrorFuture { self.fut.cancel() } } - -impl CallFuture for InterceptErrorFuture where F: CallFuture> {} diff --git a/src/hyper.rs b/src/hyper.rs index 7b92fdc..a5445e6 100644 --- a/src/hyper.rs +++ b/src/hyper.rs @@ -10,7 +10,7 @@ use http::{response, HeaderName, HeaderValue, Request, Response}; use http_body_util::{BodyExt, Either, Full}; use hyper::body::{Body, Frame, Incoming}; use hyper::service::Service; -use restate_sdk_shared_core::{Header, ResponseHead}; +use restate_sdk_shared_core::Header; use std::convert::Infallible; use std::future::{ready, Ready}; use std::ops::Deref; diff --git a/test-services/src/test_utils_service.rs b/test-services/src/test_utils_service.rs index 429e6a2..4e47b87 100644 --- a/test-services/src/test_utils_service.rs +++ b/test-services/src/test_utils_service.rs @@ -5,7 +5,6 @@ use futures::FutureExt; use restate_sdk::prelude::*; use serde::{Deserialize, Serialize}; use std::collections::HashMap; -use std::convert::Infallible; use std::sync::atomic::{AtomicU8, Ordering}; use std::sync::Arc; use std::time::Duration; @@ -64,7 +63,7 @@ pub(crate) trait TestUtilsService { #[name = "getEnvVariable"] async fn get_env_variable(env: String) -> HandlerResult; #[name = "cancelInvocation"] - async fn cancel_invocation(invocation_id: String) -> Result<(), Infallible>; + async fn cancel_invocation(invocation_id: String) -> Result<(), TerminalError>; #[name = "interpretCommands"] async fn interpret_commands(req: Json) -> HandlerResult<()>; } @@ -162,8 +161,8 @@ impl TestUtilsService for TestUtilsServiceImpl { &self, ctx: Context<'_>, invocation_id: String, - ) -> Result<(), Infallible> { - ctx.invocation_handle(invocation_id).cancel(); + ) -> Result<(), TerminalError> { + ctx.invocation_handle(invocation_id).cancel().await?; Ok(()) } From ae09e68ae706f28c862691bb02bdbbdb4db52ecf Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Tue, 4 Mar 2025 12:45:13 +0100 Subject: [PATCH 04/10] For now revert the changes to the test services to run once the integration tests --- test-services/src/proxy.rs | 17 +++++++---------- test-services/src/test_utils_service.rs | 11 ----------- 2 files changed, 7 insertions(+), 21 deletions(-) diff --git a/test-services/src/proxy.rs b/test-services/src/proxy.rs index d443a94..e12a7f0 100644 --- a/test-services/src/proxy.rs +++ b/test-services/src/proxy.rs @@ -46,7 +46,7 @@ pub(crate) trait Proxy { #[name = "call"] async fn call(req: Json) -> HandlerResult>>; #[name = "oneWayCall"] - async fn one_way_call(req: Json) -> HandlerResult; + async fn one_way_call(req: Json) -> HandlerResult<()>; #[name = "manyCalls"] async fn many_calls(req: Json>) -> HandlerResult<()>; } @@ -70,19 +70,16 @@ impl Proxy for ProxyImpl { &self, ctx: Context<'_>, Json(req): Json, - ) -> HandlerResult { + ) -> HandlerResult<()> { let request = ctx.request::<_, ()>(req.to_target(), req.message); - let invocation_id = if let Some(delay_millis) = req.delay_millis { - request - .send_after(Duration::from_millis(delay_millis)) - .invocation_id() - .await? + if let Some(delay_millis) = req.delay_millis { + request.send_after(Duration::from_millis(delay_millis)); } else { - request.send().invocation_id().await? - }; + request.send(); + } - Ok(invocation_id) + Ok(()) } async fn many_calls( diff --git a/test-services/src/test_utils_service.rs b/test-services/src/test_utils_service.rs index 4e47b87..a1d84a1 100644 --- a/test-services/src/test_utils_service.rs +++ b/test-services/src/test_utils_service.rs @@ -62,8 +62,6 @@ pub(crate) trait TestUtilsService { async fn count_executed_side_effects(increments: u32) -> HandlerResult; #[name = "getEnvVariable"] async fn get_env_variable(env: String) -> HandlerResult; - #[name = "cancelInvocation"] - async fn cancel_invocation(invocation_id: String) -> Result<(), TerminalError>; #[name = "interpretCommands"] async fn interpret_commands(req: Json) -> HandlerResult<()>; } @@ -157,15 +155,6 @@ impl TestUtilsService for TestUtilsServiceImpl { Ok(std::env::var(env).ok().unwrap_or_default()) } - async fn cancel_invocation( - &self, - ctx: Context<'_>, - invocation_id: String, - ) -> Result<(), TerminalError> { - ctx.invocation_handle(invocation_id).cancel().await?; - Ok(()) - } - async fn interpret_commands( &self, context: Context<'_>, From 7a2a5085de8918808b8de1b72a34834e784c888f Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Tue, 4 Mar 2025 12:57:40 +0100 Subject: [PATCH 05/10] Fix protocol version --- src/endpoint/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/endpoint/mod.rs b/src/endpoint/mod.rs index fbbe2de..0bea694 100644 --- a/src/endpoint/mod.rs +++ b/src/endpoint/mod.rs @@ -197,8 +197,8 @@ impl Default for Builder { Self { svcs: Default::default(), discovery: crate::discovery::Endpoint { - max_protocol_version: 3, - min_protocol_version: 3, + max_protocol_version: 4, + min_protocol_version: 4, protocol_mode: Some(crate::discovery::ProtocolMode::BidiStream), services: vec![], }, From 24dd03fb3d94dcca4a7beaba4f5d142d68f768aa Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Tue, 4 Mar 2025 14:18:23 +0100 Subject: [PATCH 06/10] Fix cancellation propagation --- src/endpoint/futures/async_result_poll.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/endpoint/futures/async_result_poll.rs b/src/endpoint/futures/async_result_poll.rs index 74b5ee7..63eaadb 100644 --- a/src/endpoint/futures/async_result_poll.rs +++ b/src/endpoint/futures/async_result_poll.rs @@ -1,7 +1,8 @@ use crate::endpoint::context::ContextInternalInner; use crate::endpoint::ErrorInner; use restate_sdk_shared_core::{ - DoProgressResponse, Error as CoreError, NotificationHandle, TakeOutputResult, Value, VM, + DoProgressResponse, Error as CoreError, NotificationHandle, TakeOutputResult, TerminalFailure, + Value, VM, }; use std::future::Future; use std::pin::Pin; @@ -117,7 +118,10 @@ impl Future for VmAsyncResultPollFuture { unimplemented!() } Ok(DoProgressResponse::CancelSignalReceived) => { - unimplemented!() + return Poll::Ready(Ok(Value::Failure(TerminalFailure { + code: 409, + message: "cancelled".to_string(), + }))) } Err(e) => { return Poll::Ready(Err(e.into())); From 752e7d81cca75658f016b2d97816f3d40fa9e2b8 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Tue, 4 Mar 2025 16:37:46 +0100 Subject: [PATCH 07/10] Run should now work --- src/endpoint/context.rs | 300 ++++++++++++++++++++-------------------- 1 file changed, 151 insertions(+), 149 deletions(-) diff --git a/src/endpoint/context.rs b/src/endpoint/context.rs index e00df54..89dd52c 100644 --- a/src/endpoint/context.rs +++ b/src/endpoint/context.rs @@ -1,5 +1,5 @@ use crate::context::{ - CallFuture, InvocationHandle, Request, RequestTarget, RunClosure, RunRetryPolicy, + CallFuture, InvocationHandle, Request, RequestTarget, RunClosure, RunFuture, RunRetryPolicy, }; use crate::endpoint::futures::async_result_poll::VmAsyncResultPollFuture; use crate::endpoint::futures::intercept_error::InterceptErrorFuture; @@ -8,20 +8,21 @@ use crate::endpoint::handler_state::HandlerStateNotifier; use crate::endpoint::{Error, ErrorInner, InputReceiver, OutputSender}; use crate::errors::{HandlerErrorInner, HandlerResult, TerminalError}; use crate::serde::{Deserialize, Serialize}; -use futures::future::{Either, Shared}; +use futures::future::{BoxFuture, Either, Shared}; use futures::{FutureExt, TryFutureExt}; use pin_project_lite::pin_project; use restate_sdk_shared_core::{ - CoreVM, Error as CoreError, NonEmptyValue, NotificationHandle, RetryPolicy, TakeOutputResult, - Target, Value, VM, + CoreVM, DoProgressResponse, Error as CoreError, NonEmptyValue, NotificationHandle, RetryPolicy, + RunExitResult, TakeOutputResult, Target, Value, VM, }; use std::borrow::Cow; use std::collections::HashMap; use std::future::{ready, Future}; use std::marker::PhantomData; +use std::mem; use std::pin::Pin; use std::sync::{Arc, Mutex}; -use std::task::{Context, Poll}; +use std::task::{ready, Context, Poll}; use std::time::{Duration, Instant, SystemTime}; pub struct ContextInternalInner { @@ -586,15 +587,14 @@ impl ContextInternal { pub fn run<'a, Run, Fut, Out>( &'a self, run_closure: Run, - ) -> impl crate::context::RunFuture> + Send + 'a + ) -> impl RunFuture> + Send + 'a where Run: RunClosure + Send + 'a, Fut: Future> + Send + 'a, Out: Serialize + Deserialize + 'static, { let this = Arc::clone(&self.inner); - - InterceptErrorFuture::new(self.clone(), RunFuture::new(this, run_closure)) + InterceptErrorFuture::new(self.clone(), RunFutureImpl::new(this, run_closure)) } // Used by codegen @@ -648,52 +648,80 @@ impl ContextInternal { } pin_project! { - struct RunFuture { + struct RunFutureImpl { name: String, retry_policy: RetryPolicy, phantom_data: PhantomData Ret>, - closure: Option, - inner_ctx: Option>>, #[pin] - state: RunState, + state: RunState, } } pin_project! { #[project = RunStateProj] - enum RunState { - New, + enum RunState { + New { + ctx: Option>>, + closure: Option, + }, ClosureRunning { + ctx: Option>>, + handle: NotificationHandle, start_time: Instant, #[pin] - fut: Fut, + closure_fut: RunFnFut, }, - PollFutureRunning { - #[pin] - fut: VmAsyncResultPollFuture + WaitingResultFut { + result_fut: BoxFuture<'static, Result, Error>> } } } -impl RunFuture { - fn new(inner_ctx: Arc>, closure: Run) -> Self { +impl RunFutureImpl { + fn new(ctx: Arc>, closure: Run) -> Self { Self { name: "".to_string(), retry_policy: RetryPolicy::Infinite, phantom_data: PhantomData, - inner_ctx: Some(inner_ctx), - closure: Some(closure), - state: RunState::New, + state: RunState::New { + ctx: Some(ctx), + closure: Some(closure), + }, } } + + fn boxed_result_fut( + ctx: Arc>, + handle: NotificationHandle, + ) -> BoxFuture<'static, Result, Error>> + where + Ret: Deserialize, + { + get_async_result(Arc::clone(&ctx), handle) + .map(|res| match res { + Ok(Value::Success(mut s)) => { + let t = + Ret::deserialize(&mut s).map_err(|e| Error::deserialization("run", e))?; + Ok(Ok(t)) + } + Ok(Value::Failure(f)) => Ok(Err(f.into())), + Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall { + variant: <&'static str>::from(v), + syscall: "run", + } + .into()), + Err(e) => Err(e), + }) + .boxed() + } } -impl crate::context::RunFuture, Error>> - for RunFuture +impl RunFuture, Error>> + for RunFutureImpl where - Run: RunClosure + Send, - Fut: Future> + Send, - Out: Serialize + Deserialize, + Run: RunClosure + Send, + Ret: Serialize + Deserialize, + RunFnFut: Future> + Send, { fn retry_policy(mut self, retry_policy: RunRetryPolicy) -> Self { self.retry_policy = RetryPolicy::Exponential { @@ -712,129 +740,103 @@ where } } -impl Future for RunFuture +impl Future for RunFutureImpl where - Run: RunClosure + Send, - Out: Serialize + Deserialize, - Fut: Future> + Send, + Run: RunClosure + Send, + Ret: Serialize + Deserialize, + RunFnFut: Future> + Send, { - type Output = Result, Error>; - - fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { - unimplemented!() - // let mut this = self.project(); - - // loop { - // match this.state.as_mut().project() { - // RunStateProj::New => { - // let enter_result = { - // must_lock!(this - // .inner_ctx - // .as_mut() - // .expect("Future should not be polled after returning Poll::Ready")) - // .vm - // .sys_run_enter(this.name.to_owned()) - // }; - // - // // Enter the side effect - // match enter_result.map_err(ErrorInner::VM)? { - // RunEnterResult::Executed(NonEmptyValue::Success(mut v)) => { - // let t = Out::deserialize(&mut v).map_err(|e| { - // ErrorInner::Deserialization { - // syscall: "run", - // err: Box::new(e), - // } - // })?; - // return Poll::Ready(Ok(Ok(t))); - // } - // RunEnterResult::Executed(NonEmptyValue::Failure(f)) => { - // return Poll::Ready(Ok(Err(f.into()))) - // } - // RunEnterResult::NotExecuted(_) => {} - // }; - // - // // We need to run the closure - // this.state.set(RunState::ClosureRunning { - // start_time: Instant::now(), - // fut: this - // .closure - // .take() - // .expect("Future should not be polled after returning Poll::Ready") - // .run(), - // }); - // } - // RunStateProj::ClosureRunning { start_time, fut } => { - // let res = match ready!(fut.poll(cx)) { - // Ok(t) => RunExitResult::Success(Out::serialize(&t).map_err(|e| { - // ErrorInner::Serialization { - // syscall: "run", - // err: Box::new(e), - // } - // })?), - // Err(e) => match e.0 { - // HandlerErrorInner::Retryable(err) => RunExitResult::RetryableFailure { - // attempt_duration: start_time.elapsed(), - // failure: TerminalFailure { - // code: 500, - // message: err.to_string(), - // }, - // }, - // HandlerErrorInner::Terminal(t) => { - // RunExitResult::TerminalFailure(TerminalError(t).into()) - // } - // }, - // }; - // - // let inner_ctx = this - // .inner_ctx - // .take() - // .expect("Future should not be polled after returning Poll::Ready"); - // - // let handle = { - // must_lock!(inner_ctx) - // .vm - // .sys_run_exit(res, mem::take(this.retry_policy)) - // }; - // - // this.state.set(RunState::PollFutureRunning { - // fut: VmAsyncResultPollFuture::maybe_new(Cow::Owned(inner_ctx), handle), - // }); - // } - // RunStateProj::PollFutureRunning { fut } => { - // let value = ready!(fut.poll(cx))?; - // - // return Poll::Ready(match value { - // Value::Void => Err(ErrorInner::UnexpectedValueVariantForSyscall { - // variant: "empty", - // syscall: "run", - // } - // .into()), - // Value::Success(mut s) => { - // let t = Out::deserialize(&mut s).map_err(|e| { - // ErrorInner::Deserialization { - // syscall: "run", - // err: Box::new(e), - // } - // })?; - // Ok(Ok(t)) - // } - // Value::Failure(f) => Ok(Err(f.into())), - // Value::StateKeys(_) => Err(ErrorInner::UnexpectedValueVariantForSyscall { - // variant: "state_keys", - // syscall: "run", - // } - // .into()), - // Value::InvocationId(_) => { - // Err(ErrorInner::UnexpectedValueVariantForSyscall { - // variant: "invocation_id", - // syscall: "run", - // } - // .into()) - // } - // }); - // } - // } - // } + type Output = Result, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + loop { + match this.state.as_mut().project() { + RunStateProj::New { ctx, closure, .. } => { + let ctx = ctx + .take() + .expect("Future should not be polled after returning Poll::Ready"); + let closure = closure + .take() + .expect("Future should not be polled after returning Poll::Ready"); + let mut inner_ctx = must_lock!(ctx); + + let handle = inner_ctx + .vm + .sys_run(this.name.to_owned()) + .map_err(ErrorInner::from)?; + + // Now we do progress once to check whether this closure should be executed or not. + match inner_ctx.vm.do_progress(vec![handle]) { + Ok(DoProgressResponse::ExecuteRun(handle_to_run)) => { + // In case it returns ExecuteRun, it must be the handle we just gave it, + // and it means we need to execute the closure + assert_eq!(handle, handle_to_run); + + drop(inner_ctx); + this.state.set(RunState::ClosureRunning { + ctx: Some(ctx), + handle, + start_time: Instant::now(), + closure_fut: closure.run(), + }); + } + _ => { + drop(inner_ctx); + // In all the other cases, just move on waiting the result, + // the poll future state will take care of doing whatever needs to be done here, + // that is propagating state machine error, or result, or whatever + this.state.set(RunState::WaitingResultFut { + result_fut: Self::boxed_result_fut(Arc::clone(&ctx), handle), + }) + } + } + } + RunStateProj::ClosureRunning { + ctx, + handle, + start_time, + closure_fut, + } => { + let res = match ready!(closure_fut.poll(cx)) { + Ok(t) => RunExitResult::Success(Ret::serialize(&t).map_err(|e| { + ErrorInner::Serialization { + syscall: "run", + err: Box::new(e), + } + })?), + Err(e) => match e.0 { + HandlerErrorInner::Retryable(err) => RunExitResult::RetryableFailure { + attempt_duration: start_time.elapsed(), + error: CoreError::new(500u16, err.to_string()), + }, + HandlerErrorInner::Terminal(t) => { + RunExitResult::TerminalFailure(TerminalError(t).into()) + } + }, + }; + + let ctx = ctx + .take() + .expect("Future should not be polled after returning Poll::Ready"); + let handle = *handle; + + let _ = { + must_lock!(ctx).vm.propose_run_completion( + handle, + res, + mem::take(this.retry_policy), + ) + }; + + this.state.set(RunState::WaitingResultFut { + result_fut: Self::boxed_result_fut(Arc::clone(&ctx), handle), + }); + } + RunStateProj::WaitingResultFut { result_fut } => return result_fut.poll_unpin(cx), + } + } } } From 546b02a66202f47290a2d7b03d59a0ad8adfcb62 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Tue, 4 Mar 2025 16:44:59 +0100 Subject: [PATCH 08/10] Revert "For now revert the changes to the test services to run once the integration tests" This reverts commit ae09e68ae706f28c862691bb02bdbbdb4db52ecf. --- test-services/src/proxy.rs | 17 ++++++++++------- test-services/src/test_utils_service.rs | 11 +++++++++++ 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/test-services/src/proxy.rs b/test-services/src/proxy.rs index e12a7f0..d443a94 100644 --- a/test-services/src/proxy.rs +++ b/test-services/src/proxy.rs @@ -46,7 +46,7 @@ pub(crate) trait Proxy { #[name = "call"] async fn call(req: Json) -> HandlerResult>>; #[name = "oneWayCall"] - async fn one_way_call(req: Json) -> HandlerResult<()>; + async fn one_way_call(req: Json) -> HandlerResult; #[name = "manyCalls"] async fn many_calls(req: Json>) -> HandlerResult<()>; } @@ -70,16 +70,19 @@ impl Proxy for ProxyImpl { &self, ctx: Context<'_>, Json(req): Json, - ) -> HandlerResult<()> { + ) -> HandlerResult { let request = ctx.request::<_, ()>(req.to_target(), req.message); - if let Some(delay_millis) = req.delay_millis { - request.send_after(Duration::from_millis(delay_millis)); + let invocation_id = if let Some(delay_millis) = req.delay_millis { + request + .send_after(Duration::from_millis(delay_millis)) + .invocation_id() + .await? } else { - request.send(); - } + request.send().invocation_id().await? + }; - Ok(()) + Ok(invocation_id) } async fn many_calls( diff --git a/test-services/src/test_utils_service.rs b/test-services/src/test_utils_service.rs index a1d84a1..4e47b87 100644 --- a/test-services/src/test_utils_service.rs +++ b/test-services/src/test_utils_service.rs @@ -62,6 +62,8 @@ pub(crate) trait TestUtilsService { async fn count_executed_side_effects(increments: u32) -> HandlerResult; #[name = "getEnvVariable"] async fn get_env_variable(env: String) -> HandlerResult; + #[name = "cancelInvocation"] + async fn cancel_invocation(invocation_id: String) -> Result<(), TerminalError>; #[name = "interpretCommands"] async fn interpret_commands(req: Json) -> HandlerResult<()>; } @@ -155,6 +157,15 @@ impl TestUtilsService for TestUtilsServiceImpl { Ok(std::env::var(env).ok().unwrap_or_default()) } + async fn cancel_invocation( + &self, + ctx: Context<'_>, + invocation_id: String, + ) -> Result<(), TerminalError> { + ctx.invocation_handle(invocation_id).cancel().await?; + Ok(()) + } + async fn interpret_commands( &self, context: Context<'_>, From e4a585688829231b9945d3b3cb7b4b6872206e4e Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Tue, 4 Mar 2025 18:00:16 +0100 Subject: [PATCH 09/10] Adapt to test suite 3.0 --- .github/workflows/integration.yaml | 2 +- test-services/README.md | 2 +- test-services/exclusions.yaml | 23 +- test-services/src/main.rs | 8 + test-services/src/proxy.rs | 21 +- test-services/src/test_utils_service.rs | 107 +------- .../src/virtual_object_command_interpreter.rs | 249 ++++++++++++++++++ 7 files changed, 295 insertions(+), 117 deletions(-) create mode 100644 test-services/src/virtual_object_command_interpreter.rs diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index 865c228..9952479 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -105,7 +105,7 @@ jobs: cache-to: type=gha,mode=max,scope=${{ github.workflow }} - name: Run test tool - uses: restatedev/sdk-test-suite@v2.4 + uses: restatedev/sdk-test-suite@v3.0 with: restateContainerImage: ${{ inputs.restateCommit != '' && 'localhost/restatedev/restate-commit-download:latest' || (inputs.restateImage != '' && inputs.restateImage || 'ghcr.io/restatedev/restate:main') }} serviceContainerImage: "restatedev/rust-test-services" diff --git a/test-services/README.md b/test-services/README.md index 259a391..b08ef7c 100644 --- a/test-services/README.md +++ b/test-services/README.md @@ -9,5 +9,5 @@ $ podman build -f test-services/Dockerfile -t restatedev/rust-test-services . To run (download the [sdk-test-suite](https://github.com/restatedev/sdk-test-suite) first): ```shell -$ java -jar restate-sdk-test-suite.jar run restatedev/rust-test-services +$ java -jar restate-sdk-test-suite.jar run localhost/restatedev/rust-test-services:latest ``` \ No newline at end of file diff --git a/test-services/exclusions.yaml b/test-services/exclusions.yaml index c868ed3..5c9150a 100644 --- a/test-services/exclusions.yaml +++ b/test-services/exclusions.yaml @@ -1,14 +1,21 @@ exclusions: "alwaysSuspending": - - "dev.restate.sdktesting.tests.AwaitTimeout" + - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwaitAny" + - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwakeableTimeoutCommand" + - "dev.restate.sdktesting.tests.Combinators.firstSuccessfulCompletedAwakeable" "default": - - "dev.restate.sdktesting.tests.AwaitTimeout" - - "dev.restate.sdktesting.tests.RawHandler" + - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwaitAny" + - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwakeableTimeoutCommand" + - "dev.restate.sdktesting.tests.Combinators.firstSuccessfulCompletedAwakeable" "singleThreadSinglePartition": - - "dev.restate.sdktesting.tests.AwaitTimeout" - - "dev.restate.sdktesting.tests.RawHandler" + - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwaitAny" + - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwakeableTimeoutCommand" + - "dev.restate.sdktesting.tests.Combinators.firstSuccessfulCompletedAwakeable" "threeNodes": - - "dev.restate.sdktesting.tests.AwaitTimeout" - - "dev.restate.sdktesting.tests.RawHandler" + - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwaitAny" + - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwakeableTimeoutCommand" + - "dev.restate.sdktesting.tests.Combinators.firstSuccessfulCompletedAwakeable" "threeNodesAlwaysSuspending": - - "dev.restate.sdktesting.tests.AwaitTimeout" + - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwaitAny" + - "dev.restate.sdktesting.tests.Combinators.awakeableOrTimeoutUsingAwakeableTimeoutCommand" + - "dev.restate.sdktesting.tests.Combinators.firstSuccessfulCompletedAwakeable" diff --git a/test-services/src/main.rs b/test-services/src/main.rs index d4f261a..bde49a9 100644 --- a/test-services/src/main.rs +++ b/test-services/src/main.rs @@ -9,6 +9,7 @@ mod map_object; mod non_deterministic; mod proxy; mod test_utils_service; +mod virtual_object_command_interpreter; use restate_sdk::prelude::{Endpoint, HttpServer}; use std::env; @@ -76,6 +77,13 @@ async fn main() { test_utils_service::TestUtilsServiceImpl, )) } + if services == "*" || services.contains("VirtualObjectCommandInterpreter") { + builder = builder.bind( + virtual_object_command_interpreter::VirtualObjectCommandInterpreter::serve( + virtual_object_command_interpreter::VirtualObjectCommandInterpreterImpl, + ), + ) + } if let Ok(key) = env::var("E2E_REQUEST_SIGNING_ENV") { builder = builder.identity_key(&key).unwrap() diff --git a/test-services/src/proxy.rs b/test-services/src/proxy.rs index d443a94..6b1f221 100644 --- a/test-services/src/proxy.rs +++ b/test-services/src/proxy.rs @@ -11,6 +11,7 @@ pub(crate) struct ProxyRequest { service_name: String, virtual_object_key: Option, handler_name: String, + idempotency_key: Option, message: Vec, delay_millis: Option, } @@ -59,11 +60,11 @@ impl Proxy for ProxyImpl { ctx: Context<'_>, Json(req): Json, ) -> HandlerResult>> { - Ok(ctx - .request::, Vec>(req.to_target(), req.message) - .call() - .await? - .into()) + let mut request = ctx.request::, Vec>(req.to_target(), req.message); + if let Some(idempotency_key) = req.idempotency_key { + request = request.idempotency_key(idempotency_key); + } + Ok(request.call().await?.into()) } async fn one_way_call( @@ -71,7 +72,10 @@ impl Proxy for ProxyImpl { ctx: Context<'_>, Json(req): Json, ) -> HandlerResult { - let request = ctx.request::<_, ()>(req.to_target(), req.message); + let mut request = ctx.request::<_, ()>(req.to_target(), req.message); + if let Some(idempotency_key) = req.idempotency_key { + request = request.idempotency_key(idempotency_key); + } let invocation_id = if let Some(delay_millis) = req.delay_millis { request @@ -93,8 +97,11 @@ impl Proxy for ProxyImpl { let mut futures: Vec, TerminalError>>> = vec![]; for req in requests { - let restate_req = + let mut restate_req = ctx.request::<_, Vec>(req.proxy_request.to_target(), req.proxy_request.message); + if let Some(idempotency_key) = req.proxy_request.idempotency_key { + restate_req = restate_req.idempotency_key(idempotency_key); + } if req.one_way_call { if let Some(delay_millis) = req.proxy_request.delay_millis { restate_req.send_after(Duration::from_millis(delay_millis)); diff --git a/test-services/src/test_utils_service.rs b/test-services/src/test_utils_service.rs index 4e47b87..0152062 100644 --- a/test-services/src/test_utils_service.rs +++ b/test-services/src/test_utils_service.rs @@ -1,48 +1,12 @@ -use crate::awakeable_holder; -use crate::list_object::ListObjectClient; use futures::future::BoxFuture; use futures::FutureExt; use restate_sdk::prelude::*; -use serde::{Deserialize, Serialize}; use std::collections::HashMap; +use std::convert::Infallible; use std::sync::atomic::{AtomicU8, Ordering}; use std::sync::Arc; use std::time::Duration; -#[derive(Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct CreateAwakeableAndAwaitItRequest { - awakeable_key: String, - await_timeout: Option, -} - -#[derive(Serialize, Deserialize)] -#[serde(tag = "type")] -#[serde(rename_all_fields = "camelCase")] -pub(crate) enum CreateAwakeableAndAwaitItResponse { - #[serde(rename = "timeout")] - Timeout, - #[serde(rename = "result")] - Result { value: String }, -} - -#[derive(Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct InterpretRequest { - list_name: String, - commands: Vec, -} - -#[derive(Serialize, Deserialize)] -#[serde(tag = "type")] -#[serde(rename_all_fields = "camelCase")] -pub(crate) enum InterpretCommand { - #[serde(rename = "createAwakeableAndAwaitIt")] - CreateAwakeableAndAwaitIt { awakeable_key: String }, - #[serde(rename = "getEnvVariable")] - GetEnvVariable { env_name: String }, -} - #[restate_sdk::service] #[name = "TestUtilsService"] pub(crate) trait TestUtilsService { @@ -50,22 +14,16 @@ pub(crate) trait TestUtilsService { async fn echo(input: String) -> HandlerResult; #[name = "uppercaseEcho"] async fn uppercase_echo(input: String) -> HandlerResult; + #[name = "rawEcho"] + async fn raw_echo(input: Vec) -> Result, Infallible>; #[name = "echoHeaders"] async fn echo_headers() -> HandlerResult>>; - #[name = "createAwakeableAndAwaitIt"] - async fn create_awakeable_and_await_it( - req: Json, - ) -> HandlerResult>; #[name = "sleepConcurrently"] async fn sleep_concurrently(millis_durations: Json>) -> HandlerResult<()>; #[name = "countExecutedSideEffects"] async fn count_executed_side_effects(increments: u32) -> HandlerResult; - #[name = "getEnvVariable"] - async fn get_env_variable(env: String) -> HandlerResult; #[name = "cancelInvocation"] async fn cancel_invocation(invocation_id: String) -> Result<(), TerminalError>; - #[name = "interpretCommands"] - async fn interpret_commands(req: Json) -> HandlerResult<()>; } pub(crate) struct TestUtilsServiceImpl; @@ -79,6 +37,10 @@ impl TestUtilsService for TestUtilsServiceImpl { Ok(input.to_ascii_uppercase()) } + async fn raw_echo(&self, _: Context<'_>, input: Vec) -> Result, Infallible> { + Ok(input) + } + async fn echo_headers( &self, context: Context<'_>, @@ -94,27 +56,6 @@ impl TestUtilsService for TestUtilsServiceImpl { Ok(headers.into()) } - async fn create_awakeable_and_await_it( - &self, - context: Context<'_>, - Json(req): Json, - ) -> HandlerResult> { - if req.await_timeout.is_some() { - unimplemented!("await timeout is not yet implemented"); - } - - let (awk_id, awakeable) = context.awakeable::(); - - context - .object_client::(req.awakeable_key) - .hold(awk_id) - .call() - .await?; - let value = awakeable.await?; - - Ok(CreateAwakeableAndAwaitItResponse::Result { value }.into()) - } - async fn sleep_concurrently( &self, context: Context<'_>, @@ -153,10 +94,6 @@ impl TestUtilsService for TestUtilsServiceImpl { Ok(counter.load(Ordering::SeqCst) as u32) } - async fn get_env_variable(&self, _: Context<'_>, env: String) -> HandlerResult { - Ok(std::env::var(env).ok().unwrap_or_default()) - } - async fn cancel_invocation( &self, ctx: Context<'_>, @@ -165,34 +102,4 @@ impl TestUtilsService for TestUtilsServiceImpl { ctx.invocation_handle(invocation_id).cancel().await?; Ok(()) } - - async fn interpret_commands( - &self, - context: Context<'_>, - Json(req): Json, - ) -> HandlerResult<()> { - let list_client = context.object_client::(req.list_name); - - for cmd in req.commands { - match cmd { - InterpretCommand::CreateAwakeableAndAwaitIt { awakeable_key } => { - let (awk_id, awakeable) = context.awakeable::(); - context - .object_client::(awakeable_key) - .hold(awk_id) - .call() - .await?; - let value = awakeable.await?; - list_client.append(value).send(); - } - InterpretCommand::GetEnvVariable { env_name } => { - list_client - .append(std::env::var(env_name).ok().unwrap_or_default()) - .send(); - } - } - } - - Ok(()) - } } diff --git a/test-services/src/virtual_object_command_interpreter.rs b/test-services/src/virtual_object_command_interpreter.rs new file mode 100644 index 0000000..994d20c --- /dev/null +++ b/test-services/src/virtual_object_command_interpreter.rs @@ -0,0 +1,249 @@ +use anyhow::anyhow; +use futures::TryFutureExt; +use restate_sdk::prelude::*; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +#[derive(Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct InterpretRequest { + commands: Vec, +} + +#[derive(Serialize, Deserialize)] +#[serde(tag = "type")] +#[serde(rename_all_fields = "camelCase")] +pub(crate) enum Command { + #[serde(rename = "awaitAnySuccessful")] + AwaitAnySuccessful { commands: Vec }, + #[serde(rename = "awaitAny")] + AwaitAny { commands: Vec }, + #[serde(rename = "awaitOne")] + AwaitOne { command: AwaitableCommand }, + #[serde(rename = "awaitAwakeableOrTimeout")] + AwaitAwakeableOrTimeout { + awakeable_key: String, + timeout_millis: u64, + }, + #[serde(rename = "resolveAwakeable")] + ResolveAwakeable { + awakeable_key: String, + value: String, + }, + #[serde(rename = "rejectAwakeable")] + RejectAwakeable { + awakeable_key: String, + reason: String, + }, + #[serde(rename = "getEnvVariable")] + GetEnvVariable { env_name: String }, +} + +#[derive(Serialize, Deserialize)] +#[serde(tag = "type")] +#[serde(rename_all_fields = "camelCase")] +pub(crate) enum AwaitableCommand { + #[serde(rename = "createAwakeable")] + CreateAwakeable { awakeable_key: String }, + #[serde(rename = "sleep")] + Sleep { timeout_millis: u64 }, + #[serde(rename = "runThrowTerminalException")] + RunThrowTerminalException { reason: String }, +} + +#[derive(Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct ResolveAwakeable { + awakeable_key: String, + value: String, +} + +#[derive(Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct RejectAwakeable { + awakeable_key: String, + reason: String, +} + +#[restate_sdk::object] +#[name = "VirtualObjectCommandInterpreter"] +pub(crate) trait VirtualObjectCommandInterpreter { + #[name = "interpretCommands"] + async fn interpret_commands(req: Json) -> HandlerResult; + + #[name = "resolveAwakeable"] + #[shared] + async fn resolve_awakeable(req: Json) -> HandlerResult<()>; + + #[name = "rejectAwakeable"] + #[shared] + async fn reject_awakeable(req: Json) -> HandlerResult<()>; + + #[name = "hasAwakeable"] + #[shared] + async fn has_awakeable(awakeable_key: String) -> HandlerResult; + + #[name = "getResults"] + #[shared] + async fn get_results() -> HandlerResult>>; +} + +pub(crate) struct VirtualObjectCommandInterpreterImpl; + +impl VirtualObjectCommandInterpreter for VirtualObjectCommandInterpreterImpl { + async fn interpret_commands( + &self, + context: ObjectContext<'_>, + Json(req): Json, + ) -> HandlerResult { + let mut last_result: String = Default::default(); + + for cmd in req.commands { + match cmd { + Command::AwaitAny { .. } => { + Err(anyhow!("AwaitAny is currently unsupported in the Rust SDK"))? + } + Command::AwaitAnySuccessful { .. } => Err(anyhow!( + "AwaitAnySuccessful is currently unsupported in the Rust SDK" + ))?, + Command::AwaitAwakeableOrTimeout { .. } => Err(anyhow!( + "AwaitAwakeableOrTimeout is currently unsupported in the Rust SDK" + ))?, + Command::AwaitOne { command } => { + last_result = match command { + AwaitableCommand::CreateAwakeable { awakeable_key } => { + let (awakeable_id, fut) = context.awakeable::(); + context.set::(&format!("awk-{awakeable_key}"), awakeable_id); + fut.await? + } + AwaitableCommand::Sleep { timeout_millis } => { + context + .sleep(Duration::from_millis(timeout_millis)) + .map_ok(|_| "sleep".to_string()) + .await? + } + AwaitableCommand::RunThrowTerminalException { reason } => { + context + .run::<_, _, String>( + || async move { Err(TerminalError::new(reason))? }, + ) + .await? + } + } + } + Command::GetEnvVariable { env_name } => { + last_result = std::env::var(env_name).ok().unwrap_or_default(); + } + Command::ResolveAwakeable { + awakeable_key, + value, + } => { + let Some(awakeable_id) = context + .get::(&format!("awk-{awakeable_key}")) + .await? + else { + Err(TerminalError::new( + "Awakeable is not registered yet".to_string(), + ))? + }; + + context.resolve_awakeable(&awakeable_id, value); + last_result = Default::default(); + } + Command::RejectAwakeable { + awakeable_key, + reason, + } => { + let Some(awakeable_id) = context + .get::(&format!("awk-{awakeable_key}")) + .await? + else { + Err(TerminalError::new( + "Awakeable is not registered yet".to_string(), + ))? + }; + + context.reject_awakeable(&awakeable_id, TerminalError::new(reason)); + last_result = Default::default(); + } + } + + let mut old_results = context + .get::>>("results") + .await? + .unwrap_or_default() + .into_inner(); + old_results.push(last_result.clone()); + context.set("results", Json(old_results)); + } + + Ok(last_result) + } + + async fn resolve_awakeable( + &self, + context: SharedObjectContext<'_>, + req: Json, + ) -> Result<(), HandlerError> { + let ResolveAwakeable { + awakeable_key, + value, + } = req.into_inner(); + let Some(awakeable_id) = context + .get::(&format!("awk-{awakeable_key}")) + .await? + else { + Err(TerminalError::new( + "Awakeable is not registered yet".to_string(), + ))? + }; + + context.resolve_awakeable(&awakeable_id, value); + + Ok(()) + } + + async fn reject_awakeable( + &self, + context: SharedObjectContext<'_>, + req: Json, + ) -> Result<(), HandlerError> { + let RejectAwakeable { + awakeable_key, + reason, + } = req.into_inner(); + let Some(awakeable_id) = context + .get::(&format!("awk-{awakeable_key}")) + .await? + else { + Err(TerminalError::new( + "Awakeable is not registered yet".to_string(), + ))? + }; + + context.reject_awakeable(&awakeable_id, TerminalError::new(reason)); + + Ok(()) + } + + async fn has_awakeable( + &self, + context: SharedObjectContext<'_>, + awakeable_key: String, + ) -> Result { + Ok(context + .get::(&format!("awk-{awakeable_key}")) + .await? + .is_some()) + } + + async fn get_results( + &self, + context: SharedObjectContext<'_>, + ) -> Result>, HandlerError> { + Ok(context + .get::>>("results") + .await? + .unwrap_or_default()) + } +} From 08e31c0d34d53a67c89a484322d797e2a36ed01a Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Tue, 4 Mar 2025 18:10:37 +0100 Subject: [PATCH 10/10] Update compat matrix table --- README.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index ea39d16..1e01ccd 100644 --- a/README.md +++ b/README.md @@ -121,10 +121,11 @@ The Rust SDK is currently in active development, and might break across releases The compatibility with Restate is described in the following table: -| Restate Server\sdk-rust | 0.0/0.1/0.2 | 0.3 | -|-------------------------|-------------|-----| -| 1.0 | ✅ | ❌ | -| 1.1 | ✅ | ✅ | +| Restate Server\sdk-rust | 0.0/0.1/0.2 | 0.3 | 0.4 | +|-------------------------|-------------|-----|-----| +| 1.0 | ✅ | ❌ | ❌ | +| 1.1 | ✅ | ✅ | ❌ | +| 1.2 | ✅ | ✅ | ✅ | ## Contributing