diff --git a/examples/counter.rs b/examples/counter.rs index b38b163..4861ae2 100644 --- a/examples/counter.rs +++ b/examples/counter.rs @@ -3,10 +3,10 @@ use restate_sdk::prelude::*; #[restate_sdk::object] trait Counter { #[shared] - async fn get() -> HandlerResult; - async fn add(val: u64) -> HandlerResult; - async fn increment() -> HandlerResult; - async fn reset() -> HandlerResult<()>; + async fn get() -> Result; + async fn add(val: u64) -> Result; + async fn increment() -> Result; + async fn reset() -> Result<(), TerminalError>; } struct CounterImpl; @@ -14,22 +14,22 @@ struct CounterImpl; const COUNT: &str = "count"; impl Counter for CounterImpl { - async fn get(&self, ctx: SharedObjectContext<'_>) -> HandlerResult { + async fn get(&self, ctx: SharedObjectContext<'_>) -> Result { Ok(ctx.get::(COUNT).await?.unwrap_or(0)) } - async fn add(&self, ctx: ObjectContext<'_>, val: u64) -> HandlerResult { + async fn add(&self, ctx: ObjectContext<'_>, val: u64) -> Result { let current = ctx.get::(COUNT).await?.unwrap_or(0); let new = current + val; ctx.set(COUNT, new); Ok(new) } - async fn increment(&self, ctx: ObjectContext<'_>) -> HandlerResult { + async fn increment(&self, ctx: ObjectContext<'_>) -> Result { self.add(ctx, 1).await } - async fn reset(&self, ctx: ObjectContext<'_>) -> HandlerResult<()> { + async fn reset(&self, ctx: ObjectContext<'_>) -> Result<(), TerminalError> { ctx.clear(COUNT); Ok(()) } diff --git a/examples/failures.rs b/examples/failures.rs index 0ab4251..62379cf 100644 --- a/examples/failures.rs +++ b/examples/failures.rs @@ -4,7 +4,7 @@ use restate_sdk::prelude::*; #[restate_sdk::service] trait FailureExample { #[name = "doRun"] - async fn do_run() -> HandlerResult<()>; + async fn do_run() -> Result<(), TerminalError>; } struct FailureExampleImpl; @@ -14,14 +14,14 @@ struct FailureExampleImpl; struct MyError; impl FailureExample for FailureExampleImpl { - async fn do_run(&self, context: Context<'_>) -> HandlerResult<()> { + async fn do_run(&self, context: Context<'_>) -> Result<(), TerminalError> { context .run(|| async move { if rand::thread_rng().next_u32() % 4 == 0 { - return Err(TerminalError::new("Failed!!!").into()); + Err(TerminalError::new("Failed!!!"))? } - Err(MyError.into()) + Err(MyError)? }) .await?; diff --git a/examples/greeter.rs b/examples/greeter.rs index 3d4011b..ec07e0f 100644 --- a/examples/greeter.rs +++ b/examples/greeter.rs @@ -1,14 +1,15 @@ use restate_sdk::prelude::*; +use std::convert::Infallible; #[restate_sdk::service] trait Greeter { - async fn greet(name: String) -> HandlerResult; + async fn greet(name: String) -> Result; } struct GreeterImpl; impl Greeter for GreeterImpl { - async fn greet(&self, _: Context<'_>, name: String) -> HandlerResult { + async fn greet(&self, _: Context<'_>, name: String) -> Result { Ok(format!("Greetings {name}")) } } diff --git a/examples/run.rs b/examples/run.rs index f6127d7..e87a979 100644 --- a/examples/run.rs +++ b/examples/run.rs @@ -3,13 +3,16 @@ use std::collections::HashMap; #[restate_sdk::service] trait RunExample { - async fn do_run() -> HandlerResult>>; + async fn do_run() -> Result>, HandlerError>; } struct RunExampleImpl(reqwest::Client); impl RunExample for RunExampleImpl { - async fn do_run(&self, context: Context<'_>) -> HandlerResult>> { + async fn do_run( + &self, + context: Context<'_>, + ) -> Result>, HandlerError> { let res = context .run(|| async move { let req = self.0.get("https://httpbin.org/ip").build()?; diff --git a/macros/src/ast.rs b/macros/src/ast.rs index 9756762..7a9aadf 100644 --- a/macros/src/ast.rs +++ b/macros/src/ast.rs @@ -145,7 +145,8 @@ pub(crate) struct Handler { pub(crate) restate_name: String, pub(crate) ident: Ident, pub(crate) arg: Option, - pub(crate) output: Type, + pub(crate) output_ok: Type, + pub(crate) output_err: Type, } impl Parse for Handler { @@ -192,17 +193,18 @@ impl Parse for Handler { let return_type: ReturnType = input.parse()?; input.parse::()?; - let output: Type = match &return_type { - ReturnType::Default => { - parse_quote!(()) - } + let (ok_ty, err_ty) = match &return_type { + ReturnType::Default => return Err(Error::new( + return_type.span(), + "The return type cannot be empty, only Result or restate_sdk::prelude::HandlerResult is supported as return type", + )), ReturnType::Type(_, ty) => { - if let Some(ty) = extract_handler_result_parameter(ty) { - ty + if let Some((ok_ty, err_ty)) = extract_handler_result_parameter(ty) { + (ok_ty, err_ty) } else { return Err(Error::new( return_type.span(), - "Only restate_sdk::prelude::HandlerResult is supported as return type", + "Only Result or restate_sdk::prelude::HandlerResult is supported as return type", )); } } @@ -229,7 +231,8 @@ impl Parse for Handler { restate_name, ident, arg: args.pop(), - output, + output_ok: ok_ty, + output_err: err_ty, }) } } @@ -263,14 +266,16 @@ fn read_literal_attribute_name(attr: &Attribute) -> Result> { .transpose() } -fn extract_handler_result_parameter(ty: &Type) -> Option { +fn extract_handler_result_parameter(ty: &Type) -> Option<(Type, Type)> { let path = match ty { Type::Path(ty) => &ty.path, _ => return None, }; let last = path.segments.last().unwrap(); - if last.ident != "HandlerResult" { + let is_result = last.ident == "Result"; + let is_handler_result = last.ident == "HandlerResult"; + if !is_result && !is_handler_result { return None; } @@ -279,12 +284,22 @@ fn extract_handler_result_parameter(ty: &Type) -> Option { _ => return None, }; - if bracketed.args.len() != 1 { - return None; - } - - match &bracketed.args[0] { - GenericArgument::Type(arg) => Some(arg.clone()), - _ => None, + if is_handler_result && bracketed.args.len() == 1 { + match &bracketed.args[0] { + GenericArgument::Type(arg) => Some(( + arg.clone(), + parse_quote!(::restate_sdk::prelude::HandlerError), + )), + _ => None, + } + } else if is_result && bracketed.args.len() == 2 { + match (&bracketed.args[0], &bracketed.args[1]) { + (GenericArgument::Type(ok_arg), GenericArgument::Type(err_arg)) => { + Some((ok_arg.clone(), err_arg.clone())) + } + _ => None, + } + } else { + None } } diff --git a/macros/src/gen.rs b/macros/src/gen.rs index f67435d..a882b4d 100644 --- a/macros/src/gen.rs +++ b/macros/src/gen.rs @@ -55,7 +55,7 @@ impl<'a> ServiceGenerator<'a> { let handler_fns = handlers .iter() .map( - |Handler { attrs, ident, arg, is_shared, output, .. }| { + |Handler { attrs, ident, arg, is_shared, output_ok, output_err, .. }| { let args = arg.iter(); let ctx = match (&service_ty, is_shared) { @@ -68,7 +68,7 @@ impl<'a> ServiceGenerator<'a> { quote! { #( #attrs )* - fn #ident(&self, context: #ctx, #( #args ),*) -> impl std::future::Future> + ::core::marker::Send; + fn #ident(&self, context: #ctx, #( #args ),*) -> impl std::future::Future> + ::core::marker::Send; } }, ); @@ -130,7 +130,7 @@ impl<'a> ServiceGenerator<'a> { quote! { #handler_literal => { #get_input_and_call - let res = fut.await; + let res = fut.await.map_err(::restate_sdk::errors::HandlerError::from); ctx.handle_handler_result(res); ctx.end(); Ok(()) @@ -302,7 +302,7 @@ impl<'a> ServiceGenerator<'a> { ty, .. }) => quote! { #ty } }; - let res_ty = &handler.output; + let res_ty = &handler.output_ok; let input = match &handler.arg { None => quote! { () }, Some(_) => quote! { req } diff --git a/src/endpoint/context.rs b/src/endpoint/context.rs index 892ba2e..af5dfc1 100644 --- a/src/endpoint/context.rs +++ b/src/endpoint/context.rs @@ -527,14 +527,14 @@ impl ContextInternal { .sys_complete_promise(id.to_owned(), NonEmptyValue::Failure(failure.into())); } - pub fn run<'a, Run, Fut, Res>( + pub fn run<'a, Run, Fut, Out>( &'a self, run_closure: Run, - ) -> impl crate::context::RunFuture> + Send + Sync + 'a + ) -> impl crate::context::RunFuture> + Send + Sync + 'a where - Run: RunClosure + Send + Sync + 'a, - Fut: Future> + Send + Sync + 'a, - Res: Serialize + Deserialize + 'static, + Run: RunClosure + Send + Sync + 'a, + Fut: Future> + Send + Sync + 'a, + Out: Serialize + Deserialize + 'static, { let this = Arc::clone(&self.inner); @@ -631,12 +631,12 @@ impl RunFuture { } } -impl crate::context::RunFuture, Error>> - for RunFuture +impl crate::context::RunFuture, Error>> + for RunFuture where - Run: RunClosure + Send + Sync, - Fut: Future> + Send + Sync, - Ret: Serialize + Deserialize, + Run: RunClosure + Send + Sync, + Fut: Future> + Send + Sync, + Out: Serialize + Deserialize, { fn with_retry_policy(mut self, retry_policy: RunRetryPolicy) -> Self { self.retry_policy = RetryPolicy::Exponential { @@ -655,13 +655,13 @@ where } } -impl Future for RunFuture +impl Future for RunFuture where - Run: RunClosure + Send + Sync, - Res: Serialize + Deserialize, - Fut: Future> + Send + Sync, + Run: RunClosure + Send + Sync, + Out: Serialize + Deserialize, + Fut: Future> + Send + Sync, { - type Output = Result, Error>; + type Output = Result, Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut this = self.project(); @@ -681,7 +681,7 @@ where // Enter the side effect match enter_result.map_err(ErrorInner::VM)? { RunEnterResult::Executed(NonEmptyValue::Success(mut v)) => { - let t = Res::deserialize(&mut v).map_err(|e| { + let t = Out::deserialize(&mut v).map_err(|e| { ErrorInner::Deserialization { syscall: "run", err: Box::new(e), @@ -707,7 +707,7 @@ where } RunStateProj::ClosureRunning { start_time, fut } => { let res = match ready!(fut.poll(cx)) { - Ok(t) => RunExitResult::Success(Res::serialize(&t).map_err(|e| { + Ok(t) => RunExitResult::Success(Out::serialize(&t).map_err(|e| { ErrorInner::Serialization { syscall: "run", err: Box::new(e), @@ -752,7 +752,7 @@ where } .into()), Value::Success(mut s) => { - let t = Res::deserialize(&mut s).map_err(|e| { + let t = Out::deserialize(&mut s).map_err(|e| { ErrorInner::Deserialization { syscall: "run", err: Box::new(e), diff --git a/src/errors.rs b/src/errors.rs index 276f903..dc7f2ff 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -144,6 +144,4 @@ impl From for Failure { } /// Result type for a Restate handler. -/// -/// All Restate handlers *MUST* use this type as return type for their handlers. pub type HandlerResult = Result; diff --git a/test-services/src/failing.rs b/test-services/src/failing.rs index a6249e1..bdcad82 100644 --- a/test-services/src/failing.rs +++ b/test-services/src/failing.rs @@ -72,7 +72,7 @@ impl Failing for FailingImpl { error_message: String, ) -> HandlerResult<()> { context - .run(|| async move { Err::<(), _>(TerminalError::new(error_message).into()) }) + .run(|| async move { Err(TerminalError::new(error_message))? }) .await?; unreachable!("This should be unreachable") @@ -92,7 +92,7 @@ impl Failing for FailingImpl { cloned_counter.store(0, Ordering::SeqCst); Ok(current_attempt) } else { - Err(anyhow!("Failed at attempt {current_attempt}").into()) + Err(anyhow!("Failed at attempt {current_attempt}"))? } }) .with_retry_policy( diff --git a/tests/service.rs b/tests/service.rs index 0fffe07..55bdec1 100644 --- a/tests/service.rs +++ b/tests/service.rs @@ -10,6 +10,12 @@ trait MyService { async fn no_output() -> HandlerResult<()>; async fn no_input_no_output() -> HandlerResult<()>; + + async fn std_result() -> Result<(), std::io::Error>; + + async fn std_result_with_terminal_error() -> Result<(), TerminalError>; + + async fn std_result_with_handler_error() -> Result<(), HandlerError>; } #[restate_sdk::object]