diff --git a/Cargo.toml b/Cargo.toml index 6e7b0d2..2ed98fd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,16 +7,18 @@ license = "MIT" repository = "https://github.com/restatedev/sdk-rust" [features] -default = ["http"] -http = ["hyper", "http-body-util", "hyper-util", "tokio/net", "tokio/signal", "restate-sdk-shared-core/http"] +default = ["http_server", "rand", "uuid"] +http_server = ["hyper", "http-body-util", "hyper-util", "tokio/net", "tokio/signal", "restate-sdk-shared-core/http"] [dependencies] bytes = "1.6.1" futures = "0.3" +http = "1.1.0" http-body-util = { version = "0.1", optional = true } hyper = { version = "1.4.1", optional = true, features = ["server", "http2"] } hyper-util = { version = "0.1", features = ["tokio", "server", "server-graceful", "http2"], optional = true } pin-project-lite = "0.2" +rand = { version = "0.8.5", optional = true } regress = "0.10" restate-sdk-macros = { version = "0.1.0", path = "macros" } restate-sdk-shared-core = { version = "0.0.5" } @@ -26,6 +28,7 @@ thiserror = "1.0.63" tokio = { version = "1", default-features = false, features = ["sync", "macros"] } tower-service = "0.3" tracing = "0.1" +uuid = { version = "1.10.0", optional = true } [dev-dependencies] tokio = { version = "1", features = ["full"] } diff --git a/src/context/mod.rs b/src/context/mod.rs index 6e2f185..d785599 100644 --- a/src/context/mod.rs +++ b/src/context/mod.rs @@ -9,19 +9,44 @@ mod run; pub use request::{Request, RequestTarget}; pub use run::RunClosure; +pub type HeaderMap = http::HeaderMap; pub struct Context<'ctx> { + random_seed: u64, + #[cfg(feature = "rand")] + std_rng: rand::prelude::StdRng, + headers: HeaderMap, inner: &'ctx ContextInternal, } +impl<'ctx> Context<'ctx> { + pub fn headers(&self) -> &HeaderMap { + &self.headers + } + + pub fn headers_mut(&mut self) -> &HeaderMap { + &mut self.headers + } +} + impl<'ctx> From<(&'ctx ContextInternal, InputMetadata)> for Context<'ctx> { fn from(value: (&'ctx ContextInternal, InputMetadata)) -> Self { - Self { inner: value.0 } + Self { + random_seed: value.1.random_seed, + #[cfg(feature = "rand")] + std_rng: rand::prelude::SeedableRng::seed_from_u64(value.1.random_seed), + headers: value.1.headers, + inner: value.0, + } } } pub struct SharedObjectContext<'ctx> { key: String, + random_seed: u64, + #[cfg(feature = "rand")] + std_rng: rand::prelude::StdRng, + headers: HeaderMap, pub(crate) inner: &'ctx ContextInternal, } @@ -29,12 +54,24 @@ impl<'ctx> SharedObjectContext<'ctx> { pub fn key(&self) -> &str { &self.key } + + pub fn headers(&self) -> &HeaderMap { + &self.headers + } + + pub fn headers_mut(&mut self) -> &HeaderMap { + &mut self.headers + } } impl<'ctx> From<(&'ctx ContextInternal, InputMetadata)> for SharedObjectContext<'ctx> { fn from(value: (&'ctx ContextInternal, InputMetadata)) -> Self { Self { key: value.1.key, + random_seed: value.1.random_seed, + #[cfg(feature = "rand")] + std_rng: rand::prelude::SeedableRng::seed_from_u64(value.1.random_seed), + headers: value.1.headers, inner: value.0, } } @@ -42,6 +79,10 @@ impl<'ctx> From<(&'ctx ContextInternal, InputMetadata)> for SharedObjectContext< pub struct ObjectContext<'ctx> { key: String, + random_seed: u64, + #[cfg(feature = "rand")] + std_rng: rand::prelude::StdRng, + headers: HeaderMap, pub(crate) inner: &'ctx ContextInternal, } @@ -49,12 +90,24 @@ impl<'ctx> ObjectContext<'ctx> { pub fn key(&self) -> &str { &self.key } + + pub fn headers(&self) -> &HeaderMap { + &self.headers + } + + pub fn headers_mut(&mut self) -> &HeaderMap { + &mut self.headers + } } impl<'ctx> From<(&'ctx ContextInternal, InputMetadata)> for ObjectContext<'ctx> { fn from(value: (&'ctx ContextInternal, InputMetadata)) -> Self { Self { key: value.1.key, + random_seed: value.1.random_seed, + #[cfg(feature = "rand")] + std_rng: rand::prelude::SeedableRng::seed_from_u64(value.1.random_seed), + headers: value.1.headers, inner: value.0, } } @@ -62,6 +115,10 @@ impl<'ctx> From<(&'ctx ContextInternal, InputMetadata)> for ObjectContext<'ctx> pub struct SharedWorkflowContext<'ctx> { key: String, + random_seed: u64, + #[cfg(feature = "rand")] + std_rng: rand::prelude::StdRng, + headers: HeaderMap, pub(crate) inner: &'ctx ContextInternal, } @@ -69,6 +126,10 @@ impl<'ctx> From<(&'ctx ContextInternal, InputMetadata)> for SharedWorkflowContex fn from(value: (&'ctx ContextInternal, InputMetadata)) -> Self { Self { key: value.1.key, + random_seed: value.1.random_seed, + #[cfg(feature = "rand")] + std_rng: rand::prelude::SeedableRng::seed_from_u64(value.1.random_seed), + headers: value.1.headers, inner: value.0, } } @@ -78,10 +139,22 @@ impl<'ctx> SharedWorkflowContext<'ctx> { pub fn key(&self) -> &str { &self.key } + + pub fn headers(&self) -> &HeaderMap { + &self.headers + } + + pub fn headers_mut(&mut self) -> &HeaderMap { + &mut self.headers + } } pub struct WorkflowContext<'ctx> { key: String, + random_seed: u64, + #[cfg(feature = "rand")] + std_rng: rand::prelude::StdRng, + headers: HeaderMap, pub(crate) inner: &'ctx ContextInternal, } @@ -89,6 +162,10 @@ impl<'ctx> From<(&'ctx ContextInternal, InputMetadata)> for WorkflowContext<'ctx fn from(value: (&'ctx ContextInternal, InputMetadata)) -> Self { Self { key: value.1.key, + random_seed: value.1.random_seed, + #[cfg(feature = "rand")] + std_rng: rand::prelude::SeedableRng::seed_from_u64(value.1.random_seed), + headers: value.1.headers, inner: value.0, } } @@ -98,18 +175,26 @@ impl<'ctx> WorkflowContext<'ctx> { pub fn key(&self) -> &str { &self.key } + + pub fn headers(&self) -> &HeaderMap { + &self.headers + } + + pub fn headers_mut(&mut self) -> &HeaderMap { + &mut self.headers + } } -pub trait ContextTimers<'ctx>: private::SealedGetInnerContext<'ctx> { +pub trait ContextTimers<'ctx>: private::SealedContext<'ctx> { /// Sleep using Restate fn sleep(&self, duration: Duration) -> impl Future> + 'ctx { - private::SealedGetInnerContext::inner_context(self).sleep(duration) + private::SealedContext::inner_context(self).sleep(duration) } } -impl<'ctx, CTX: private::SealedGetInnerContext<'ctx>> ContextTimers<'ctx> for CTX {} +impl<'ctx, CTX: private::SealedContext<'ctx>> ContextTimers<'ctx> for CTX {} -pub trait ContextClient<'ctx>: private::SealedGetInnerContext<'ctx> { +pub trait ContextClient<'ctx>: private::SealedContext<'ctx> { fn request( &self, request_target: RequestTarget, @@ -152,9 +237,9 @@ pub trait IntoWorkflowClient<'ctx>: Sized { fn create_client(ctx: &'ctx ContextInternal, key: String) -> Self; } -impl<'ctx, CTX: private::SealedGetInnerContext<'ctx>> ContextClient<'ctx> for CTX {} +impl<'ctx, CTX: private::SealedContext<'ctx>> ContextClient<'ctx> for CTX {} -pub trait ContextAwakeables<'ctx>: private::SealedGetInnerContext<'ctx> { +pub trait ContextAwakeables<'ctx>: private::SealedContext<'ctx> { /// Create an awakeable fn awakeable( &self, @@ -176,9 +261,9 @@ pub trait ContextAwakeables<'ctx>: private::SealedGetInnerContext<'ctx> { } } -impl<'ctx, CTX: private::SealedGetInnerContext<'ctx>> ContextAwakeables<'ctx> for CTX {} +impl<'ctx, CTX: private::SealedContext<'ctx>> ContextAwakeables<'ctx> for CTX {} -pub trait ContextSideEffects<'ctx>: private::SealedGetInnerContext<'ctx> { +pub trait ContextSideEffects<'ctx>: private::SealedContext<'ctx> { /// Run a non-deterministic operation fn run( &self, @@ -192,11 +277,26 @@ pub trait ContextSideEffects<'ctx>: private::SealedGetInnerContext<'ctx> { { self.inner_context().run(name, run_closure) } + + fn random_seed(&self) -> u64 { + private::SealedContext::random_seed(self) + } + + #[cfg(feature = "rand")] + fn rand(&mut self) -> &mut rand::prelude::StdRng { + private::SealedContext::rand(self) + } + + #[cfg(all(feature = "rand", feature = "uuid"))] + fn rand_uuid(&mut self) -> uuid::Uuid { + let rand = private::SealedContext::rand(self); + uuid::Uuid::from_u64_pair(rand::RngCore::next_u64(rand), rand::RngCore::next_u64(rand)) + } } -impl<'ctx, CTX: private::SealedGetInnerContext<'ctx>> ContextSideEffects<'ctx> for CTX {} +impl<'ctx, CTX: private::SealedContext<'ctx>> ContextSideEffects<'ctx> for CTX {} -pub trait ContextReadState<'ctx>: private::SealedGetInnerContext<'ctx> { +pub trait ContextReadState<'ctx>: private::SealedContext<'ctx> { /// Get state fn get( &self, @@ -211,12 +311,12 @@ pub trait ContextReadState<'ctx>: private::SealedGetInnerContext<'ctx> { } } -impl<'ctx, CTX: private::SealedGetInnerContext<'ctx> + private::SealedCanReadState> - ContextReadState<'ctx> for CTX +impl<'ctx, CTX: private::SealedContext<'ctx> + private::SealedCanReadState> ContextReadState<'ctx> + for CTX { } -pub trait ContextWriteState<'ctx>: private::SealedGetInnerContext<'ctx> { +pub trait ContextWriteState<'ctx>: private::SealedContext<'ctx> { /// Set state fn set(&self, key: &str, t: T) { self.inner_context().set(key, t) @@ -233,12 +333,12 @@ pub trait ContextWriteState<'ctx>: private::SealedGetInnerContext<'ctx> { } } -impl<'ctx, CTX: private::SealedGetInnerContext<'ctx> + private::SealedCanWriteState> - ContextWriteState<'ctx> for CTX +impl<'ctx, CTX: private::SealedContext<'ctx> + private::SealedCanWriteState> ContextWriteState<'ctx> + for CTX { } -pub trait ContextPromises<'ctx>: private::SealedGetInnerContext<'ctx> { +pub trait ContextPromises<'ctx>: private::SealedContext<'ctx> { /// Create a promise fn promise( &'ctx self, @@ -266,16 +366,21 @@ pub trait ContextPromises<'ctx>: private::SealedGetInnerContext<'ctx> { } } -impl<'ctx, CTX: private::SealedGetInnerContext<'ctx> + private::SealedCanUsePromises> - ContextPromises<'ctx> for CTX +impl<'ctx, CTX: private::SealedContext<'ctx> + private::SealedCanUsePromises> ContextPromises<'ctx> + for CTX { } mod private { use super::*; - pub trait SealedGetInnerContext<'ctx> { + pub trait SealedContext<'ctx> { fn inner_context(&self) -> &'ctx ContextInternal; + + fn random_seed(&self) -> u64; + + #[cfg(feature = "rand")] + fn rand(&mut self) -> &mut rand::prelude::StdRng; } // Context capabilities @@ -283,42 +388,87 @@ mod private { pub trait SealedCanWriteState {} pub trait SealedCanUsePromises {} - impl<'ctx> SealedGetInnerContext<'ctx> for Context<'ctx> { + impl<'ctx> SealedContext<'ctx> for Context<'ctx> { fn inner_context(&self) -> &'ctx ContextInternal { self.inner } + + fn random_seed(&self) -> u64 { + self.random_seed + } + + #[cfg(feature = "rand")] + fn rand(&mut self) -> &mut rand::prelude::StdRng { + &mut self.std_rng + } } - impl<'ctx> SealedGetInnerContext<'ctx> for SharedObjectContext<'ctx> { + impl<'ctx> SealedContext<'ctx> for SharedObjectContext<'ctx> { fn inner_context(&self) -> &'ctx ContextInternal { self.inner } + + fn random_seed(&self) -> u64 { + self.random_seed + } + + #[cfg(feature = "rand")] + fn rand(&mut self) -> &mut rand::prelude::StdRng { + &mut self.std_rng + } } impl SealedCanReadState for SharedObjectContext<'_> {} - impl<'ctx> SealedGetInnerContext<'ctx> for ObjectContext<'ctx> { + impl<'ctx> SealedContext<'ctx> for ObjectContext<'ctx> { fn inner_context(&self) -> &'ctx ContextInternal { self.inner } + + fn random_seed(&self) -> u64 { + self.random_seed + } + + #[cfg(feature = "rand")] + fn rand(&mut self) -> &mut rand::prelude::StdRng { + &mut self.std_rng + } } impl SealedCanReadState for ObjectContext<'_> {} impl SealedCanWriteState for ObjectContext<'_> {} - impl<'ctx> SealedGetInnerContext<'ctx> for SharedWorkflowContext<'ctx> { + impl<'ctx> SealedContext<'ctx> for SharedWorkflowContext<'ctx> { fn inner_context(&self) -> &'ctx ContextInternal { self.inner } + + fn random_seed(&self) -> u64 { + self.random_seed + } + + #[cfg(feature = "rand")] + fn rand(&mut self) -> &mut rand::prelude::StdRng { + &mut self.std_rng + } } impl SealedCanReadState for SharedWorkflowContext<'_> {} impl SealedCanUsePromises for SharedWorkflowContext<'_> {} - impl<'ctx> SealedGetInnerContext<'ctx> for WorkflowContext<'ctx> { + impl<'ctx> SealedContext<'ctx> for WorkflowContext<'ctx> { fn inner_context(&self) -> &'ctx ContextInternal { self.inner } + + fn random_seed(&self) -> u64 { + self.random_seed + } + + #[cfg(feature = "rand")] + fn rand(&mut self) -> &mut rand::prelude::StdRng { + &mut self.std_rng + } } impl SealedCanReadState for WorkflowContext<'_> {} diff --git a/src/context/request.rs b/src/context/request.rs index 48c0dee..0b7812c 100644 --- a/src/context/request.rs +++ b/src/context/request.rs @@ -67,7 +67,7 @@ impl fmt::Display for RequestTarget { } } -pub struct Request<'a, Req, Res> { +pub struct Request<'a, Req, Res = ()> { ctx: &'a ContextInternal, request_target: RequestTarget, req: Req, diff --git a/src/endpoint/context.rs b/src/endpoint/context.rs index f8c7e34..970391e 100644 --- a/src/endpoint/context.rs +++ b/src/endpoint/context.rs @@ -8,9 +8,10 @@ use bytes::Bytes; use futures::future::Either; use futures::{FutureExt, TryFutureExt}; use restate_sdk_shared_core::{ - AsyncResultHandle, CoreVM, Header, Input, NonEmptyValue, RunEnterResult, SuspendedOrVMError, - TakeOutputResult, Target, VMError, Value, VM, + AsyncResultHandle, CoreVM, NonEmptyValue, RunEnterResult, SuspendedOrVMError, TakeOutputResult, + Target, VMError, Value, VM, }; +use std::collections::HashMap; use std::future::{ready, Future}; use std::pin::Pin; use std::sync::{Arc, Mutex}; @@ -38,6 +39,12 @@ impl ContextInternalInner { handler_state, } } + + pub(super) fn fail(&mut self, e: Error) { + self.vm + .notify_error(e.0.to_string().into(), format!("{:#}", e.0).into()); + self.handler_state.mark_error(e); + } } #[derive(Clone)] @@ -84,18 +91,7 @@ pub struct InputMetadata { pub invocation_id: String, pub random_seed: u64, pub key: String, - pub headers: Vec
, -} - -impl From for InputMetadata { - fn from(value: Input) -> Self { - Self { - invocation_id: value.invocation_id, - random_seed: value.random_seed, - key: value.key, - headers: value.headers, - } - } + pub headers: http::HeaderMap, } impl From for Target { @@ -137,6 +133,18 @@ impl ContextInternal { .sys_input() .map_err(ErrorInner::VM) .and_then(|raw_input| { + let headers = http::HeaderMap::::try_from( + &raw_input + .headers + .into_iter() + .map(|h| (h.key.to_string(), h.value.to_string())) + .collect::>(), + ) + .map_err(|e| ErrorInner::Deserialization { + syscall: "input_headers", + err: e.into(), + })?; + Ok(( T::deserialize(&mut (raw_input.input.into())).map_err(|e| { ErrorInner::Deserialization { @@ -148,7 +156,7 @@ impl ContextInternal { invocation_id: raw_input.invocation_id, random_seed: raw_input.random_seed, key: raw_input.key, - headers: raw_input.headers, + headers, }, )) }); @@ -159,7 +167,7 @@ impl ContextInternal { return Either::Left(ready(i)); } Err(e) => { - inner_lock.handler_state.mark_error_inner(e); + inner_lock.fail(e.into()); drop(inner_lock); } } @@ -222,12 +230,13 @@ impl ContextInternal { let _ = inner_lock.vm.sys_state_set(key.to_owned(), b.to_vec()); } Err(e) => { - inner_lock - .handler_state - .mark_error_inner(ErrorInner::Serialization { + inner_lock.fail( + ErrorInner::Serialization { syscall: "set_state", err: Box::new(e), - }); + } + .into(), + ); } } } @@ -277,12 +286,13 @@ impl ContextInternal { let input = match Req::serialize(&req) { Ok(t) => t, Err(e) => { - inner_lock - .handler_state - .mark_error_inner(ErrorInner::Serialization { + inner_lock.fail( + ErrorInner::Serialization { syscall: "call", err: Box::new(e), - }); + } + .into(), + ); return Either::Right(TrapFuture::default()); } }; @@ -334,12 +344,13 @@ impl ContextInternal { .sys_send(request_target.into(), t.to_vec(), delay); } Err(e) => { - inner_lock - .handler_state - .mark_error_inner(ErrorInner::Serialization { + inner_lock.fail( + ErrorInner::Serialization { syscall: "call", err: Box::new(e), - }); + } + .into(), + ); } }; } @@ -398,12 +409,13 @@ impl ContextInternal { .sys_complete_awakeable(id.to_owned(), NonEmptyValue::Success(b.to_vec())); } Err(e) => { - inner_lock - .handler_state - .mark_error_inner(ErrorInner::Serialization { + inner_lock.fail( + ErrorInner::Serialization { syscall: "resolve_awakeable", err: Box::new(e), - }); + } + .into(), + ); } } } @@ -480,12 +492,13 @@ impl ContextInternal { .sys_complete_promise(name.to_owned(), NonEmptyValue::Success(b.to_vec())); } Err(e) => { - inner_lock - .handler_state - .mark_error_inner(ErrorInner::Serialization { + inner_lock.fail( + ErrorInner::Serialization { syscall: "resolve_promise", err: Box::new(e), - }); + } + .into(), + ); } } } @@ -589,20 +602,19 @@ impl ContextInternal { Ok(success) => match T::serialize(&success) { Ok(t) => NonEmptyValue::Success(t.to_vec()), Err(e) => { - inner_lock - .handler_state - .mark_error_inner(ErrorInner::Serialization { + inner_lock.fail( + ErrorInner::Serialization { syscall: "output", err: Box::new(e), - }); + } + .into(), + ); return; } }, Err(e) => match e.0 { HandlerErrorInner::Retryable(err) => { - inner_lock - .handler_state - .mark_error_inner(ErrorInner::HandlerResult { err }); + inner_lock.fail(ErrorInner::HandlerResult { err }.into()); return; } HandlerErrorInner::Terminal(t) => NonEmptyValue::Failure(TerminalError(t).into()), @@ -628,7 +640,7 @@ impl ContextInternal { } pub(super) fn fail(&self, e: Error) { - must_lock!(self.inner).handler_state.mark_error(e); + must_lock!(self.inner).fail(e) } fn create_poll_future( diff --git a/src/endpoint/handler_state.rs b/src/endpoint/handler_state.rs index 71c38ee..59f34d0 100644 --- a/src/endpoint/handler_state.rs +++ b/src/endpoint/handler_state.rs @@ -1,4 +1,4 @@ -use crate::endpoint::{Error, ErrorInner}; +use crate::endpoint::Error; use tokio::sync::oneshot; pub(super) struct HandlerStateNotifier { @@ -17,8 +17,4 @@ impl HandlerStateNotifier { } // Some other operation already marked this handler as errored. } - - pub(super) fn mark_error_inner(&mut self, err: ErrorInner) { - self.mark_error(Error(err)) - } } diff --git a/src/endpoint/mod.rs b/src/endpoint/mod.rs index 788d7aa..dfcf28f 100644 --- a/src/endpoint/mod.rs +++ b/src/endpoint/mod.rs @@ -100,8 +100,8 @@ enum ErrorInner { UnknownServiceHandler(String, String), #[error("Error when processing the request: {0:?}")] VM(#[from] VMError), - #[error("Cannot read header '{0}', reason: {1}")] - Header(&'static str, #[source] BoxError), + #[error("Cannot convert header '{0}', reason: {1}")] + Header(String, #[source] BoxError), #[error("Cannot reply to discovery, got accept header '{0}' but currently supported discovery is {DISCOVERY_CONTENT_TYPE}")] BadDiscovery(String), #[error("Bad path '{0}', expected either '/discover' or '/invoke/service/handler'")] @@ -237,7 +237,7 @@ impl Endpoint { if parts.last() == Some(&"discover") { let accept_header = headers .extract("accept") - .map_err(|e| ErrorInner::Header("accept", Box::new(e)))?; + .map_err(|e| ErrorInner::Header("accept".to_owned(), Box::new(e)))?; if accept_header.is_some() { let accept = accept_header.unwrap(); if !accept.contains("application/vnd.restate.endpointmanifest.v1+json") { diff --git a/src/lib.rs b/src/lib.rs index 2388443..1b173fd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,19 +4,19 @@ pub mod service; pub mod context; pub mod discovery; pub mod errors; -#[cfg(feature = "http")] +#[cfg(feature = "http_server")] pub mod http; pub mod serde; pub use restate_sdk_macros::{object, service, workflow}; pub mod prelude { - #[cfg(feature = "http")] + #[cfg(feature = "http_server")] pub use crate::http::HyperServer; pub use crate::context::{ Context, ContextAwakeables, ContextClient, ContextPromises, ContextReadState, - ContextSideEffects, ContextTimers, ContextWriteState, ObjectContext, Request, + ContextSideEffects, ContextTimers, ContextWriteState, HeaderMap, ObjectContext, Request, SharedObjectContext, SharedWorkflowContext, WorkflowContext, }; pub use crate::endpoint::Endpoint; diff --git a/test-services/Cargo.toml b/test-services/Cargo.toml index 0b2a469..4383f90 100644 --- a/test-services/Cargo.toml +++ b/test-services/Cargo.toml @@ -11,4 +11,6 @@ tracing-subscriber = "0.3" futures = "0.3" restate-sdk = { path = ".." } serde = { version = "1", features = ["derive"] } -tracing = "0.1.40" \ No newline at end of file +tracing = "0.1.40" +rand = "0.8.5" +uuid = "1.10.0" \ No newline at end of file diff --git a/test-services/exclusions.yaml b/test-services/exclusions.yaml index 0909535..84f3d74 100644 --- a/test-services/exclusions.yaml +++ b/test-services/exclusions.yaml @@ -1,32 +1,7 @@ exclusions: "alwaysSuspending": - "dev.restate.sdktesting.tests.AwaitTimeout" - - "dev.restate.sdktesting.tests.ServiceToServiceCommunication" - - "dev.restate.sdktesting.tests.SideEffect" - - "dev.restate.sdktesting.tests.Sleep" - - "dev.restate.sdktesting.tests.SleepWithFailures" - - "dev.restate.sdktesting.tests.UpgradeWithInFlightInvocation" - - "dev.restate.sdktesting.tests.UpgradeWithNewInvocation" - - "dev.restate.sdktesting.tests.UserErrors" "default": - "dev.restate.sdktesting.tests.AwaitTimeout" - - "dev.restate.sdktesting.tests.Ingress" - - "dev.restate.sdktesting.tests.KillInvocation" - - "dev.restate.sdktesting.tests.ServiceToServiceCommunication" - - "dev.restate.sdktesting.tests.Sleep" - - "dev.restate.sdktesting.tests.SleepWithFailures" - - "dev.restate.sdktesting.tests.UpgradeWithInFlightInvocation" - - "dev.restate.sdktesting.tests.UpgradeWithNewInvocation" - - "dev.restate.sdktesting.tests.UserErrors" - "persistedTimers": - - "dev.restate.sdktesting.tests.Sleep" "singleThreadSinglePartition": - "dev.restate.sdktesting.tests.AwaitTimeout" - - "dev.restate.sdktesting.tests.Ingress" - - "dev.restate.sdktesting.tests.KillInvocation" - - "dev.restate.sdktesting.tests.ServiceToServiceCommunication" - - "dev.restate.sdktesting.tests.Sleep" - - "dev.restate.sdktesting.tests.SleepWithFailures" - - "dev.restate.sdktesting.tests.UpgradeWithInFlightInvocation" - - "dev.restate.sdktesting.tests.UpgradeWithNewInvocation" - - "dev.restate.sdktesting.tests.UserErrors" diff --git a/test-services/src/failing.rs b/test-services/src/failing.rs new file mode 100644 index 0000000..b8b5ec0 --- /dev/null +++ b/test-services/src/failing.rs @@ -0,0 +1,97 @@ +use anyhow::anyhow; +use restate_sdk::prelude::*; +use std::sync::atomic::{AtomicI32, Ordering}; +use std::sync::Arc; + +#[restate_sdk::object] +#[name = "Failing"] +pub(crate) trait Failing { + #[name = "terminallyFailingCall"] + async fn terminally_failing_call(error_message: String) -> HandlerResult<()>; + #[name = "callTerminallyFailingCall"] + async fn call_terminally_failing_call(error_message: String) -> HandlerResult; + #[name = "failingCallWithEventualSuccess"] + async fn failing_call_with_eventual_success() -> HandlerResult; + #[name = "failingSideEffectWithEventualSuccess"] + async fn failing_side_effect_with_eventual_success() -> HandlerResult; + #[name = "terminallyFailingSideEffect"] + async fn terminally_failing_side_effect(error_message: String) -> HandlerResult<()>; +} + +#[derive(Clone, Default)] +pub(crate) struct FailingImpl { + eventual_success_calls: Arc, + eventual_success_side_effects: Arc, +} + +impl Failing for FailingImpl { + async fn terminally_failing_call( + &self, + _: ObjectContext<'_>, + error_message: String, + ) -> HandlerResult<()> { + Err(TerminalError::new(error_message).into()) + } + + async fn call_terminally_failing_call( + &self, + mut context: ObjectContext<'_>, + error_message: String, + ) -> HandlerResult { + let uuid = context.rand_uuid().to_string(); + context + .object_client::(uuid) + .terminally_failing_call(error_message) + .call() + .await?; + + unreachable!("This should be unreachable") + } + + async fn failing_call_with_eventual_success(&self, _: ObjectContext<'_>) -> HandlerResult { + let current_attempt = self.eventual_success_calls.fetch_add(1, Ordering::SeqCst) + 1; + + if current_attempt >= 4 { + self.eventual_success_calls.store(0, Ordering::SeqCst); + Ok(current_attempt) + } else { + Err(anyhow!("Failed at attempt ${current_attempt}").into()) + } + } + + async fn failing_side_effect_with_eventual_success( + &self, + context: ObjectContext<'_>, + ) -> HandlerResult { + let cloned_eventual_side_effect_calls = Arc::clone(&self.eventual_success_side_effects); + let success_attempt = context + .run("failing_side_effect", || async move { + let current_attempt = + cloned_eventual_side_effect_calls.fetch_add(1, Ordering::SeqCst) + 1; + + if current_attempt >= 4 { + cloned_eventual_side_effect_calls.store(0, Ordering::SeqCst); + Ok(current_attempt) + } else { + Err(anyhow!("Failed at attempt ${current_attempt}").into()) + } + }) + .await?; + + Ok(success_attempt) + } + + async fn terminally_failing_side_effect( + &self, + context: ObjectContext<'_>, + error_message: String, + ) -> HandlerResult<()> { + context + .run("failing_side_effect", || async move { + Err::<(), _>(TerminalError::new(error_message).into()) + }) + .await?; + + unreachable!("This should be unreachable") + } +} diff --git a/test-services/src/kill_test.rs b/test-services/src/kill_test.rs new file mode 100644 index 0000000..7446bfc --- /dev/null +++ b/test-services/src/kill_test.rs @@ -0,0 +1,57 @@ +use crate::awakeable_holder; +use restate_sdk::prelude::*; + +#[restate_sdk::service] +#[name = "KillTestRunner"] +pub(crate) trait KillTestRunner { + #[name = "startCallTree"] + async fn start_call_tree() -> HandlerResult<()>; +} + +pub(crate) struct KillTestRunnerImpl; + +impl KillTestRunner for KillTestRunnerImpl { + async fn start_call_tree(&self, context: Context<'_>) -> HandlerResult<()> { + context + .object_client::("") + .recursive_call() + .call() + .await?; + Ok(()) + } +} + +#[restate_sdk::object] +#[name = "KillTestSingleton"] +pub(crate) trait KillTestSingleton { + #[name = "recursiveCall"] + async fn recursive_call() -> HandlerResult<()>; + #[name = "isUnlocked"] + async fn is_unlocked() -> HandlerResult<()>; +} + +pub(crate) struct KillTestSingletonImpl; + +impl KillTestSingleton for KillTestSingletonImpl { + async fn recursive_call(&self, context: ObjectContext<'_>) -> HandlerResult<()> { + let awakeable_holder_client = + context.object_client::("kill"); + + let (awk_id, awakeable) = context.awakeable::<()>(); + awakeable_holder_client.hold(awk_id).send(); + awakeable.await?; + + context + .object_client::("") + .recursive_call() + .call() + .await?; + + Ok(()) + } + + async fn is_unlocked(&self, _: ObjectContext<'_>) -> HandlerResult<()> { + // no-op + Ok(()) + } +} diff --git a/test-services/src/main.rs b/test-services/src/main.rs index a094b54..33e6d93 100644 --- a/test-services/src/main.rs +++ b/test-services/src/main.rs @@ -2,9 +2,13 @@ mod awakeable_holder; mod block_and_wait_workflow; mod cancel_test; mod counter; +mod failing; +mod kill_test; mod list_object; mod map_object; +mod non_deterministic; mod proxy; +mod test_utils_service; use restate_sdk::prelude::{Endpoint, HyperServer}; use std::env; @@ -49,6 +53,29 @@ async fn main() { cancel_test::CancelTestBlockingServiceImpl, )) } + if services == "*" || services.contains("Failing") { + builder = builder.with_service(failing::Failing::serve(failing::FailingImpl::default())) + } + if services == "*" || services.contains("KillTestRunner") { + builder = builder.with_service(kill_test::KillTestRunner::serve( + kill_test::KillTestRunnerImpl, + )) + } + if services == "*" || services.contains("KillTestSingleton") { + builder = builder.with_service(kill_test::KillTestSingleton::serve( + kill_test::KillTestSingletonImpl, + )) + } + if services == "*" || services.contains("NonDeterministic") { + builder = builder.with_service(non_deterministic::NonDeterministic::serve( + non_deterministic::NonDeterministicImpl::default(), + )) + } + if services == "*" || services.contains("TestUtilsService") { + builder = builder.with_service(test_utils_service::TestUtilsService::serve( + test_utils_service::TestUtilsServiceImpl, + )) + } HyperServer::new(builder.build()) .listen_and_serve(format!("0.0.0.0:{port}").parse().unwrap()) diff --git a/test-services/src/non_deterministic.rs b/test-services/src/non_deterministic.rs new file mode 100644 index 0000000..7614bf8 --- /dev/null +++ b/test-services/src/non_deterministic.rs @@ -0,0 +1,96 @@ +use crate::counter::CounterClient; +use restate_sdk::prelude::*; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::Mutex; + +#[restate_sdk::object] +#[name = "NonDeterministic"] +pub(crate) trait NonDeterministic { + #[name = "eitherSleepOrCall"] + async fn either_sleep_or_call() -> HandlerResult<()>; + #[name = "callDifferentMethod"] + async fn call_different_method() -> HandlerResult<()>; + #[name = "backgroundInvokeWithDifferentTargets"] + async fn background_invoke_with_different_targets() -> HandlerResult<()>; + #[name = "setDifferentKey"] + async fn set_different_key() -> HandlerResult<()>; +} + +#[derive(Clone, Default)] +pub(crate) struct NonDeterministicImpl(Arc>>); + +const STATE_A: &str = "a"; +const STATE_B: &str = "b"; + +impl NonDeterministic for NonDeterministicImpl { + async fn either_sleep_or_call(&self, context: ObjectContext<'_>) -> HandlerResult<()> { + if self.do_left_action(&context).await { + context.sleep(Duration::from_millis(100)).await?; + } else { + context + .object_client::("abc") + .get() + .call() + .await?; + } + Self::sleep_then_increment_counter(&context).await + } + + async fn call_different_method(&self, context: ObjectContext<'_>) -> HandlerResult<()> { + if self.do_left_action(&context).await { + context + .object_client::("abc") + .get() + .call() + .await?; + } else { + context + .object_client::("abc") + .reset() + .call() + .await?; + } + Self::sleep_then_increment_counter(&context).await + } + + async fn background_invoke_with_different_targets( + &self, + context: ObjectContext<'_>, + ) -> HandlerResult<()> { + if self.do_left_action(&context).await { + context.object_client::("abc").get().send(); + } else { + context.object_client::("abc").reset().send(); + } + Self::sleep_then_increment_counter(&context).await + } + + async fn set_different_key(&self, context: ObjectContext<'_>) -> HandlerResult<()> { + if self.do_left_action(&context).await { + context.set(STATE_A, "my-state".to_owned()); + } else { + context.set(STATE_B, "my-state".to_owned()); + } + Self::sleep_then_increment_counter(&context).await + } +} + +impl NonDeterministicImpl { + async fn do_left_action(&self, ctx: &ObjectContext<'_>) -> bool { + let mut counts = self.0.lock().await; + *(counts + .entry(ctx.key().to_owned()) + .and_modify(|i| *i += 1) + .or_default()) + % 2 + == 1 + } + + async fn sleep_then_increment_counter(ctx: &ObjectContext<'_>) -> HandlerResult<()> { + ctx.sleep(Duration::from_millis(100)).await?; + ctx.object_client::(ctx.key()).add(1).send(); + Ok(()) + } +} diff --git a/test-services/src/proxy.rs b/test-services/src/proxy.rs index 2812f69..36954f6 100644 --- a/test-services/src/proxy.rs +++ b/test-services/src/proxy.rs @@ -59,7 +59,11 @@ impl Proxy for ProxyImpl { ctx: Context<'_>, Json(req): Json, ) -> HandlerResult>> { - Ok(ctx.request(req.to_target(), req.message).call().await?) + Ok(ctx + .request::, Vec>(req.to_target(), req.message) + .call() + .await? + .into()) } async fn one_way_call( diff --git a/test-services/src/test_utils_service.rs b/test-services/src/test_utils_service.rs new file mode 100644 index 0000000..cbfc688 --- /dev/null +++ b/test-services/src/test_utils_service.rs @@ -0,0 +1,187 @@ +use crate::awakeable_holder; +use crate::list_object::ListObjectClient; +use futures::future::BoxFuture; +use futures::FutureExt; +use restate_sdk::prelude::*; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::atomic::{AtomicU8, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +#[derive(Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct CreateAwakeableAndAwaitItRequest { + awakeable_key: String, + await_timeout: Option, +} + +#[derive(Serialize, Deserialize)] +#[serde(tag = "type")] +#[serde(rename_all_fields = "camelCase")] +pub(crate) enum CreateAwakeableAndAwaitItResponse { + #[serde(rename = "timeout")] + Timeout, + #[serde(rename = "result")] + Result { value: String }, +} + +#[derive(Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct InterpretRequest { + list_name: String, + commands: Vec, +} + +#[derive(Serialize, Deserialize)] +#[serde(tag = "type")] +#[serde(rename_all_fields = "camelCase")] +pub(crate) enum InterpretCommand { + #[serde(rename = "createAwakeableAndAwaitIt")] + CreateAwakeableAndAwaitIt { awakeable_key: String }, + #[serde(rename = "getEnvVariable")] + GetEnvVariable { env_name: String }, +} + +#[restate_sdk::service] +#[name = "TestUtilsService"] +pub(crate) trait TestUtilsService { + #[name = "echo"] + async fn echo(input: String) -> HandlerResult; + #[name = "uppercaseEcho"] + async fn uppercase_echo(input: String) -> HandlerResult; + #[name = "echoHeaders"] + async fn echo_headers() -> HandlerResult>>; + #[name = "createAwakeableAndAwaitIt"] + async fn create_awakeable_and_await_it( + req: Json, + ) -> HandlerResult>; + #[name = "sleepConcurrently"] + async fn sleep_concurrently(millis_durations: Json>) -> HandlerResult<()>; + #[name = "countExecutedSideEffects"] + async fn count_executed_side_effects(increments: u32) -> HandlerResult; + #[name = "getEnvVariable"] + async fn get_env_variable(env: String) -> HandlerResult; + #[name = "interpretCommands"] + async fn interpret_commands(req: Json) -> HandlerResult<()>; +} + +pub(crate) struct TestUtilsServiceImpl; + +impl TestUtilsService for TestUtilsServiceImpl { + async fn echo(&self, _: Context<'_>, input: String) -> HandlerResult { + Ok(input) + } + + async fn uppercase_echo(&self, _: Context<'_>, input: String) -> HandlerResult { + Ok(input.to_ascii_uppercase()) + } + + async fn echo_headers( + &self, + context: Context<'_>, + ) -> HandlerResult>> { + let mut headers = HashMap::new(); + for k in context.headers().keys() { + headers.insert( + k.as_str().to_owned(), + context.headers().get(k).unwrap().clone(), + ); + } + + Ok(headers.into()) + } + + async fn create_awakeable_and_await_it( + &self, + context: Context<'_>, + Json(req): Json, + ) -> HandlerResult> { + if req.await_timeout.is_some() { + unimplemented!("await timeout is not yet implemented"); + } + + let (awk_id, awakeable) = context.awakeable::(); + + context + .object_client::(req.awakeable_key) + .hold(awk_id) + .call() + .await?; + let value = awakeable.await?; + + Ok(CreateAwakeableAndAwaitItResponse::Result { value }.into()) + } + + async fn sleep_concurrently( + &self, + context: Context<'_>, + millis_durations: Json>, + ) -> HandlerResult<()> { + let mut futures: Vec>> = vec![]; + + for duration in millis_durations.into_inner() { + futures.push(context.sleep(Duration::from_millis(duration)).boxed()); + } + + for fut in futures { + fut.await?; + } + + Ok(()) + } + + async fn count_executed_side_effects( + &self, + context: Context<'_>, + increments: u32, + ) -> HandlerResult { + let counter: Arc = Default::default(); + + for _ in 0..increments { + let counter_clone = Arc::clone(&counter); + context + .run("count", || async { + counter_clone.fetch_add(1, Ordering::SeqCst); + Ok(()) + }) + .await?; + } + + Ok(counter.load(Ordering::SeqCst) as u32) + } + + async fn get_env_variable(&self, _: Context<'_>, env: String) -> HandlerResult { + Ok(std::env::var(env).ok().unwrap_or_default()) + } + + async fn interpret_commands( + &self, + context: Context<'_>, + Json(req): Json, + ) -> HandlerResult<()> { + let list_client = context.object_client::(req.list_name); + + for cmd in req.commands { + match cmd { + InterpretCommand::CreateAwakeableAndAwaitIt { awakeable_key } => { + let (awk_id, awakeable) = context.awakeable::(); + context + .object_client::(awakeable_key) + .hold(awk_id) + .call() + .await?; + let value = awakeable.await?; + list_client.append(value).send(); + } + InterpretCommand::GetEnvVariable { env_name } => { + list_client + .append(std::env::var(env_name).ok().unwrap_or_default()) + .send(); + } + } + } + + Ok(()) + } +}