Skip to content

Replay aware logger #48

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@ license = "MIT"
repository = "https://github.com/restatedev/sdk-rust"
rust-version = "1.76.0"

[[example]]
name = "tracing"
path = "examples/tracing.rs"
required-features = ["tracing-span-filter"]

[features]
default = ["http_server", "rand", "uuid"]
default = ["http_server", "rand", "uuid", "tracing-span-filter"]
hyper = ["dep:hyper", "http-body-util", "restate-sdk-shared-core/http"]
http_server = ["hyper", "hyper/server", "hyper/http2", "hyper-util", "tokio/net", "tokio/signal", "tokio/macros"]

tracing-span-filter = ["dep:tracing-subscriber"]

[dependencies]
bytes = "1.6.1"
Expand All @@ -31,11 +36,12 @@ thiserror = "1.0.63"
tokio = { version = "1", default-features = false, features = ["sync"] }
tower-service = "0.3"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["registry"], optional = true }
uuid = { version = "1.10.0", optional = true }

[dev-dependencies]
tokio = { version = "1", features = ["full"] }
tracing-subscriber = "0.3"
tracing-subscriber = { version = "0.3", features = ["env-filter", "registry"] }
trybuild = "1.0"
reqwest = { version = "0.12", features = ["json"] }
rand = "0.8.5"
Expand Down
37 changes: 37 additions & 0 deletions examples/tracing.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use restate_sdk::prelude::*;
use std::time::Duration;
use tracing::info;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, Layer};

#[restate_sdk::service]
trait Greeter {
async fn greet(name: String) -> Result<String, HandlerError>;
}

struct GreeterImpl;

impl Greeter for GreeterImpl {
async fn greet(&self, ctx: Context<'_>, name: String) -> Result<String, HandlerError> {
info!("Before sleep");
ctx.sleep(Duration::from_secs(61)).await?; // More than suspension timeout to trigger replay
info!("After sleep");
Ok(format!("Greetings {name}"))
}
}

#[tokio::main]
async fn main() {
let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "info,restate_sdk=debug".into());
let replay_filter = restate_sdk::filter::ReplayAwareFilter;
tracing_subscriber::registry()
.with(
tracing_subscriber::fmt::layer()
.with_filter(env_filter)
.with_filter(replay_filter),
)
.init();
HttpServer::new(Endpoint::builder().bind(GreeterImpl.serve()).build())
.listen_and_serve("0.0.0.0:9080".parse().unwrap())
.await;
}
37 changes: 35 additions & 2 deletions src/endpoint/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ pub struct ContextInternalInner {
pub(crate) read: InputReceiver,
pub(crate) write: OutputSender,
pub(super) handler_state: HandlerStateNotifier,

/// We remember here the state of the span replaying field state, because setting it might be expensive (it's guarded behind locks and other stuff).
/// For details, see [ContextInternalInner::maybe_flip_span_replaying_field]
pub(super) span_replaying_field_state: bool,
}

impl ContextInternalInner {
Expand All @@ -44,17 +48,29 @@ impl ContextInternalInner {
read,
write,
handler_state,
span_replaying_field_state: false,
}
}

pub(super) fn fail(&mut self, e: Error) {
self.maybe_flip_span_replaying_field();
self.vm.notify_error(
CoreError::new(500u16, e.0.to_string())
.with_stacktrace(Cow::Owned(format!("{:#}", e.0))),
None,
);
self.handler_state.mark_error(e);
}

pub(super) fn maybe_flip_span_replaying_field(&mut self) {
if !self.span_replaying_field_state && self.vm.is_replaying() {
tracing::Span::current().record("restate.sdk.is_replaying", true);
self.span_replaying_field_state = true;
} else if self.span_replaying_field_state && !self.vm.is_replaying() {
tracing::Span::current().record("restate.sdk.is_replaying", false);
self.span_replaying_field_state = false;
}
}
}

/// Internal context interface.
Expand Down Expand Up @@ -190,6 +206,7 @@ impl ContextInternal {
},
))
});
inner_lock.maybe_flip_span_replaying_field();

match input_result {
Ok(Ok(i)) => {
Expand Down Expand Up @@ -223,6 +240,7 @@ impl ContextInternal {
) -> impl Future<Output = Result<Option<T>, TerminalError>> + Send {
let mut inner_lock = must_lock!(self.inner);
let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_state_get(key.to_owned()));
inner_lock.maybe_flip_span_replaying_field();

let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
Ok(Value::Void) => Ok(Ok(None)),
Expand All @@ -246,6 +264,7 @@ impl ContextInternal {
pub fn get_keys(&self) -> impl Future<Output = Result<Vec<String>, TerminalError>> + Send {
let mut inner_lock = must_lock!(self.inner);
let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_state_get_keys());
inner_lock.maybe_flip_span_replaying_field();

let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
Ok(Value::Failure(f)) => Ok(Err(f.into())),
Expand All @@ -266,6 +285,7 @@ impl ContextInternal {
match t.serialize() {
Ok(b) => {
let _ = inner_lock.vm.sys_state_set(key.to_owned(), b);
inner_lock.maybe_flip_span_replaying_field();
}
Err(e) => {
inner_lock.fail(Error::serialization("set_state", e));
Expand All @@ -274,11 +294,15 @@ impl ContextInternal {
}

pub fn clear(&self, key: &str) {
let _ = must_lock!(self.inner).vm.sys_state_clear(key.to_string());
let mut inner_lock = must_lock!(self.inner);
let _ = inner_lock.vm.sys_state_clear(key.to_string());
inner_lock.maybe_flip_span_replaying_field();
}

pub fn clear_all(&self) {
let _ = must_lock!(self.inner).vm.sys_state_clear_all();
let mut inner_lock = must_lock!(self.inner);
let _ = inner_lock.vm.sys_state_clear_all();
inner_lock.maybe_flip_span_replaying_field();
}

pub fn sleep(
Expand All @@ -293,6 +317,7 @@ impl ContextInternal {
inner_lock,
inner_lock.vm.sys_sleep(now + sleep_duration, Some(now))
);
inner_lock.maybe_flip_span_replaying_field();

let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
Ok(Value::Void) => Ok(Ok(())),
Expand Down Expand Up @@ -328,6 +353,7 @@ impl ContextInternal {
);

let call_handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_call(target, input));
inner_lock.maybe_flip_span_replaying_field();
drop(inner_lock);

// Let's prepare the two futures here
Expand Down Expand Up @@ -411,6 +437,7 @@ impl ContextInternal {
return Either::Right(TrapFuture::<()>::default());
}
};
inner_lock.maybe_flip_span_replaying_field();
drop(inner_lock);

let invocation_id_fut = InterceptErrorFuture::new(
Expand Down Expand Up @@ -452,6 +479,7 @@ impl ContextInternal {
) {
let mut inner_lock = must_lock!(self.inner);
let maybe_awakeable_id_and_handle = inner_lock.vm.sys_awakeable();
inner_lock.maybe_flip_span_replaying_field();

let (awakeable_id, handle) = match maybe_awakeable_id_and_handle {
Ok((s, handle)) => (s, handle),
Expand Down Expand Up @@ -512,6 +540,7 @@ impl ContextInternal {
) -> impl Future<Output = Result<T, TerminalError>> + Send {
let mut inner_lock = must_lock!(self.inner);
let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_get_promise(name.to_owned()));
inner_lock.maybe_flip_span_replaying_field();
drop(inner_lock);

let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
Expand All @@ -537,6 +566,7 @@ impl ContextInternal {
) -> impl Future<Output = Result<Option<T>, TerminalError>> + Send {
let mut inner_lock = must_lock!(self.inner);
let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_peek_promise(name.to_owned()));
inner_lock.maybe_flip_span_replaying_field();
drop(inner_lock);

let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
Expand Down Expand Up @@ -625,6 +655,7 @@ impl ContextInternal {
};

let _ = inner_lock.vm.sys_write_output(res_to_write);
inner_lock.maybe_flip_span_replaying_field();
}

pub fn end(&self) {
Expand Down Expand Up @@ -859,6 +890,7 @@ impl<InvIdFut: Future<Output = Result<String, TerminalError>> + Send> Invocation
let inv_id = cloned_invocation_id_fut.await?;
let mut inner_lock = must_lock!(cloned_ctx);
let _ = inner_lock.vm.sys_cancel_invocation(inv_id);
inner_lock.maybe_flip_span_replaying_field();
drop(inner_lock);
Ok(())
}
Expand Down Expand Up @@ -903,6 +935,7 @@ where
let inv_id = cloned_invocation_id_fut.await?;
let mut inner_lock = must_lock!(cloned_ctx);
let _ = inner_lock.vm.sys_cancel_invocation(inv_id);
inner_lock.maybe_flip_span_replaying_field();
drop(inner_lock);
Ok(())
}
Expand Down
3 changes: 3 additions & 0 deletions src/endpoint/futures/async_result_poll.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ impl Future for VmAsyncResultPollFuture {
}
};

// DoProgress might cause a flip of the replaying state
inner_lock.maybe_flip_span_replaying_field();

// At this point let's try to take the notification
match inner_lock.vm.take_notification(handle) {
Ok(Some(v)) => return Poll::Ready(Ok(v)),
Expand Down
9 changes: 9 additions & 0 deletions src/endpoint/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use std::future::poll_fn;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tracing::{info_span, Instrument};

const DISCOVERY_CONTENT_TYPE: &str = "application/vnd.restate.endpointmanifest.v1+json";

Expand Down Expand Up @@ -368,6 +369,13 @@ impl BidiStreamRunner {
.get(&self.svc_name)
.expect("service must exist at this point");

let span = info_span!(
"restate_sdk_endpoint_handle",
"rpc.system" = "restate",
"rpc.service" = self.svc_name,
"rpc.method" = self.handler_name,
"restate.sdk.is_replaying" = false
);
handle(
input_rx,
output_tx,
Expand All @@ -376,6 +384,7 @@ impl BidiStreamRunner {
self.handler_name,
svc,
)
.instrument(span)
.await
}
}
Expand Down
90 changes: 90 additions & 0 deletions src/filter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
//! Replay aware tracing filter.

use std::fmt::Debug;
use tracing::{
field::{Field, Visit},
span::{Attributes, Record},
Event, Id, Metadata, Subscriber,
};
use tracing_subscriber::{
layer::{Context, Filter},
registry::LookupSpan,
Layer,
};

#[derive(Debug)]
struct ReplayField(bool);

struct ReplayFieldVisitor(bool);

impl Visit for ReplayFieldVisitor {
fn record_bool(&mut self, field: &Field, value: bool) {
if field.name().eq("restate.sdk.is_replaying") {
self.0 = value;
}
}

fn record_debug(&mut self, _field: &Field, _value: &dyn Debug) {}
}

/// Replay aware tracing filter.
///
/// Use this filter to skip tracing events in the service while replaying:
///
/// ```rust,no_run
/// use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, Layer};
/// tracing_subscriber::registry()
/// .with(
/// tracing_subscriber::fmt::layer()
/// // Default Env filter to read RUST_LOG
/// .with_filter(tracing_subscriber::EnvFilter::from_default_env())
/// // Replay aware filter
/// .with_filter(restate_sdk::filter::ReplayAwareFilter)
/// )
/// .init();
/// ```
pub struct ReplayAwareFilter;

impl<S: Subscriber + for<'lookup> LookupSpan<'lookup>> Filter<S> for ReplayAwareFilter {
fn enabled(&self, _meta: &Metadata<'_>, _cx: &Context<'_, S>) -> bool {
true
}

fn event_enabled(&self, event: &Event<'_>, cx: &Context<'_, S>) -> bool {
if let Some(scope) = cx.event_scope(event) {
let iterator = scope.from_root();
for span in iterator {
if span.name() == "restate_sdk_endpoint_handle" {
if let Some(replay) = span.extensions().get::<ReplayField>() {
return !replay.0;
}
}
}
}
true
}

fn on_new_span(&self, attrs: &Attributes<'_>, id: &Id, ctx: Context<'_, S>) {
if let Some(span) = ctx.span(id) {
if span.name() == "restate_sdk_endpoint_handle" {
let mut visitor = ReplayFieldVisitor(false);
attrs.record(&mut visitor);
let mut extensions = span.extensions_mut();
extensions.replace::<ReplayField>(ReplayField(visitor.0));
}
}
}

fn on_record(&self, id: &Id, values: &Record<'_>, ctx: Context<'_, S>) {
if let Some(span) = ctx.span(id) {
if span.name() == "restate_sdk_endpoint_handle" {
let mut visitor = ReplayFieldVisitor(false);
values.record(&mut visitor);
let mut extensions = span.extensions_mut();
extensions.replace::<ReplayField>(ReplayField(visitor.0));
}
}
}
}

impl<S: Subscriber> Layer<S> for ReplayAwareFilter {}
6 changes: 5 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,9 @@
//! }
//! ```
//!
//! For more information, have a look at the [tracing subscriber doc](https://docs.rs/tracing-subscriber/latest/tracing_subscriber/fmt/index.html#filtering-events-with-environment-variables).
//! You can filter logs *when a handler is being replayed* configuring the [filter::ReplayAwareFilter].
//!
//! For more information about tracing and logging, have a look at the [tracing subscriber doc](https://docs.rs/tracing-subscriber/latest/tracing_subscriber/fmt/index.html#filtering-events-with-environment-variables).
//!
//! Next, have a look at the other [SDK features](#features).
//!
Expand All @@ -218,6 +220,8 @@ pub mod service;
pub mod context;
pub mod discovery;
pub mod errors;
#[cfg(feature = "tracing-span-filter")]
pub mod filter;
#[cfg(feature = "http_server")]
pub mod http_server;
#[cfg(feature = "hyper")]
Expand Down
Loading