diff --git a/macros/src/ast.rs b/macros/src/ast.rs index 03ceeaa..b485b64 100644 --- a/macros/src/ast.rs +++ b/macros/src/ast.rs @@ -16,8 +16,8 @@ use syn::parse::{Parse, ParseStream}; use syn::spanned::Spanned; use syn::token::Comma; use syn::{ - braced, parenthesized, Attribute, Error, Expr, ExprLit, FnArg, GenericArgument, Ident, Lit, - Pat, PatType, Path, PathArguments, Result, ReturnType, Token, Type, Visibility, + braced, parenthesized, parse_quote, Attribute, Error, Expr, ExprLit, FnArg, GenericArgument, + Ident, Lit, Pat, PatType, Path, PathArguments, Result, ReturnType, Token, Type, Visibility, }; /// Accumulates multiple errors into a result. @@ -145,7 +145,7 @@ pub(crate) struct Handler { pub(crate) restate_name: String, pub(crate) ident: Ident, pub(crate) arg: Option, - pub(crate) output: ReturnType, + pub(crate) output: Type, } impl Parse for Handler { @@ -189,20 +189,24 @@ impl Parse for Handler { errors?; // Parse return type - let output: ReturnType = input.parse()?; + let return_type: ReturnType = input.parse()?; input.parse::()?; - match &output { - ReturnType::Default => {} + let output: Type = match &return_type { + ReturnType::Default => { + parse_quote!(()) + } ReturnType::Type(_, ty) => { - if handler_result_parameter(ty).is_none() { + if let Some(ty) = extract_handler_result_parameter(ty) { + ty + } else { return Err(Error::new( - output.span(), + return_type.span(), "Only restate_sdk::prelude::HandlerResult is supported as return type", )); } } - } + }; // Process attributes let mut is_shared = false; @@ -259,7 +263,7 @@ fn read_literal_attribute_name(attr: &Attribute) -> Result> { .transpose() } -fn handler_result_parameter(ty: &Type) -> Option<&Type> { +fn extract_handler_result_parameter(ty: &Type) -> Option { let path = match ty { Type::Path(ty) => &ty.path, _ => return None, @@ -280,7 +284,7 @@ fn handler_result_parameter(ty: &Type) -> Option<&Type> { } match &bracketed.args[0] { - GenericArgument::Type(arg) => Some(arg), + GenericArgument::Type(arg) => Some(arg.clone()), _ => None, } } diff --git a/macros/src/gen.rs b/macros/src/gen.rs index 8d8b066..f67435d 100644 --- a/macros/src/gen.rs +++ b/macros/src/gen.rs @@ -2,12 +2,13 @@ use crate::ast::{Handler, Object, Service, ServiceInner, ServiceType, Workflow}; use proc_macro2::TokenStream as TokenStream2; use proc_macro2::{Ident, Literal}; use quote::{format_ident, quote, ToTokens}; -use syn::{parse_quote, Attribute, ReturnType, Type, Visibility}; +use syn::{Attribute, PatType, Visibility}; pub(crate) struct ServiceGenerator<'a> { pub(crate) service_ty: ServiceType, pub(crate) restate_name: &'a str, pub(crate) service_ident: &'a Ident, + pub(crate) client_ident: Ident, pub(crate) serve_ident: Ident, pub(crate) vis: &'a Visibility, pub(crate) attrs: &'a [Attribute], @@ -20,6 +21,7 @@ impl<'a> ServiceGenerator<'a> { service_ty, restate_name: &s.restate_name, service_ident: &s.ident, + client_ident: format_ident!("{}Client", s.ident), serve_ident: format_ident!("Serve{}", s.ident), vis: &s.vis, attrs: &s.attrs, @@ -50,8 +52,6 @@ impl<'a> ServiceGenerator<'a> { .. } = self; - let unit_type: &Type = &parse_quote!(()); - let handler_fns = handlers .iter() .map( @@ -66,13 +66,9 @@ impl<'a> ServiceGenerator<'a> { (ServiceType::Workflow, false) => quote! { ::restate_sdk::prelude::WorkflowContext }, }; - let output = match output { - ReturnType::Type(_, ref ty) => ty.as_ref(), - ReturnType::Default => unit_type, - }; quote! { #( #attrs )* - fn #ident(&self, context: #ctx, #( #args ),*) -> impl std::future::Future + ::core::marker::Send; + fn #ident(&self, context: #ctx, #( #args ),*) -> impl std::future::Future> + ::core::marker::Send; } }, ); @@ -223,6 +219,123 @@ impl<'a> ServiceGenerator<'a> { } } } + + fn struct_client(&self) -> TokenStream2 { + let &Self { + vis, + ref client_ident, + // service_ident, + ref service_ty, + .. + } = self; + + let key_field = match service_ty { + ServiceType::Service => quote! {}, + ServiceType::Object | ServiceType::Workflow => quote! { + key: String, + }, + }; + + let into_client_impl = match service_ty { + ServiceType::Service => { + quote! { + impl<'ctx> ::restate_sdk::context::IntoServiceClient<'ctx> for #client_ident<'ctx> { + fn create_client(ctx: &'ctx ::restate_sdk::endpoint::ContextInternal) -> Self { + Self { ctx } + } + } + } + } + ServiceType::Object => quote! { + impl<'ctx> ::restate_sdk::context::IntoObjectClient<'ctx> for #client_ident<'ctx> { + fn create_client(ctx: &'ctx ::restate_sdk::endpoint::ContextInternal, key: String) -> Self { + Self { ctx, key } + } + } + }, + ServiceType::Workflow => quote! { + impl<'ctx> ::restate_sdk::context::IntoWorkflowClient<'ctx> for #client_ident<'ctx> { + fn create_client(ctx: &'ctx ::restate_sdk::endpoint::ContextInternal, key: String) -> Self { + Self { ctx, key } + } + } + }, + }; + + quote! { + /// Struct exposing the client to invoke [#service_ident] from another service. + #vis struct #client_ident<'ctx> { + ctx: &'ctx ::restate_sdk::endpoint::ContextInternal, + #key_field + } + + #into_client_impl + } + } + + fn impl_client(&self) -> TokenStream2 { + let &Self { + vis, + ref client_ident, + service_ident, + handlers, + restate_name, + service_ty, + .. + } = self; + + let service_literal = Literal::string(restate_name); + + let handlers_fns = handlers.iter().map(|handler| { + let handler_ident = &handler.ident; + let handler_literal = Literal::string(&handler.restate_name); + + let argument = match &handler.arg { + None => quote! {}, + Some(PatType { + ty, .. + }) => quote! { req: #ty } + }; + let argument_ty = match &handler.arg { + None => quote! { () }, + Some(PatType { + ty, .. + }) => quote! { #ty } + }; + let res_ty = &handler.output; + let input = match &handler.arg { + None => quote! { () }, + Some(_) => quote! { req } + }; + let request_target = match service_ty { + ServiceType::Service => quote! { + ::restate_sdk::context::RequestTarget::service(#service_literal, #handler_literal) + }, + ServiceType::Object => quote! { + ::restate_sdk::context::RequestTarget::object(#service_literal, &self.key, #handler_literal) + }, + ServiceType::Workflow => quote! { + ::restate_sdk::context::RequestTarget::workflow(#service_literal, &self.key, #handler_literal) + } + }; + + quote! { + #vis fn #handler_ident(&self, #argument) -> ::restate_sdk::context::Request<'ctx, #argument_ty, #res_ty> { + self.ctx.request(#request_target, #input) + } + } + }); + + let doc_msg = format!( + "Struct exposing the client to invoke [`{service_ident}`] from another service." + ); + quote! { + #[doc = #doc_msg] + impl<'ctx> #client_ident<'ctx> { + #( #handlers_fns )* + } + } + } } impl<'a> ToTokens for ServiceGenerator<'a> { @@ -232,6 +345,8 @@ impl<'a> ToTokens for ServiceGenerator<'a> { self.struct_serve(), self.impl_service_for_serve(), self.impl_discoverable(), + self.struct_client(), + self.impl_client(), ]); } } diff --git a/src/context.rs b/src/context.rs deleted file mode 100644 index 5cc05cc..0000000 --- a/src/context.rs +++ /dev/null @@ -1,329 +0,0 @@ -use crate::endpoint::{ContextInternal, InputMetadata}; -use crate::errors::{HandlerResult, TerminalError}; -use crate::serde::{Deserialize, Serialize}; -use std::fmt; -use std::future::Future; -use std::time::Duration; - -pub struct Context<'a> { - inner: &'a ContextInternal, -} - -impl<'a> From<(&'a ContextInternal, InputMetadata)> for Context<'a> { - fn from(value: (&'a ContextInternal, InputMetadata)) -> Self { - Self { inner: value.0 } - } -} - -pub struct SharedObjectContext<'a> { - key: String, - pub(crate) inner: &'a ContextInternal, -} - -impl<'a> SharedObjectContext<'a> { - pub fn key(&self) -> &str { - &self.key - } -} - -impl<'a> From<(&'a ContextInternal, InputMetadata)> for SharedObjectContext<'a> { - fn from(value: (&'a ContextInternal, InputMetadata)) -> Self { - Self { - key: value.1.key, - inner: value.0, - } - } -} - -pub struct ObjectContext<'a> { - key: String, - pub(crate) inner: &'a ContextInternal, -} - -impl<'a> ObjectContext<'a> { - pub fn key(&self) -> &str { - &self.key - } -} - -impl<'a> From<(&'a ContextInternal, InputMetadata)> for ObjectContext<'a> { - fn from(value: (&'a ContextInternal, InputMetadata)) -> Self { - Self { - key: value.1.key, - inner: value.0, - } - } -} - -pub struct SharedWorkflowContext<'a> { - key: String, - pub(crate) inner: &'a ContextInternal, -} - -impl<'a> From<(&'a ContextInternal, InputMetadata)> for SharedWorkflowContext<'a> { - fn from(value: (&'a ContextInternal, InputMetadata)) -> Self { - Self { - key: value.1.key, - inner: value.0, - } - } -} - -impl<'a> SharedWorkflowContext<'a> { - pub fn key(&self) -> &str { - &self.key - } -} - -pub struct WorkflowContext<'a> { - key: String, - pub(crate) inner: &'a ContextInternal, -} - -impl<'a> From<(&'a ContextInternal, InputMetadata)> for WorkflowContext<'a> { - fn from(value: (&'a ContextInternal, InputMetadata)) -> Self { - Self { - key: value.1.key, - inner: value.0, - } - } -} - -impl<'a> WorkflowContext<'a> { - pub fn key(&self) -> &str { - &self.key - } -} - -// Little macro to simplify implementing context methods on all the context structs -macro_rules! impl_context_method { - ([$ctx:ident, $($morectx:ident),*]; $($sig:tt)*) => { - impl_context_method!(@render_impl $ctx; $($sig)*); - impl_context_method!([$($morectx),*]; $($sig)*); - }; - ([$ctx:ident]; $($sig:tt)*) => { - impl_context_method!(@render_impl $ctx; $($sig)*); - }; - (@render_impl $ctx:ident; #[doc = $doc:expr] async fn $name:ident $(< $( $lt:tt $( : $clt:tt $(+ $dlt:tt )* )? ),+ >)? ($($param:ident : $ty:ty),*) -> $ret:ty $(where $( $wlt:tt $( : $wclt:tt $(+ $wdlt:tt )* )? ),+ )?) => { - impl<'a> $ctx<'a> { - #[doc = $doc] - pub fn $name $(< $( $lt $( : $clt $(+ $dlt )* )? ),+ >)? (&self, $($param: $ty),*) -> impl Future + 'a { - self.inner.$name($($param),*) - } - } - }; - (@render_impl $ctx:ident; #[doc = $doc:expr] fn $name:ident $(< $( $lt:tt $( : $clt:tt $(+ $dlt:tt )* )? ),+ >)? ($($param:ident : $ty:ty),*) -> $ret:ty $(where $( $wlt:tt $( : $wclt:tt $(+ $wdlt:tt )* )? ),+ )?) => { - impl<'a> $ctx<'a> { - #[doc = $doc] - pub fn $name $(< $( $lt $( : $clt $(+ $dlt )* )? ),+ >)? (&self, $($param: $ty),*) -> $ret { - self.inner.$name($($param),*) - } - } - }; -} - -// State read methods -impl_context_method!( - [SharedObjectContext, ObjectContext, SharedWorkflowContext, WorkflowContext]; - /// Get state - async fn get(key: &str) -> Result, TerminalError> -); -impl_context_method!( - [SharedObjectContext, ObjectContext, SharedWorkflowContext, WorkflowContext]; - /// Get state - async fn get_keys() -> Result, TerminalError> -); - -// State write methods -impl_context_method!( - [ObjectContext, WorkflowContext]; - /// Set state - fn set(key: &str, t: T) -> () -); -impl_context_method!( - [ObjectContext, WorkflowContext]; - /// Clear state - fn clear(key: &str) -> () -); -impl_context_method!( - [ObjectContext, WorkflowContext]; - /// Clear state - fn clear_all() -> () -); - -// Sleep -impl_context_method!( - [Context, SharedObjectContext, ObjectContext, SharedWorkflowContext, WorkflowContext]; - /// Sleep using Restate - async fn sleep(duration: Duration) -> Result<(), TerminalError> -); - -// Calls -#[derive(Debug, Clone)] -pub enum RequestTarget { - Service { - name: String, - handler: String, - }, - Object { - name: String, - key: String, - handler: String, - }, - Workflow { - name: String, - key: String, - handler: String, - }, -} - -impl RequestTarget { - pub fn service(name: impl Into, handler: impl Into) -> Self { - Self::Service { - name: name.into(), - handler: handler.into(), - } - } - - pub fn object( - name: impl Into, - key: impl Into, - handler: impl Into, - ) -> Self { - Self::Object { - name: name.into(), - key: key.into(), - handler: handler.into(), - } - } - - pub fn workflow( - name: impl Into, - key: impl Into, - handler: impl Into, - ) -> Self { - Self::Workflow { - name: name.into(), - key: key.into(), - handler: handler.into(), - } - } -} - -impl fmt::Display for RequestTarget { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - RequestTarget::Service { name, handler } => write!(f, "{name}/{handler}"), - RequestTarget::Object { name, key, handler } => write!(f, "{name}/{key}/{handler}"), - RequestTarget::Workflow { name, key, handler } => write!(f, "{name}/{key}/{handler}"), - } - } -} - -impl_context_method!( - [Context, SharedObjectContext, ObjectContext, SharedWorkflowContext, WorkflowContext]; - /// Call another Restate service - async fn call(request_target: RequestTarget, req: Req) -> Result -); -impl_context_method!( - [Context, SharedObjectContext, ObjectContext, SharedWorkflowContext, WorkflowContext]; - /// Call another Restate service one way - fn send(request_target: RequestTarget, req: Req, delay: Option) -> () -); - -// Awakeables -impl_context_method!( - [Context, SharedObjectContext, ObjectContext, SharedWorkflowContext, WorkflowContext]; - /// Create an awakeable - fn awakeable() -> (String, impl Future> + Send + Sync) -); -impl_context_method!( - [Context, SharedObjectContext, ObjectContext, SharedWorkflowContext, WorkflowContext]; - /// Resolve an awakeable - fn resolve_awakeable(key: &str, t: T) -> () -); -impl_context_method!( - [Context, SharedObjectContext, ObjectContext, SharedWorkflowContext, WorkflowContext]; - /// Resolve an awakeable - fn reject_awakeable(key: &str, failure: TerminalError) -> () -); - -// Promises -impl_context_method!( - [SharedWorkflowContext, WorkflowContext]; - /// Create a promise - async fn promise(key: &str) -> Result -); -impl_context_method!( - [SharedWorkflowContext, WorkflowContext]; - /// Peek a promise - async fn peek_promise(key: &str) -> Result, TerminalError> -); -impl_context_method!( - [SharedWorkflowContext, WorkflowContext]; - /// Resolve a promise - fn resolve_promise(key: &str, t: T) -> () -); -impl_context_method!( - [Context, SharedObjectContext, ObjectContext, SharedWorkflowContext, WorkflowContext]; - /// Resolve a promise - fn reject_promise(key: &str, failure: TerminalError) -> () -); - -// Run -pub trait RunClosure { - type Output: Deserialize + Serialize + 'static; - type Fut: Future>; - - fn run(self) -> Self::Fut; -} - -impl RunClosure for F -where - F: FnOnce() -> Fut, - Fut: Future>, - O: Deserialize + Serialize + 'static, -{ - type Output = O; - type Fut = Fut; - - fn run(self) -> Self::Fut { - self() - } -} - -// ad-hoc macro to copy paste run -macro_rules! impl_run_method { - ([$ctx:ident, $($morectx:ident),*]) => { - impl_run_method!(@render_impl $ctx); - impl_run_method!([$($morectx),*]); - }; - ([$ctx:ident]) => { - impl_run_method!(@render_impl $ctx); - }; - (@render_impl $ctx:ident) => { - impl<'a> $ctx<'a> { - /// Run a non-deterministic operation - pub fn run( - &self, - name: &'a str, - run_closure: R, - ) -> impl Future> + 'a - where - R: RunClosure + Send + Sync + 'a, - T: Serialize + Deserialize, - F: Future> + Send + Sync + 'a, - { - self.inner.run(name, run_closure) - } - } - }; -} - -impl_run_method!([ - Context, - SharedObjectContext, - ObjectContext, - SharedWorkflowContext, - WorkflowContext -]); diff --git a/src/context/mod.rs b/src/context/mod.rs new file mode 100644 index 0000000..6e2f185 --- /dev/null +++ b/src/context/mod.rs @@ -0,0 +1,327 @@ +use crate::endpoint::{ContextInternal, InputMetadata}; +use crate::errors::{HandlerResult, TerminalError}; +use crate::serde::{Deserialize, Serialize}; +use std::future::Future; +use std::time::Duration; + +mod request; +mod run; + +pub use request::{Request, RequestTarget}; +pub use run::RunClosure; + +pub struct Context<'ctx> { + inner: &'ctx ContextInternal, +} + +impl<'ctx> From<(&'ctx ContextInternal, InputMetadata)> for Context<'ctx> { + fn from(value: (&'ctx ContextInternal, InputMetadata)) -> Self { + Self { inner: value.0 } + } +} + +pub struct SharedObjectContext<'ctx> { + key: String, + pub(crate) inner: &'ctx ContextInternal, +} + +impl<'ctx> SharedObjectContext<'ctx> { + pub fn key(&self) -> &str { + &self.key + } +} + +impl<'ctx> From<(&'ctx ContextInternal, InputMetadata)> for SharedObjectContext<'ctx> { + fn from(value: (&'ctx ContextInternal, InputMetadata)) -> Self { + Self { + key: value.1.key, + inner: value.0, + } + } +} + +pub struct ObjectContext<'ctx> { + key: String, + pub(crate) inner: &'ctx ContextInternal, +} + +impl<'ctx> ObjectContext<'ctx> { + pub fn key(&self) -> &str { + &self.key + } +} + +impl<'ctx> From<(&'ctx ContextInternal, InputMetadata)> for ObjectContext<'ctx> { + fn from(value: (&'ctx ContextInternal, InputMetadata)) -> Self { + Self { + key: value.1.key, + inner: value.0, + } + } +} + +pub struct SharedWorkflowContext<'ctx> { + key: String, + pub(crate) inner: &'ctx ContextInternal, +} + +impl<'ctx> From<(&'ctx ContextInternal, InputMetadata)> for SharedWorkflowContext<'ctx> { + fn from(value: (&'ctx ContextInternal, InputMetadata)) -> Self { + Self { + key: value.1.key, + inner: value.0, + } + } +} + +impl<'ctx> SharedWorkflowContext<'ctx> { + pub fn key(&self) -> &str { + &self.key + } +} + +pub struct WorkflowContext<'ctx> { + key: String, + pub(crate) inner: &'ctx ContextInternal, +} + +impl<'ctx> From<(&'ctx ContextInternal, InputMetadata)> for WorkflowContext<'ctx> { + fn from(value: (&'ctx ContextInternal, InputMetadata)) -> Self { + Self { + key: value.1.key, + inner: value.0, + } + } +} + +impl<'ctx> WorkflowContext<'ctx> { + pub fn key(&self) -> &str { + &self.key + } +} + +pub trait ContextTimers<'ctx>: private::SealedGetInnerContext<'ctx> { + /// Sleep using Restate + fn sleep(&self, duration: Duration) -> impl Future> + 'ctx { + private::SealedGetInnerContext::inner_context(self).sleep(duration) + } +} + +impl<'ctx, CTX: private::SealedGetInnerContext<'ctx>> ContextTimers<'ctx> for CTX {} + +pub trait ContextClient<'ctx>: private::SealedGetInnerContext<'ctx> { + fn request( + &self, + request_target: RequestTarget, + req: Req, + ) -> Request<'ctx, Req, Res> { + Request::new(self.inner_context(), request_target, req) + } + + fn service_client(&self) -> C + where + C: IntoServiceClient<'ctx>, + { + C::create_client(self.inner_context()) + } + + fn object_client(&self, key: impl Into) -> C + where + C: IntoObjectClient<'ctx>, + { + C::create_client(self.inner_context(), key.into()) + } + + fn workflow_client(&self, key: impl Into) -> C + where + C: IntoWorkflowClient<'ctx>, + { + C::create_client(self.inner_context(), key.into()) + } +} + +pub trait IntoServiceClient<'ctx>: Sized { + fn create_client(ctx: &'ctx ContextInternal) -> Self; +} + +pub trait IntoObjectClient<'ctx>: Sized { + fn create_client(ctx: &'ctx ContextInternal, key: String) -> Self; +} + +pub trait IntoWorkflowClient<'ctx>: Sized { + fn create_client(ctx: &'ctx ContextInternal, key: String) -> Self; +} + +impl<'ctx, CTX: private::SealedGetInnerContext<'ctx>> ContextClient<'ctx> for CTX {} + +pub trait ContextAwakeables<'ctx>: private::SealedGetInnerContext<'ctx> { + /// Create an awakeable + fn awakeable( + &self, + ) -> ( + String, + impl Future> + Send + Sync + 'ctx, + ) { + self.inner_context().awakeable() + } + + /// Resolve an awakeable + fn resolve_awakeable(&self, key: &str, t: T) { + self.inner_context().resolve_awakeable(key, t) + } + + /// Resolve an awakeable + fn reject_awakeable(&self, key: &str, failure: TerminalError) { + self.inner_context().reject_awakeable(key, failure) + } +} + +impl<'ctx, CTX: private::SealedGetInnerContext<'ctx>> ContextAwakeables<'ctx> for CTX {} + +pub trait ContextSideEffects<'ctx>: private::SealedGetInnerContext<'ctx> { + /// Run a non-deterministic operation + fn run( + &self, + name: &'ctx str, + run_closure: R, + ) -> impl Future> + 'ctx + where + R: RunClosure + Send + Sync + 'ctx, + T: Serialize + Deserialize, + F: Future> + Send + Sync + 'ctx, + { + self.inner_context().run(name, run_closure) + } +} + +impl<'ctx, CTX: private::SealedGetInnerContext<'ctx>> ContextSideEffects<'ctx> for CTX {} + +pub trait ContextReadState<'ctx>: private::SealedGetInnerContext<'ctx> { + /// Get state + fn get( + &self, + key: &str, + ) -> impl Future, TerminalError>> + 'ctx { + self.inner_context().get(key) + } + + /// Get state keys + fn get_keys(&self) -> impl Future, TerminalError>> + 'ctx { + self.inner_context().get_keys() + } +} + +impl<'ctx, CTX: private::SealedGetInnerContext<'ctx> + private::SealedCanReadState> + ContextReadState<'ctx> for CTX +{ +} + +pub trait ContextWriteState<'ctx>: private::SealedGetInnerContext<'ctx> { + /// Set state + fn set(&self, key: &str, t: T) { + self.inner_context().set(key, t) + } + + /// Clear state + fn clear(&self, key: &str) { + self.inner_context().clear(key) + } + + /// Clear all state + fn clear_all(&self) { + self.inner_context().clear_all() + } +} + +impl<'ctx, CTX: private::SealedGetInnerContext<'ctx> + private::SealedCanWriteState> + ContextWriteState<'ctx> for CTX +{ +} + +pub trait ContextPromises<'ctx>: private::SealedGetInnerContext<'ctx> { + /// Create a promise + fn promise( + &'ctx self, + key: &str, + ) -> impl Future> + 'ctx { + self.inner_context().promise(key) + } + + /// Peek a promise + fn peek_promise( + &self, + key: &str, + ) -> impl Future, TerminalError>> + 'ctx { + self.inner_context().peek_promise(key) + } + + /// Resolve a promise + fn resolve_promise(&self, key: &str, t: T) { + self.inner_context().resolve_promise(key, t) + } + + /// Resolve a promise + fn reject_promise(&self, key: &str, failure: TerminalError) { + self.inner_context().reject_promise(key, failure) + } +} + +impl<'ctx, CTX: private::SealedGetInnerContext<'ctx> + private::SealedCanUsePromises> + ContextPromises<'ctx> for CTX +{ +} + +mod private { + use super::*; + + pub trait SealedGetInnerContext<'ctx> { + fn inner_context(&self) -> &'ctx ContextInternal; + } + + // Context capabilities + pub trait SealedCanReadState {} + pub trait SealedCanWriteState {} + pub trait SealedCanUsePromises {} + + impl<'ctx> SealedGetInnerContext<'ctx> for Context<'ctx> { + fn inner_context(&self) -> &'ctx ContextInternal { + self.inner + } + } + + impl<'ctx> SealedGetInnerContext<'ctx> for SharedObjectContext<'ctx> { + fn inner_context(&self) -> &'ctx ContextInternal { + self.inner + } + } + + impl SealedCanReadState for SharedObjectContext<'_> {} + + impl<'ctx> SealedGetInnerContext<'ctx> for ObjectContext<'ctx> { + fn inner_context(&self) -> &'ctx ContextInternal { + self.inner + } + } + + impl SealedCanReadState for ObjectContext<'_> {} + impl SealedCanWriteState for ObjectContext<'_> {} + + impl<'ctx> SealedGetInnerContext<'ctx> for SharedWorkflowContext<'ctx> { + fn inner_context(&self) -> &'ctx ContextInternal { + self.inner + } + } + + impl SealedCanReadState for SharedWorkflowContext<'_> {} + impl SealedCanUsePromises for SharedWorkflowContext<'_> {} + + impl<'ctx> SealedGetInnerContext<'ctx> for WorkflowContext<'ctx> { + fn inner_context(&self) -> &'ctx ContextInternal { + self.inner + } + } + + impl SealedCanReadState for WorkflowContext<'_> {} + impl SealedCanWriteState for WorkflowContext<'_> {} + impl SealedCanUsePromises for WorkflowContext<'_> {} +} diff --git a/src/context/request.rs b/src/context/request.rs new file mode 100644 index 0000000..48c0dee --- /dev/null +++ b/src/context/request.rs @@ -0,0 +1,108 @@ +use crate::endpoint::ContextInternal; +use crate::errors::TerminalError; +use crate::serde::{Deserialize, Serialize}; +use std::fmt; +use std::future::Future; +use std::marker::PhantomData; +use std::time::Duration; + +#[derive(Debug, Clone)] +pub enum RequestTarget { + Service { + name: String, + handler: String, + }, + Object { + name: String, + key: String, + handler: String, + }, + Workflow { + name: String, + key: String, + handler: String, + }, +} + +impl RequestTarget { + pub fn service(name: impl Into, handler: impl Into) -> Self { + Self::Service { + name: name.into(), + handler: handler.into(), + } + } + + pub fn object( + name: impl Into, + key: impl Into, + handler: impl Into, + ) -> Self { + Self::Object { + name: name.into(), + key: key.into(), + handler: handler.into(), + } + } + + pub fn workflow( + name: impl Into, + key: impl Into, + handler: impl Into, + ) -> Self { + Self::Workflow { + name: name.into(), + key: key.into(), + handler: handler.into(), + } + } +} + +impl fmt::Display for RequestTarget { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + RequestTarget::Service { name, handler } => write!(f, "{name}/{handler}"), + RequestTarget::Object { name, key, handler } => write!(f, "{name}/{key}/{handler}"), + RequestTarget::Workflow { name, key, handler } => write!(f, "{name}/{key}/{handler}"), + } + } +} + +pub struct Request<'a, Req, Res> { + ctx: &'a ContextInternal, + request_target: RequestTarget, + req: Req, + res: PhantomData, +} + +impl<'a, Req, Res> Request<'a, Req, Res> { + pub(crate) fn new(ctx: &'a ContextInternal, request_target: RequestTarget, req: Req) -> Self { + Self { + ctx, + request_target, + req, + res: PhantomData, + } + } + + pub fn call(self) -> impl Future> + Send + where + Req: Serialize + 'static, + Res: Deserialize + 'static, + { + self.ctx.call(self.request_target, self.req) + } + + pub fn send(self) + where + Req: Serialize + 'static, + { + self.ctx.send(self.request_target, self.req, None) + } + + pub fn send_with_delay(self, duration: Duration) + where + Req: Serialize + 'static, + { + self.ctx.send(self.request_target, self.req, Some(duration)) + } +} diff --git a/src/context/run.rs b/src/context/run.rs new file mode 100644 index 0000000..645d9c6 --- /dev/null +++ b/src/context/run.rs @@ -0,0 +1,25 @@ +use crate::errors::HandlerResult; +use crate::serde::{Deserialize, Serialize}; +use std::future::Future; + +// Run +pub trait RunClosure { + type Output: Deserialize + Serialize + 'static; + type Fut: Future>; + + fn run(self) -> Self::Fut; +} + +impl RunClosure for F +where + F: FnOnce() -> Fut, + Fut: Future>, + O: Deserialize + Serialize + 'static, +{ + type Output = O; + type Fut = Fut; + + fn run(self) -> Self::Fut { + self() + } +} diff --git a/src/endpoint/context.rs b/src/endpoint/context.rs index de44498..f8c7e34 100644 --- a/src/endpoint/context.rs +++ b/src/endpoint/context.rs @@ -1,4 +1,4 @@ -use crate::context::{RequestTarget, RunClosure}; +use crate::context::{Request, RequestTarget, RunClosure}; use crate::endpoint::futures::{InterceptErrorFuture, TrapFuture}; use crate::endpoint::handler_state::HandlerStateNotifier; use crate::endpoint::{Error, ErrorInner, InputReceiver, OutputSender}; @@ -263,6 +263,10 @@ impl ContextInternal { InterceptErrorFuture::new(self.clone(), poll_future.map_err(Error)) } + pub fn request(&self, request_target: RequestTarget, req: Req) -> Request { + Request::new(self, request_target, req) + } + pub fn call( &self, request_target: RequestTarget, diff --git a/src/lib.rs b/src/lib.rs index 724a32a..2388443 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,7 +15,9 @@ pub mod prelude { pub use crate::http::HyperServer; pub use crate::context::{ - Context, ObjectContext, SharedObjectContext, SharedWorkflowContext, WorkflowContext, + Context, ContextAwakeables, ContextClient, ContextPromises, ContextReadState, + ContextSideEffects, ContextTimers, ContextWriteState, ObjectContext, Request, + SharedObjectContext, SharedWorkflowContext, WorkflowContext, }; pub use crate::endpoint::Endpoint; pub use crate::errors::{HandlerError, HandlerResult, TerminalError}; diff --git a/test-services/exclusions.yaml b/test-services/exclusions.yaml index f9e3440..0909535 100644 --- a/test-services/exclusions.yaml +++ b/test-services/exclusions.yaml @@ -10,8 +10,6 @@ exclusions: - "dev.restate.sdktesting.tests.UserErrors" "default": - "dev.restate.sdktesting.tests.AwaitTimeout" - - "dev.restate.sdktesting.tests.CallOrdering" - - "dev.restate.sdktesting.tests.CancelInvocation" - "dev.restate.sdktesting.tests.Ingress" - "dev.restate.sdktesting.tests.KillInvocation" - "dev.restate.sdktesting.tests.ServiceToServiceCommunication" @@ -24,8 +22,6 @@ exclusions: - "dev.restate.sdktesting.tests.Sleep" "singleThreadSinglePartition": - "dev.restate.sdktesting.tests.AwaitTimeout" - - "dev.restate.sdktesting.tests.CallOrdering" - - "dev.restate.sdktesting.tests.CancelInvocation" - "dev.restate.sdktesting.tests.Ingress" - "dev.restate.sdktesting.tests.KillInvocation" - "dev.restate.sdktesting.tests.ServiceToServiceCommunication" diff --git a/test-services/src/cancel_test.rs b/test-services/src/cancel_test.rs new file mode 100644 index 0000000..d05b6c5 --- /dev/null +++ b/test-services/src/cancel_test.rs @@ -0,0 +1,98 @@ +use crate::awakeable_holder; +use anyhow::anyhow; +use restate_sdk::prelude::*; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +#[derive(Serialize, Deserialize)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub(crate) enum BlockingOperation { + Call, + Sleep, + Awakeable, +} + +#[restate_sdk::object] +#[name = "CancelTestRunner"] +pub(crate) trait CancelTestRunner { + #[name = "startTest"] + async fn start_test(op: Json) -> HandlerResult<()>; + #[name = "verifyTest"] + async fn verify_test() -> HandlerResult; +} + +pub(crate) struct CancelTestRunnerImpl; + +const CANCELED: &str = "canceled"; + +impl CancelTestRunner for CancelTestRunnerImpl { + async fn start_test( + &self, + context: ObjectContext<'_>, + op: Json, + ) -> HandlerResult<()> { + let this = context.object_client::(""); + + match this.block(op).call().await { + Ok(_) => Err(anyhow!("Block succeeded, this is unexpected").into()), + Err(e) if e.code() == 409 => { + context.set(CANCELED, true); + Ok(()) + } + Err(e) => Err(e.into()), + } + } + + async fn verify_test(&self, context: ObjectContext<'_>) -> HandlerResult { + Ok(context.get::(CANCELED).await?.unwrap_or(false)) + } +} + +#[restate_sdk::object] +#[name = "CancelTestBlockingService"] +pub(crate) trait CancelTestBlockingService { + #[name = "block"] + async fn block(op: Json) -> HandlerResult<()>; + #[name = "isUnlocked"] + async fn is_unlocked() -> HandlerResult<()>; +} + +pub(crate) struct CancelTestBlockingServiceImpl; + +impl CancelTestBlockingService for CancelTestBlockingServiceImpl { + async fn block( + &self, + context: ObjectContext<'_>, + op: Json, + ) -> HandlerResult<()> { + let this = context.object_client::(""); + let awakeable_holder_client = + context.object_client::("cancel"); + + let (awk_id, awakeable) = context.awakeable::(); + awakeable_holder_client.hold(awk_id).call().await?; + awakeable.await?; + + match &op.0 { + BlockingOperation::Call => { + this.block(op).call().await?; + } + BlockingOperation::Sleep => { + context + .sleep(Duration::from_secs(60 * 60 * 24 * 1024)) + .await?; + } + BlockingOperation::Awakeable => { + let (_, uncompletable) = context.awakeable::(); + uncompletable.await?; + } + } + + Ok(()) + } + + async fn is_unlocked(&self, _: ObjectContext<'_>) -> HandlerResult<()> { + // no-op + Ok(()) + } +} diff --git a/test-services/src/main.rs b/test-services/src/main.rs index 3a0b813..a094b54 100644 --- a/test-services/src/main.rs +++ b/test-services/src/main.rs @@ -1,5 +1,6 @@ mod awakeable_holder; mod block_and_wait_workflow; +mod cancel_test; mod counter; mod list_object; mod map_object; @@ -38,6 +39,16 @@ async fn main() { block_and_wait_workflow::BlockAndWaitWorkflowImpl, )) } + if services == "*" || services.contains("CancelTestRunner") { + builder = builder.with_service(cancel_test::CancelTestRunner::serve( + cancel_test::CancelTestRunnerImpl, + )) + } + if services == "*" || services.contains("CancelTestBlockingService") { + builder = builder.with_service(cancel_test::CancelTestBlockingService::serve( + cancel_test::CancelTestBlockingServiceImpl, + )) + } HyperServer::new(builder.build()) .listen_and_serve(format!("0.0.0.0:{port}").parse().unwrap()) diff --git a/test-services/src/proxy.rs b/test-services/src/proxy.rs index 2fbe234..2812f69 100644 --- a/test-services/src/proxy.rs +++ b/test-services/src/proxy.rs @@ -59,7 +59,7 @@ impl Proxy for ProxyImpl { ctx: Context<'_>, Json(req): Json, ) -> HandlerResult>> { - Ok(ctx.call(req.to_target(), req.message).await?) + Ok(ctx.request(req.to_target(), req.message).call().await?) } async fn one_way_call( @@ -67,11 +67,14 @@ impl Proxy for ProxyImpl { ctx: Context<'_>, Json(req): Json, ) -> HandlerResult<()> { - ctx.send( - req.to_target(), - req.message, - req.delay_millis.map(Duration::from_millis), - ); + let request = ctx.request::<_, ()>(req.to_target(), req.message); + + if let Some(delay_millis) = req.delay_millis { + request.send_with_delay(Duration::from_millis(delay_millis)); + } else { + request.send(); + } + Ok(()) } @@ -83,15 +86,16 @@ impl Proxy for ProxyImpl { let mut futures: Vec, TerminalError>>> = vec![]; for req in requests { + let restate_req = + ctx.request::<_, Vec>(req.proxy_request.to_target(), req.proxy_request.message); if req.one_way_call { - ctx.send( - req.proxy_request.to_target(), - req.proxy_request.message, - req.proxy_request.delay_millis.map(Duration::from_millis), - ); + if let Some(delay_millis) = req.proxy_request.delay_millis { + restate_req.send_with_delay(Duration::from_millis(delay_millis)); + } else { + restate_req.send(); + } } else { - let fut = ctx - .call::<_, Vec>(req.proxy_request.to_target(), req.proxy_request.message); + let fut = restate_req.call(); if req.await_at_the_end { futures.push(fut.boxed()) }