From 49c8f69184eb8cef1311eafd1bc259c9a461bcc8 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Wed, 21 Aug 2024 17:52:05 +0200 Subject: [PATCH] Moved hyper support in a separate module. One can now use the hyper support without necessarily using our server, nor tokio net. --- Cargo.toml | 7 ++- README.md | 2 +- examples/counter.rs | 2 +- examples/failures.rs | 2 +- examples/greeter.rs | 2 +- examples/run.rs | 2 +- src/http_server.rs | 79 ++++++++++++++++++++++++ src/{http.rs => hyper.rs} | 88 +++------------------------ src/lib.rs | 6 +- test-services/src/main.rs | 4 +- tests/ui/shared_handler_in_service.rs | 2 +- 11 files changed, 104 insertions(+), 92 deletions(-) create mode 100644 src/http_server.rs rename src/{http.rs => hyper.rs} (61%) diff --git a/Cargo.toml b/Cargo.toml index 2ed98fd..5c11426 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,14 +8,15 @@ repository = "https://github.com/restatedev/sdk-rust" [features] default = ["http_server", "rand", "uuid"] -http_server = ["hyper", "http-body-util", "hyper-util", "tokio/net", "tokio/signal", "restate-sdk-shared-core/http"] +hyper = ["dep:hyper", "http-body-util", "restate-sdk-shared-core/http"] +http_server = ["hyper", "hyper/server", "hyper/http2", "hyper-util", "tokio/net", "tokio/signal"] [dependencies] bytes = "1.6.1" futures = "0.3" http = "1.1.0" http-body-util = { version = "0.1", optional = true } -hyper = { version = "1.4.1", optional = true, features = ["server", "http2"] } +hyper = { version = "1.4.1", optional = true} hyper-util = { version = "0.1", features = ["tokio", "server", "server-graceful", "http2"], optional = true } pin-project-lite = "0.2" rand = { version = "0.8.5", optional = true } @@ -25,7 +26,7 @@ restate-sdk-shared-core = { version = "0.0.5" } serde = "1.0" serde_json = "1.0" thiserror = "1.0.63" -tokio = { version = "1", default-features = false, features = ["sync", "macros"] } +tokio = { version = "1", default-features = false, features = ["sync"] } tower-service = "0.3" tracing = "0.1" uuid = { version = "1.10.0", optional = true } diff --git a/README.md b/README.md index 7bc6d61..796e4a6 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ impl Greeter for GreeterImpl { async fn main() { // To enable logging/tracing // tracing_subscriber::fmt::init(); - HyperServer::new( + HttpServer::new( Endpoint::builder() .with_service(GreeterImpl.serve()) .build(), diff --git a/examples/counter.rs b/examples/counter.rs index 0d4e53a..b38b163 100644 --- a/examples/counter.rs +++ b/examples/counter.rs @@ -38,7 +38,7 @@ impl Counter for CounterImpl { #[tokio::main] async fn main() { tracing_subscriber::fmt::init(); - HyperServer::new( + HttpServer::new( Endpoint::builder() .with_service(CounterImpl.serve()) .build(), diff --git a/examples/failures.rs b/examples/failures.rs index 364ceac..bdc2c14 100644 --- a/examples/failures.rs +++ b/examples/failures.rs @@ -32,7 +32,7 @@ impl FailureExample for FailureExampleImpl { #[tokio::main] async fn main() { tracing_subscriber::fmt::init(); - HyperServer::new( + HttpServer::new( Endpoint::builder() .with_service(FailureExampleImpl.serve()) .build(), diff --git a/examples/greeter.rs b/examples/greeter.rs index 2c62270..3d4011b 100644 --- a/examples/greeter.rs +++ b/examples/greeter.rs @@ -16,7 +16,7 @@ impl Greeter for GreeterImpl { #[tokio::main] async fn main() { tracing_subscriber::fmt::init(); - HyperServer::new( + HttpServer::new( Endpoint::builder() .with_service(GreeterImpl.serve()) .build(), diff --git a/examples/run.rs b/examples/run.rs index 259c260..2df394a 100644 --- a/examples/run.rs +++ b/examples/run.rs @@ -33,7 +33,7 @@ impl RunExample for RunExampleImpl { #[tokio::main] async fn main() { tracing_subscriber::fmt::init(); - HyperServer::new( + HttpServer::new( Endpoint::builder() .with_service(RunExampleImpl(reqwest::Client::new()).serve()) .build(), diff --git a/src/http_server.rs b/src/http_server.rs new file mode 100644 index 0000000..95a604b --- /dev/null +++ b/src/http_server.rs @@ -0,0 +1,79 @@ +use crate::endpoint::Endpoint; +use crate::hyper::HyperEndpoint; +use futures::FutureExt; +use hyper::server::conn::http2; +use hyper_util::rt::{TokioExecutor, TokioIo}; +use std::future::Future; +use std::net::SocketAddr; +use std::time::Duration; +use tokio::net::TcpListener; +use tracing::{info, warn}; + +pub struct HttpServer { + endpoint: Endpoint, +} + +impl From for HttpServer { + fn from(endpoint: Endpoint) -> Self { + Self { endpoint } + } +} + +impl HttpServer { + pub fn new(endpoint: Endpoint) -> Self { + Self { endpoint } + } + + pub async fn listen_and_serve(self, addr: SocketAddr) { + let listener = TcpListener::bind(addr).await.expect("listener can bind"); + self.serve(listener).await; + } + + pub async fn serve(self, listener: TcpListener) { + self.serve_with_cancel(listener, tokio::signal::ctrl_c().map(|_| ())) + .await; + } + + pub async fn serve_with_cancel(self, listener: TcpListener, cancel_signal_future: impl Future) { + let endpoint = HyperEndpoint::new(self.endpoint); + let graceful = hyper_util::server::graceful::GracefulShutdown::new(); + + // when this signal completes, start shutdown + let mut signal = std::pin::pin!(cancel_signal_future); + + info!("Starting listening on {}", listener.local_addr().unwrap()); + + // Our server accept loop + loop { + tokio::select! { + Ok((stream, remote)) = listener.accept() => { + let endpoint = endpoint.clone(); + + let conn = http2::Builder::new(TokioExecutor::default()) + .serve_connection(TokioIo::new(stream), endpoint); + + let fut = graceful.watch(conn); + + tokio::spawn(async move { + if let Err(e) = fut.await { + warn!("Error serving connection {remote}: {:?}", e); + } + }); + }, + _ = &mut signal => { + info!("Shutting down"); + // stop the accept loop + break; + } + } + } + + // Wait graceful shutdown + tokio::select! { + _ = graceful.shutdown() => {}, + _ = tokio::time::sleep(Duration::from_secs(10)) => { + warn!("Timed out waiting for all connections to close"); + } + } + } +} diff --git a/src/http.rs b/src/hyper.rs similarity index 61% rename from src/http.rs rename to src/hyper.rs index a97d9b4..21e35df 100644 --- a/src/http.rs +++ b/src/hyper.rs @@ -3,98 +3,29 @@ use crate::endpoint::{Endpoint, InputReceiver, OutputSender}; use bytes::Bytes; use futures::future::BoxFuture; use futures::{FutureExt, TryStreamExt}; +use http::header::CONTENT_TYPE; +use http::{response, Request, Response}; use http_body_util::{BodyExt, Either, Full}; use hyper::body::{Body, Frame, Incoming}; -use hyper::header::CONTENT_TYPE; -use hyper::http::response; -use hyper::server::conn::http2; use hyper::service::Service; -use hyper::{Request, Response}; -use hyper_util::rt::{TokioExecutor, TokioIo}; use restate_sdk_shared_core::ResponseHead; use std::convert::Infallible; -use std::future::{ready, Future, Ready}; -use std::net::SocketAddr; +use std::future::{ready, Ready}; use std::ops::Deref; use std::pin::Pin; use std::task::{ready, Context, Poll}; -use std::time::Duration; -use tokio::net::TcpListener; use tokio::sync::mpsc; -use tracing::{info, warn}; +use tracing::warn; -pub struct HyperServer { - endpoint: Endpoint, -} - -impl From for HyperServer { - fn from(endpoint: Endpoint) -> Self { - Self { endpoint } - } -} +#[derive(Clone)] +pub struct HyperEndpoint(Endpoint); -impl HyperServer { +impl HyperEndpoint { pub fn new(endpoint: Endpoint) -> Self { - Self { endpoint } - } - - pub async fn listen_and_serve(self, addr: SocketAddr) { - let listener = TcpListener::bind(addr).await.expect("listener can bind"); - self.serve(listener).await; - } - - pub async fn serve(self, listener: TcpListener) { - self.serve_with_cancel(listener, tokio::signal::ctrl_c().map(|_| ())) - .await; - } - - pub async fn serve_with_cancel(self, listener: TcpListener, cancel_signal_future: impl Future) { - let endpoint = HyperEndpoint(self.endpoint); - let graceful = hyper_util::server::graceful::GracefulShutdown::new(); - - // when this signal completes, start shutdown - let mut signal = std::pin::pin!(cancel_signal_future); - - info!("Starting listening on {}", listener.local_addr().unwrap()); - - // Our server accept loop - loop { - tokio::select! { - Ok((stream, remote)) = listener.accept() => { - let endpoint = endpoint.clone(); - - let conn = http2::Builder::new(TokioExecutor::default()) - .serve_connection(TokioIo::new(stream), endpoint); - - let fut = graceful.watch(conn); - - tokio::spawn(async move { - if let Err(e) = fut.await { - warn!("Error serving connection {remote}: {:?}", e); - } - }); - }, - _ = &mut signal => { - info!("Shutting down"); - // stop the accept loop - break; - } - } - } - - // Wait graceful shutdown - tokio::select! { - _ = graceful.shutdown() => {}, - _ = tokio::time::sleep(Duration::from_secs(10)) => { - warn!("Timed out waiting for all connections to close"); - } - } + Self(endpoint) } } -#[derive(Clone)] -struct HyperEndpoint(Endpoint); - impl Service> for HyperEndpoint { type Response = Response, BidiStreamRunner>>; type Error = endpoint::Error; @@ -155,8 +86,7 @@ fn response_builder_from_response_head(response_head: ResponseHead) -> response: response_builder } -// TODO use pin_project -struct BidiStreamRunner { +pub struct BidiStreamRunner { fut: Option>>, output_rx: mpsc::UnboundedReceiver, end_stream: bool, diff --git a/src/lib.rs b/src/lib.rs index 1b173fd..4c22024 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,14 +5,16 @@ pub mod context; pub mod discovery; pub mod errors; #[cfg(feature = "http_server")] -pub mod http; +pub mod http_server; +#[cfg(feature = "hyper")] +pub mod hyper; pub mod serde; pub use restate_sdk_macros::{object, service, workflow}; pub mod prelude { #[cfg(feature = "http_server")] - pub use crate::http::HyperServer; + pub use crate::http_server::HttpServer; pub use crate::context::{ Context, ContextAwakeables, ContextClient, ContextPromises, ContextReadState, diff --git a/test-services/src/main.rs b/test-services/src/main.rs index 33e6d93..2c9bc59 100644 --- a/test-services/src/main.rs +++ b/test-services/src/main.rs @@ -10,7 +10,7 @@ mod non_deterministic; mod proxy; mod test_utils_service; -use restate_sdk::prelude::{Endpoint, HyperServer}; +use restate_sdk::prelude::{Endpoint, HttpServer}; use std::env; #[tokio::main] @@ -77,7 +77,7 @@ async fn main() { )) } - HyperServer::new(builder.build()) + HttpServer::new(builder.build()) .listen_and_serve(format!("0.0.0.0:{port}").parse().unwrap()) .await; } diff --git a/tests/ui/shared_handler_in_service.rs b/tests/ui/shared_handler_in_service.rs index f8b5980..ef98a45 100644 --- a/tests/ui/shared_handler_in_service.rs +++ b/tests/ui/shared_handler_in_service.rs @@ -17,7 +17,7 @@ impl SharedHandlerInService for SharedHandlerInServiceImpl { #[tokio::main] async fn main() { tracing_subscriber::fmt::init(); - HyperServer::new( + HttpServer::new( Endpoint::builder() .with_service(SharedHandlerInServiceImpl.serve()) .build(),