diff --git a/Cargo.toml b/Cargo.toml index a5104b2..09227c7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,11 +7,16 @@ license = "MIT" repository = "https://github.com/restatedev/sdk-rust" rust-version = "1.76.0" +[[example]] +name = "tracing" +path = "examples/tracing.rs" +required-features = ["tracing-span-filter"] + [features] -default = ["http_server", "rand", "uuid"] +default = ["http_server", "rand", "uuid", "tracing-span-filter"] hyper = ["dep:hyper", "http-body-util", "restate-sdk-shared-core/http"] http_server = ["hyper", "hyper/server", "hyper/http2", "hyper-util", "tokio/net", "tokio/signal", "tokio/macros"] - +tracing-span-filter = ["dep:tracing-subscriber"] [dependencies] bytes = "1.6.1" @@ -31,11 +36,12 @@ thiserror = "1.0.63" tokio = { version = "1", default-features = false, features = ["sync"] } tower-service = "0.3" tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["registry"], optional = true } uuid = { version = "1.10.0", optional = true } [dev-dependencies] tokio = { version = "1", features = ["full"] } -tracing-subscriber = "0.3" +tracing-subscriber = { version = "0.3", features = ["env-filter", "registry"] } trybuild = "1.0" reqwest = { version = "0.12", features = ["json"] } rand = "0.8.5" diff --git a/examples/tracing.rs b/examples/tracing.rs new file mode 100644 index 0000000..19f6995 --- /dev/null +++ b/examples/tracing.rs @@ -0,0 +1,37 @@ +use restate_sdk::prelude::*; +use std::time::Duration; +use tracing::info; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, Layer}; + +#[restate_sdk::service] +trait Greeter { + async fn greet(name: String) -> Result; +} + +struct GreeterImpl; + +impl Greeter for GreeterImpl { + async fn greet(&self, ctx: Context<'_>, name: String) -> Result { + info!("Before sleep"); + ctx.sleep(Duration::from_secs(61)).await?; // More than suspension timeout to trigger replay + info!("After sleep"); + Ok(format!("Greetings {name}")) + } +} + +#[tokio::main] +async fn main() { + let env_filter = tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "info,restate_sdk=debug".into()); + let replay_filter = restate_sdk::filter::ReplayAwareFilter; + tracing_subscriber::registry() + .with( + tracing_subscriber::fmt::layer() + .with_filter(env_filter) + .with_filter(replay_filter), + ) + .init(); + HttpServer::new(Endpoint::builder().bind(GreeterImpl.serve()).build()) + .listen_and_serve("0.0.0.0:9080".parse().unwrap()) + .await; +} diff --git a/src/endpoint/context.rs b/src/endpoint/context.rs index 89dd52c..66e2bfb 100644 --- a/src/endpoint/context.rs +++ b/src/endpoint/context.rs @@ -30,6 +30,10 @@ pub struct ContextInternalInner { pub(crate) read: InputReceiver, pub(crate) write: OutputSender, pub(super) handler_state: HandlerStateNotifier, + + /// We remember here the state of the span replaying field state, because setting it might be expensive (it's guarded behind locks and other stuff). + /// For details, see [ContextInternalInner::maybe_flip_span_replaying_field] + pub(super) span_replaying_field_state: bool, } impl ContextInternalInner { @@ -44,10 +48,12 @@ impl ContextInternalInner { read, write, handler_state, + span_replaying_field_state: false, } } pub(super) fn fail(&mut self, e: Error) { + self.maybe_flip_span_replaying_field(); self.vm.notify_error( CoreError::new(500u16, e.0.to_string()) .with_stacktrace(Cow::Owned(format!("{:#}", e.0))), @@ -55,6 +61,16 @@ impl ContextInternalInner { ); self.handler_state.mark_error(e); } + + pub(super) fn maybe_flip_span_replaying_field(&mut self) { + if !self.span_replaying_field_state && self.vm.is_replaying() { + tracing::Span::current().record("restate.sdk.is_replaying", true); + self.span_replaying_field_state = true; + } else if self.span_replaying_field_state && !self.vm.is_replaying() { + tracing::Span::current().record("restate.sdk.is_replaying", false); + self.span_replaying_field_state = false; + } + } } /// Internal context interface. @@ -190,6 +206,7 @@ impl ContextInternal { }, )) }); + inner_lock.maybe_flip_span_replaying_field(); match input_result { Ok(Ok(i)) => { @@ -223,6 +240,7 @@ impl ContextInternal { ) -> impl Future, TerminalError>> + Send { let mut inner_lock = must_lock!(self.inner); let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_state_get(key.to_owned())); + inner_lock.maybe_flip_span_replaying_field(); let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res { Ok(Value::Void) => Ok(Ok(None)), @@ -246,6 +264,7 @@ impl ContextInternal { pub fn get_keys(&self) -> impl Future, TerminalError>> + Send { let mut inner_lock = must_lock!(self.inner); let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_state_get_keys()); + inner_lock.maybe_flip_span_replaying_field(); let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res { Ok(Value::Failure(f)) => Ok(Err(f.into())), @@ -266,6 +285,7 @@ impl ContextInternal { match t.serialize() { Ok(b) => { let _ = inner_lock.vm.sys_state_set(key.to_owned(), b); + inner_lock.maybe_flip_span_replaying_field(); } Err(e) => { inner_lock.fail(Error::serialization("set_state", e)); @@ -274,11 +294,15 @@ impl ContextInternal { } pub fn clear(&self, key: &str) { - let _ = must_lock!(self.inner).vm.sys_state_clear(key.to_string()); + let mut inner_lock = must_lock!(self.inner); + let _ = inner_lock.vm.sys_state_clear(key.to_string()); + inner_lock.maybe_flip_span_replaying_field(); } pub fn clear_all(&self) { - let _ = must_lock!(self.inner).vm.sys_state_clear_all(); + let mut inner_lock = must_lock!(self.inner); + let _ = inner_lock.vm.sys_state_clear_all(); + inner_lock.maybe_flip_span_replaying_field(); } pub fn sleep( @@ -293,6 +317,7 @@ impl ContextInternal { inner_lock, inner_lock.vm.sys_sleep(now + sleep_duration, Some(now)) ); + inner_lock.maybe_flip_span_replaying_field(); let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res { Ok(Value::Void) => Ok(Ok(())), @@ -328,6 +353,7 @@ impl ContextInternal { ); let call_handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_call(target, input)); + inner_lock.maybe_flip_span_replaying_field(); drop(inner_lock); // Let's prepare the two futures here @@ -411,6 +437,7 @@ impl ContextInternal { return Either::Right(TrapFuture::<()>::default()); } }; + inner_lock.maybe_flip_span_replaying_field(); drop(inner_lock); let invocation_id_fut = InterceptErrorFuture::new( @@ -452,6 +479,7 @@ impl ContextInternal { ) { let mut inner_lock = must_lock!(self.inner); let maybe_awakeable_id_and_handle = inner_lock.vm.sys_awakeable(); + inner_lock.maybe_flip_span_replaying_field(); let (awakeable_id, handle) = match maybe_awakeable_id_and_handle { Ok((s, handle)) => (s, handle), @@ -512,6 +540,7 @@ impl ContextInternal { ) -> impl Future> + Send { let mut inner_lock = must_lock!(self.inner); let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_get_promise(name.to_owned())); + inner_lock.maybe_flip_span_replaying_field(); drop(inner_lock); let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res { @@ -537,6 +566,7 @@ impl ContextInternal { ) -> impl Future, TerminalError>> + Send { let mut inner_lock = must_lock!(self.inner); let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_peek_promise(name.to_owned())); + inner_lock.maybe_flip_span_replaying_field(); drop(inner_lock); let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res { @@ -625,6 +655,7 @@ impl ContextInternal { }; let _ = inner_lock.vm.sys_write_output(res_to_write); + inner_lock.maybe_flip_span_replaying_field(); } pub fn end(&self) { @@ -859,6 +890,7 @@ impl> + Send> Invocation let inv_id = cloned_invocation_id_fut.await?; let mut inner_lock = must_lock!(cloned_ctx); let _ = inner_lock.vm.sys_cancel_invocation(inv_id); + inner_lock.maybe_flip_span_replaying_field(); drop(inner_lock); Ok(()) } @@ -903,6 +935,7 @@ where let inv_id = cloned_invocation_id_fut.await?; let mut inner_lock = must_lock!(cloned_ctx); let _ = inner_lock.vm.sys_cancel_invocation(inv_id); + inner_lock.maybe_flip_span_replaying_field(); drop(inner_lock); Ok(()) } diff --git a/src/endpoint/futures/async_result_poll.rs b/src/endpoint/futures/async_result_poll.rs index 63eaadb..6d84abb 100644 --- a/src/endpoint/futures/async_result_poll.rs +++ b/src/endpoint/futures/async_result_poll.rs @@ -128,6 +128,9 @@ impl Future for VmAsyncResultPollFuture { } }; + // DoProgress might cause a flip of the replaying state + inner_lock.maybe_flip_span_replaying_field(); + // At this point let's try to take the notification match inner_lock.vm.take_notification(handle) { Ok(Some(v)) => return Poll::Ready(Ok(v)), diff --git a/src/endpoint/mod.rs b/src/endpoint/mod.rs index 0bea694..241e45a 100644 --- a/src/endpoint/mod.rs +++ b/src/endpoint/mod.rs @@ -18,6 +18,7 @@ use std::future::poll_fn; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +use tracing::{info_span, Instrument}; const DISCOVERY_CONTENT_TYPE: &str = "application/vnd.restate.endpointmanifest.v1+json"; @@ -368,6 +369,13 @@ impl BidiStreamRunner { .get(&self.svc_name) .expect("service must exist at this point"); + let span = info_span!( + "restate_sdk_endpoint_handle", + "rpc.system" = "restate", + "rpc.service" = self.svc_name, + "rpc.method" = self.handler_name, + "restate.sdk.is_replaying" = false + ); handle( input_rx, output_tx, @@ -376,6 +384,7 @@ impl BidiStreamRunner { self.handler_name, svc, ) + .instrument(span) .await } } diff --git a/src/filter.rs b/src/filter.rs new file mode 100644 index 0000000..c3ae61e --- /dev/null +++ b/src/filter.rs @@ -0,0 +1,90 @@ +//! Replay aware tracing filter. + +use std::fmt::Debug; +use tracing::{ + field::{Field, Visit}, + span::{Attributes, Record}, + Event, Id, Metadata, Subscriber, +}; +use tracing_subscriber::{ + layer::{Context, Filter}, + registry::LookupSpan, + Layer, +}; + +#[derive(Debug)] +struct ReplayField(bool); + +struct ReplayFieldVisitor(bool); + +impl Visit for ReplayFieldVisitor { + fn record_bool(&mut self, field: &Field, value: bool) { + if field.name().eq("restate.sdk.is_replaying") { + self.0 = value; + } + } + + fn record_debug(&mut self, _field: &Field, _value: &dyn Debug) {} +} + +/// Replay aware tracing filter. +/// +/// Use this filter to skip tracing events in the service while replaying: +/// +/// ```rust,no_run +/// use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, Layer}; +/// tracing_subscriber::registry() +/// .with( +/// tracing_subscriber::fmt::layer() +/// // Default Env filter to read RUST_LOG +/// .with_filter(tracing_subscriber::EnvFilter::from_default_env()) +/// // Replay aware filter +/// .with_filter(restate_sdk::filter::ReplayAwareFilter) +/// ) +/// .init(); +/// ``` +pub struct ReplayAwareFilter; + +impl LookupSpan<'lookup>> Filter for ReplayAwareFilter { + fn enabled(&self, _meta: &Metadata<'_>, _cx: &Context<'_, S>) -> bool { + true + } + + fn event_enabled(&self, event: &Event<'_>, cx: &Context<'_, S>) -> bool { + if let Some(scope) = cx.event_scope(event) { + let iterator = scope.from_root(); + for span in iterator { + if span.name() == "restate_sdk_endpoint_handle" { + if let Some(replay) = span.extensions().get::() { + return !replay.0; + } + } + } + } + true + } + + fn on_new_span(&self, attrs: &Attributes<'_>, id: &Id, ctx: Context<'_, S>) { + if let Some(span) = ctx.span(id) { + if span.name() == "restate_sdk_endpoint_handle" { + let mut visitor = ReplayFieldVisitor(false); + attrs.record(&mut visitor); + let mut extensions = span.extensions_mut(); + extensions.replace::(ReplayField(visitor.0)); + } + } + } + + fn on_record(&self, id: &Id, values: &Record<'_>, ctx: Context<'_, S>) { + if let Some(span) = ctx.span(id) { + if span.name() == "restate_sdk_endpoint_handle" { + let mut visitor = ReplayFieldVisitor(false); + values.record(&mut visitor); + let mut extensions = span.extensions_mut(); + extensions.replace::(ReplayField(visitor.0)); + } + } + } +} + +impl Layer for ReplayAwareFilter {} diff --git a/src/lib.rs b/src/lib.rs index 842211b..533e192 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -207,7 +207,9 @@ //! } //! ``` //! -//! For more information, have a look at the [tracing subscriber doc](https://docs.rs/tracing-subscriber/latest/tracing_subscriber/fmt/index.html#filtering-events-with-environment-variables). +//! You can filter logs *when a handler is being replayed* configuring the [filter::ReplayAwareFilter]. +//! +//! For more information about tracing and logging, have a look at the [tracing subscriber doc](https://docs.rs/tracing-subscriber/latest/tracing_subscriber/fmt/index.html#filtering-events-with-environment-variables). //! //! Next, have a look at the other [SDK features](#features). //! @@ -218,6 +220,8 @@ pub mod service; pub mod context; pub mod discovery; pub mod errors; +#[cfg(feature = "tracing-span-filter")] +pub mod filter; #[cfg(feature = "http_server")] pub mod http_server; #[cfg(feature = "hyper")]