Skip to content

Commit 2910d50

Browse files
More improvements, now a standalone Result can be returned
1 parent 54d4e55 commit 2910d50

File tree

10 files changed

+83
-60
lines changed

10 files changed

+83
-60
lines changed

examples/counter.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,33 @@ use restate_sdk::prelude::*;
33
#[restate_sdk::object]
44
trait Counter {
55
#[shared]
6-
async fn get() -> HandlerResult<u64>;
7-
async fn add(val: u64) -> HandlerResult<u64>;
8-
async fn increment() -> HandlerResult<u64>;
9-
async fn reset() -> HandlerResult<()>;
6+
async fn get() -> Result<u64, TerminalError>;
7+
async fn add(val: u64) -> Result<u64, TerminalError>;
8+
async fn increment() -> Result<u64, TerminalError>;
9+
async fn reset() -> Result<(), TerminalError>;
1010
}
1111

1212
struct CounterImpl;
1313

1414
const COUNT: &str = "count";
1515

1616
impl Counter for CounterImpl {
17-
async fn get(&self, ctx: SharedObjectContext<'_>) -> HandlerResult<u64> {
17+
async fn get(&self, ctx: SharedObjectContext<'_>) -> Result<u64, TerminalError> {
1818
Ok(ctx.get::<u64>(COUNT).await?.unwrap_or(0))
1919
}
2020

21-
async fn add(&self, ctx: ObjectContext<'_>, val: u64) -> HandlerResult<u64> {
21+
async fn add(&self, ctx: ObjectContext<'_>, val: u64) -> Result<u64, TerminalError> {
2222
let current = ctx.get::<u64>(COUNT).await?.unwrap_or(0);
2323
let new = current + val;
2424
ctx.set(COUNT, new);
2525
Ok(new)
2626
}
2727

28-
async fn increment(&self, ctx: ObjectContext<'_>) -> HandlerResult<u64> {
28+
async fn increment(&self, ctx: ObjectContext<'_>) -> Result<u64, TerminalError> {
2929
self.add(ctx, 1).await
3030
}
3131

32-
async fn reset(&self, ctx: ObjectContext<'_>) -> HandlerResult<()> {
32+
async fn reset(&self, ctx: ObjectContext<'_>) -> Result<(), TerminalError> {
3333
ctx.clear(COUNT);
3434
Ok(())
3535
}

examples/failures.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use restate_sdk::prelude::*;
44
#[restate_sdk::service]
55
trait FailureExample {
66
#[name = "doRun"]
7-
async fn do_run() -> HandlerResult<()>;
7+
async fn do_run() -> Result<(), TerminalError>;
88
}
99

1010
struct FailureExampleImpl;
@@ -14,14 +14,14 @@ struct FailureExampleImpl;
1414
struct MyError;
1515

1616
impl FailureExample for FailureExampleImpl {
17-
async fn do_run(&self, context: Context<'_>) -> HandlerResult<()> {
17+
async fn do_run(&self, context: Context<'_>) -> Result<(), TerminalError> {
1818
context
1919
.run(|| async move {
2020
if rand::thread_rng().next_u32() % 4 == 0 {
21-
return Err(TerminalError::new("Failed!!!").into());
21+
Err(TerminalError::new("Failed!!!"))?
2222
}
2323

24-
Err(MyError.into())
24+
Err(MyError)?
2525
})
2626
.await?;
2727

examples/greeter.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
use restate_sdk::prelude::*;
2+
use std::convert::Infallible;
23

34
#[restate_sdk::service]
45
trait Greeter {
5-
async fn greet(name: String) -> HandlerResult<String>;
6+
async fn greet(name: String) -> Result<String, Infallible>;
67
}
78

89
struct GreeterImpl;
910

1011
impl Greeter for GreeterImpl {
11-
async fn greet(&self, _: Context<'_>, name: String) -> HandlerResult<String> {
12+
async fn greet(&self, _: Context<'_>, name: String) -> Result<String, Infallible> {
1213
Ok(format!("Greetings {name}"))
1314
}
1415
}

examples/run.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@ use std::collections::HashMap;
33

44
#[restate_sdk::service]
55
trait RunExample {
6-
async fn do_run() -> HandlerResult<Json<HashMap<String, String>>>;
6+
async fn do_run() -> Result<Json<HashMap<String, String>>, HandlerError>;
77
}
88

99
struct RunExampleImpl(reqwest::Client);
1010

1111
impl RunExample for RunExampleImpl {
12-
async fn do_run(&self, context: Context<'_>) -> HandlerResult<Json<HashMap<String, String>>> {
12+
async fn do_run(
13+
&self,
14+
context: Context<'_>,
15+
) -> Result<Json<HashMap<String, String>>, HandlerError> {
1316
let res = context
1417
.run(|| async move {
1518
let req = self.0.get("https://httpbin.org/ip").build()?;

macros/src/ast.rs

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ pub(crate) struct Handler {
145145
pub(crate) restate_name: String,
146146
pub(crate) ident: Ident,
147147
pub(crate) arg: Option<PatType>,
148-
pub(crate) output: Type,
148+
pub(crate) output_ok: Type,
149+
pub(crate) output_err: Type,
149150
}
150151

151152
impl Parse for Handler {
@@ -192,17 +193,18 @@ impl Parse for Handler {
192193
let return_type: ReturnType = input.parse()?;
193194
input.parse::<Token![;]>()?;
194195

195-
let output: Type = match &return_type {
196-
ReturnType::Default => {
197-
parse_quote!(())
198-
}
196+
let (ok_ty, err_ty) = match &return_type {
197+
ReturnType::Default => return Err(Error::new(
198+
return_type.span(),
199+
"The return type cannot be empty, only Result or restate_sdk::prelude::HandlerResult is supported as return type",
200+
)),
199201
ReturnType::Type(_, ty) => {
200-
if let Some(ty) = extract_handler_result_parameter(ty) {
201-
ty
202+
if let Some((ok_ty, err_ty)) = extract_handler_result_parameter(ty) {
203+
(ok_ty, err_ty)
202204
} else {
203205
return Err(Error::new(
204206
return_type.span(),
205-
"Only restate_sdk::prelude::HandlerResult is supported as return type",
207+
"Only Result or restate_sdk::prelude::HandlerResult is supported as return type",
206208
));
207209
}
208210
}
@@ -229,7 +231,8 @@ impl Parse for Handler {
229231
restate_name,
230232
ident,
231233
arg: args.pop(),
232-
output,
234+
output_ok: ok_ty,
235+
output_err: err_ty,
233236
})
234237
}
235238
}
@@ -263,14 +266,16 @@ fn read_literal_attribute_name(attr: &Attribute) -> Result<Option<String>> {
263266
.transpose()
264267
}
265268

266-
fn extract_handler_result_parameter(ty: &Type) -> Option<Type> {
269+
fn extract_handler_result_parameter(ty: &Type) -> Option<(Type, Type)> {
267270
let path = match ty {
268271
Type::Path(ty) => &ty.path,
269272
_ => return None,
270273
};
271274

272275
let last = path.segments.last().unwrap();
273-
if last.ident != "HandlerResult" {
276+
let is_result = last.ident == "Result";
277+
let is_handler_result = last.ident == "HandlerResult";
278+
if !is_result && !is_handler_result {
274279
return None;
275280
}
276281

@@ -279,12 +284,22 @@ fn extract_handler_result_parameter(ty: &Type) -> Option<Type> {
279284
_ => return None,
280285
};
281286

282-
if bracketed.args.len() != 1 {
283-
return None;
284-
}
285-
286-
match &bracketed.args[0] {
287-
GenericArgument::Type(arg) => Some(arg.clone()),
288-
_ => None,
287+
if is_handler_result && bracketed.args.len() == 1 {
288+
match &bracketed.args[0] {
289+
GenericArgument::Type(arg) => Some((
290+
arg.clone(),
291+
parse_quote!(::restate_sdk::prelude::HandlerError),
292+
)),
293+
_ => None,
294+
}
295+
} else if is_result && bracketed.args.len() == 2 {
296+
match (&bracketed.args[0], &bracketed.args[1]) {
297+
(GenericArgument::Type(ok_arg), GenericArgument::Type(err_arg)) => {
298+
Some((ok_arg.clone(), err_arg.clone()))
299+
}
300+
_ => None,
301+
}
302+
} else {
303+
None
289304
}
290305
}

macros/src/gen.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ impl<'a> ServiceGenerator<'a> {
5555
let handler_fns = handlers
5656
.iter()
5757
.map(
58-
|Handler { attrs, ident, arg, is_shared, output, .. }| {
58+
|Handler { attrs, ident, arg, is_shared, output_ok, output_err, .. }| {
5959
let args = arg.iter();
6060

6161
let ctx = match (&service_ty, is_shared) {
@@ -68,7 +68,7 @@ impl<'a> ServiceGenerator<'a> {
6868

6969
quote! {
7070
#( #attrs )*
71-
fn #ident(&self, context: #ctx, #( #args ),*) -> impl std::future::Future<Output=::restate_sdk::prelude::HandlerResult<#output>> + ::core::marker::Send;
71+
fn #ident(&self, context: #ctx, #( #args ),*) -> impl std::future::Future<Output=Result<#output_ok, #output_err>> + ::core::marker::Send;
7272
}
7373
},
7474
);
@@ -130,7 +130,7 @@ impl<'a> ServiceGenerator<'a> {
130130
quote! {
131131
#handler_literal => {
132132
#get_input_and_call
133-
let res = fut.await;
133+
let res = fut.await.map_err(::restate_sdk::errors::HandlerError::from);
134134
ctx.handle_handler_result(res);
135135
ctx.end();
136136
Ok(())
@@ -302,7 +302,7 @@ impl<'a> ServiceGenerator<'a> {
302302
ty, ..
303303
}) => quote! { #ty }
304304
};
305-
let res_ty = &handler.output;
305+
let res_ty = &handler.output_ok;
306306
let input = match &handler.arg {
307307
None => quote! { () },
308308
Some(_) => quote! { req }

src/endpoint/context.rs

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -527,14 +527,14 @@ impl ContextInternal {
527527
.sys_complete_promise(id.to_owned(), NonEmptyValue::Failure(failure.into()));
528528
}
529529

530-
pub fn run<'a, Run, Fut, Res>(
530+
pub fn run<'a, Run, Fut, Out>(
531531
&'a self,
532532
run_closure: Run,
533-
) -> impl crate::context::RunFuture<Result<Res, TerminalError>> + Send + Sync + 'a
533+
) -> impl crate::context::RunFuture<Result<Out, TerminalError>> + Send + Sync + 'a
534534
where
535-
Run: RunClosure<Fut = Fut, Output = Res> + Send + Sync + 'a,
536-
Fut: Future<Output = HandlerResult<Res>> + Send + Sync + 'a,
537-
Res: Serialize + Deserialize + 'static,
535+
Run: RunClosure<Fut = Fut, Output = Out> + Send + Sync + 'a,
536+
Fut: Future<Output = HandlerResult<Out>> + Send + Sync + 'a,
537+
Out: Serialize + Deserialize + 'static,
538538
{
539539
let this = Arc::clone(&self.inner);
540540

@@ -631,12 +631,12 @@ impl<Run, Fut, Ret> RunFuture<Run, Fut, Ret> {
631631
}
632632
}
633633

634-
impl<Run, Fut, Ret> crate::context::RunFuture<Result<Result<Ret, TerminalError>, Error>>
635-
for RunFuture<Run, Fut, Ret>
634+
impl<Run, Fut, Out> crate::context::RunFuture<Result<Result<Out, TerminalError>, Error>>
635+
for RunFuture<Run, Fut, Out>
636636
where
637-
Run: RunClosure<Fut = Fut, Output = Ret> + Send + Sync,
638-
Fut: Future<Output = HandlerResult<Ret>> + Send + Sync,
639-
Ret: Serialize + Deserialize,
637+
Run: RunClosure<Fut = Fut, Output = Out> + Send + Sync,
638+
Fut: Future<Output = HandlerResult<Out>> + Send + Sync,
639+
Out: Serialize + Deserialize,
640640
{
641641
fn with_retry_policy(mut self, retry_policy: RunRetryPolicy) -> Self {
642642
self.retry_policy = RetryPolicy::Exponential {
@@ -655,13 +655,13 @@ where
655655
}
656656
}
657657

658-
impl<Run, Fut, Res> Future for RunFuture<Run, Fut, Res>
658+
impl<Run, Fut, Out> Future for RunFuture<Run, Fut, Out>
659659
where
660-
Run: RunClosure<Fut = Fut, Output = Res> + Send + Sync,
661-
Res: Serialize + Deserialize,
662-
Fut: Future<Output = HandlerResult<Res>> + Send + Sync,
660+
Run: RunClosure<Fut = Fut, Output = Out> + Send + Sync,
661+
Out: Serialize + Deserialize,
662+
Fut: Future<Output = HandlerResult<Out>> + Send + Sync,
663663
{
664-
type Output = Result<Result<Res, TerminalError>, Error>;
664+
type Output = Result<Result<Out, TerminalError>, Error>;
665665

666666
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
667667
let mut this = self.project();
@@ -681,7 +681,7 @@ where
681681
// Enter the side effect
682682
match enter_result.map_err(ErrorInner::VM)? {
683683
RunEnterResult::Executed(NonEmptyValue::Success(mut v)) => {
684-
let t = Res::deserialize(&mut v).map_err(|e| {
684+
let t = Out::deserialize(&mut v).map_err(|e| {
685685
ErrorInner::Deserialization {
686686
syscall: "run",
687687
err: Box::new(e),
@@ -707,7 +707,7 @@ where
707707
}
708708
RunStateProj::ClosureRunning { start_time, fut } => {
709709
let res = match ready!(fut.poll(cx)) {
710-
Ok(t) => RunExitResult::Success(Res::serialize(&t).map_err(|e| {
710+
Ok(t) => RunExitResult::Success(Out::serialize(&t).map_err(|e| {
711711
ErrorInner::Serialization {
712712
syscall: "run",
713713
err: Box::new(e),
@@ -752,7 +752,7 @@ where
752752
}
753753
.into()),
754754
Value::Success(mut s) => {
755-
let t = Res::deserialize(&mut s).map_err(|e| {
755+
let t = Out::deserialize(&mut s).map_err(|e| {
756756
ErrorInner::Deserialization {
757757
syscall: "run",
758758
err: Box::new(e),

src/errors.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,4 @@ impl From<TerminalError> for Failure {
144144
}
145145

146146
/// Result type for a Restate handler.
147-
///
148-
/// All Restate handlers *MUST* use this type as return type for their handlers.
149147
pub type HandlerResult<T> = Result<T, HandlerError>;

test-services/src/failing.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ impl Failing for FailingImpl {
7272
error_message: String,
7373
) -> HandlerResult<()> {
7474
context
75-
.run(|| async move { Err::<(), _>(TerminalError::new(error_message).into()) })
75+
.run(|| async move { Err(TerminalError::new(error_message))? })
7676
.await?;
7777

7878
unreachable!("This should be unreachable")
@@ -92,7 +92,7 @@ impl Failing for FailingImpl {
9292
cloned_counter.store(0, Ordering::SeqCst);
9393
Ok(current_attempt)
9494
} else {
95-
Err(anyhow!("Failed at attempt {current_attempt}").into())
95+
Err(anyhow!("Failed at attempt {current_attempt}"))?
9696
}
9797
})
9898
.with_retry_policy(

tests/service.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ trait MyService {
1010
async fn no_output() -> HandlerResult<()>;
1111

1212
async fn no_input_no_output() -> HandlerResult<()>;
13+
14+
async fn std_result() -> Result<(), std::io::Error>;
15+
16+
async fn std_result_with_terminal_error() -> Result<(), TerminalError>;
17+
18+
async fn std_result_with_handler_error() -> Result<(), HandlerError>;
1319
}
1420

1521
#[restate_sdk::object]

0 commit comments

Comments
 (0)