diff --git a/Cargo.toml b/Cargo.toml index af1f95d..69fe447 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,4 +60,4 @@ syn = "2.0" typify = { version = "0.1.0" } [workspace] -members = ["macros", "test-services", "test-env"] +members = ["macros", "test-services", "testcontainers"] diff --git a/README.md b/README.md index 6af7b1a..0aac168 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ Just configure it as usual through [`tracing_subscriber`](https://docs.rs/tracin ### Testing The SDK uses [Testcontainers](https://rust.testcontainers.org/) to support integration testing using a Docker-deployed restate server. -The `restate-sdk-test-env` crate provides a framework for initializing the test environment, and an integration test example in `test-env/tests/test_container.rs`. +The `restate-sdk-testcontainers` crate provides a framework for initializing the test environment, and an integration test example in `testcontainers/tests/test_container.rs`. ```rust #[tokio::test] diff --git a/justfile b/justfile index c9f0df3..d52df38 100644 --- a/justfile +++ b/justfile @@ -64,7 +64,7 @@ print-target: @echo {{ _resolved_target }} test: (_target-installed target) - cargo nextest run {{ _target-option }} --all-features + cargo nextest run {{ _target-option }} --all-features --workspace doctest: cargo test --doc diff --git a/src/endpoint/mod.rs b/src/endpoint/mod.rs index c32d0a2..f554911 100644 --- a/src/endpoint/mod.rs +++ b/src/endpoint/mod.rs @@ -284,6 +284,14 @@ impl Endpoint { let parts: Vec<&str> = path.split('/').collect(); + if parts.last() == Some(&"health") { + return Ok(Response::ReplyNow { + status_code: 200, + headers: vec![], + body: Bytes::new(), + }); + } + if parts.last() == Some(&"discover") { let accept_header = headers .extract("accept") diff --git a/test-env/src/lib.rs b/test-env/src/lib.rs deleted file mode 100644 index ea10e77..0000000 --- a/test-env/src/lib.rs +++ /dev/null @@ -1,321 +0,0 @@ -use restate_sdk::{ - errors::HandlerError, - prelude::{Endpoint, HttpServer}, -}; -use serde::{Deserialize, Serialize}; -use std::time::Duration; -use testcontainers::{ - core::{IntoContainerPort, WaitFor}, - runners::AsyncRunner, - ContainerAsync, ContainerRequest, GenericImage, ImageExt, -}; -use tokio::{ - io::AsyncBufReadExt, - net::TcpListener, - task::{self, JoinHandle}, -}; -use tracing::{error, info, warn}; - -// addapted from from restate-admin-rest-model crate version 1.1.6 -#[derive(Serialize, Deserialize, Debug)] -pub struct RegisterDeploymentRequestHttp { - uri: String, - additional_headers: Option>, - use_http_11: bool, - force: bool, - dry_run: bool, -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct RegisterDeploymentRequestLambda { - arn: String, - assume_role_arn: Option, - force: bool, - dry_run: bool, -} - -#[derive(Serialize, Deserialize, Debug)] -struct VersionResponse { - version: String, - min_admin_api_version: u32, - max_admin_api_version: u32, -} - -pub struct TestContainerBuilder { - container_name: String, - container_tag: String, - logging: bool, -} - -impl Default for TestContainerBuilder { - fn default() -> Self { - TestContainerBuilder { - container_name: "docker.io/restatedev/restate".to_string(), - container_tag: "latest".to_string(), - logging: false, - } - } -} - -impl TestContainerBuilder { - pub fn with_container_logging(mut self) -> TestContainerBuilder { - self.logging = true; - self - } - - pub fn with_container( - mut self, - container_name: String, - container_tag: String, - ) -> TestContainerBuilder { - self.container_name = container_name; - self.container_tag = container_tag; - - self - } - - pub fn build(self) -> TestContainer { - TestContainer { - builder: self, - container: None, - endpoint_task: None, - endpoint_server_ip: None, - endpoint_server_port: None, - endpoint_server_url: None, - ingress_url: None, - stdout_logging: None, - stderr_logging: None, - } - } -} - -pub struct TestContainer { - builder: TestContainerBuilder, - container: Option>, - endpoint_task: Option>, - endpoint_server_ip: Option, - endpoint_server_port: Option, - endpoint_server_url: Option, - ingress_url: Option, - stdout_logging: Option>, - stderr_logging: Option>, -} - -impl Default for TestContainer { - fn default() -> Self { - TestContainerBuilder::default().build() - } -} - -impl TestContainer { - pub fn builder() -> TestContainerBuilder { - TestContainerBuilder::default() - } - - pub async fn start(mut self, endpoint: Endpoint) -> Result { - self.serve_endpoint(endpoint).await?; - self.start_container().await?; - let registered = self.register_endpoint().await; - if registered.is_err() { - return Err(anyhow::anyhow!("Failed to register endpoint")); - } - - Ok(self) - } - - async fn serve_endpoint(&mut self, endpoint: Endpoint) -> Result<(), anyhow::Error> { - info!("Starting endpoint server..."); - - // use port 0 to allow tokio to assign unused port number - let host_address = "127.0.0.1:0".to_string(); - let listener = TcpListener::bind(host_address) - .await - .expect("listener can bind"); - self.endpoint_server_ip = Some(listener.local_addr().unwrap().ip().to_string()); - self.endpoint_server_port = Some(listener.local_addr().unwrap().port()); - self.endpoint_server_url = Some(format!( - "http://{}:{}", - self.endpoint_server_ip.as_ref().unwrap(), - self.endpoint_server_port.as_ref().unwrap() - )); - - // boot endpoint server - self.endpoint_task = Some(tokio::spawn(async move { - HttpServer::new(endpoint).serve(listener).await; - })); - - let client = reqwest::Client::builder().http2_prior_knowledge().build()?; - - // wait for endpoint server to respond - let mut retries = 0; - while client - .get(format!( - "{}/discover", - self.endpoint_server_url.as_ref().unwrap() - )) - .header("accept", "application/vnd.restate.endpointmanifest.v1+json") - .send() - .await - .is_err() - { - tokio::time::sleep(Duration::from_millis(100)).await; - - warn!("retrying endpoint server"); - - retries += 1; - if retries > 10 { - return Err(anyhow::anyhow!("endpoint server failed to start")); - } - } - - info!( - "endpoint server: {}", - self.endpoint_server_url.as_ref().unwrap() - ); - - Ok(()) - } - - async fn start_container(&mut self) -> Result<(), anyhow::Error> { - let image = GenericImage::new(&self.builder.container_name, &self.builder.container_tag) - .with_exposed_port(9070.tcp()) - .with_exposed_port(8080.tcp()) - .with_wait_for(WaitFor::message_on_stdout("Ingress HTTP listening")); - - // have to expose entire host network because testcontainer-rs doesn't implement selective SSH port forward from host - // see https://github.com/testcontainers/testcontainers-rs/issues/535 - self.container = Some( - ContainerRequest::from(image) - .with_host( - "host.docker.internal", - testcontainers::core::Host::HostGateway, - ) - .start() - .await?, - ); - - if self.builder.logging { - let container_stdout = self.container.as_ref().unwrap().stdout(true); - let mut stdout_lines = container_stdout.lines(); - - // Spawn a task to copy data from the AsyncBufRead to stdout - let stdout_logging = task::spawn(async move { - while let Some(line) = stdout_lines.next_line().await.transpose() { - match line { - Ok(line) => { - // Log each line using tracing - info!("{}", line); - } - Err(e) => { - // Log any errors - error!("Error reading from container stream: {}", e); - break; - } - } - } - }); - - self.stderr_logging = Some(stdout_logging); - - let container_stderr = self.container.as_ref().unwrap().stderr(true); - let mut stderr_lines = container_stderr.lines(); - - // Spawn a task to copy data from the AsyncBufRead to stderr - let stderr_logging = task::spawn(async move { - while let Some(line) = stderr_lines.next_line().await.transpose() { - match line { - Ok(line) => { - // Log each line using tracing - error!("{}", line); - } - Err(e) => { - // Log any errors - error!("Error reading from container stream: {}", e); - break; - } - } - } - }); - - self.stderr_logging = Some(stderr_logging); - } - - let host = self.container.as_ref().unwrap().get_host().await?; - let ports = self.container.as_ref().unwrap().ports().await?; - - let admin_port = ports.map_to_host_port_ipv4(9070.tcp()).unwrap(); - - let admin_url = format!("http://{}:{}/version", host, admin_port); - reqwest::get(admin_url) - .await? - .json::() - .await?; - - Ok(()) - } - - async fn register_endpoint(&mut self) -> Result<(), HandlerError> { - info!( - "registering endpoint server: {}", - self.endpoint_server_url.as_ref().unwrap() - ); - - let host = self.container.as_ref().unwrap().get_host().await?; - let ports = self.container.as_ref().unwrap().ports().await?; - - let admin_port = ports.map_to_host_port_ipv4(9070.tcp()).unwrap(); - - let client = reqwest::Client::builder().http2_prior_knowledge().build()?; - - let deployment_uri: String = format!( - "http://host.docker.internal:{}/", - self.endpoint_server_port.unwrap() - ); - let deployment_payload = RegisterDeploymentRequestHttp { - uri: deployment_uri, - additional_headers: None, - use_http_11: false, - force: false, - dry_run: false, - }; //, additional_headers: (), use_http_11: (), force: (), dry_run: () } - - let register_admin_url = format!("http://{}:{}/deployments", host, admin_port); - - client - .post(register_admin_url) - .json(&deployment_payload) - .send() - .await?; - - let ingress_port = ports.map_to_host_port_ipv4(8080.tcp()).unwrap(); - self.ingress_url = Some(format!("http://{}:{}", host, ingress_port)); - - info!("ingress url: {}", self.ingress_url.as_ref().unwrap()); - - Ok(()) - } - - pub fn ingress_url(&self) -> String { - self.ingress_url.clone().unwrap() - } -} - -impl Drop for TestContainer { - fn drop(&mut self) { - // testcontainers-rs already implements stop/rm on drop] - // https://docs.rs/testcontainers/latest/testcontainers/ - - // clean up tokio tasks - if self.endpoint_task.is_some() { - self.endpoint_task.take().unwrap().abort(); - } - - if self.stdout_logging.is_some() { - self.stdout_logging.take().unwrap().abort(); - } - - if self.stderr_logging.is_some() { - self.stderr_logging.take().unwrap().abort(); - } - } -} diff --git a/test-env/Cargo.toml b/testcontainers/Cargo.toml similarity index 67% rename from test-env/Cargo.toml rename to testcontainers/Cargo.toml index 206d44e..7cd2279 100644 --- a/test-env/Cargo.toml +++ b/testcontainers/Cargo.toml @@ -1,8 +1,8 @@ [package] -name = "restate-sdk-test-env" +name = "restate-sdk-testcontainers" version = "0.4.0" edition = "2021" -description = "Test Utilities for Restate SDK for Rust" +description = "Restate SDK Testcontainers utilities" license = "MIT" repository = "https://github.com/restatedev/sdk-rust" rust-version = "1.76.0" @@ -10,12 +10,11 @@ rust-version = "1.76.0" [dependencies] anyhow = "1.0.95" -nu-ansi-term = "0.50.1" +futures = "0.3.31" reqwest = { version= "0.12.12", features = ["json"] } restate-sdk = { version = "0.4.0", path = "../" } serde = "1.0.217" -serde_json = "1.0.138" -testcontainers = "0.23.1" +testcontainers = { version = "0.23.3", features = ["http_wait"] } tokio = "1.43.0" tracing = "0.1.41" tracing-subscriber = "0.3.19" diff --git a/testcontainers/src/lib.rs b/testcontainers/src/lib.rs new file mode 100644 index 0000000..30dfeec --- /dev/null +++ b/testcontainers/src/lib.rs @@ -0,0 +1,292 @@ +use anyhow::Context; +use futures::FutureExt; +use restate_sdk::prelude::{Endpoint, HttpServer}; +use serde::{Deserialize, Serialize}; +use testcontainers::core::wait::HttpWaitStrategy; +use testcontainers::{ + core::{IntoContainerPort, WaitFor}, + runners::AsyncRunner, + ContainerAsync, ContainerRequest, GenericImage, ImageExt, +}; +use tokio::{io::AsyncBufReadExt, net::TcpListener, task}; +use tracing::{error, info, warn}; + +// From restate-admin-rest-model +#[derive(Serialize, Deserialize, Debug)] +pub struct RegisterDeploymentRequestHttp { + uri: String, + additional_headers: Option>, + use_http_11: bool, + force: bool, + dry_run: bool, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct RegisterDeploymentRequestLambda { + arn: String, + assume_role_arn: Option, + force: bool, + dry_run: bool, +} + +#[derive(Serialize, Deserialize, Debug)] +struct VersionResponse { + version: String, + min_admin_api_version: u32, + max_admin_api_version: u32, +} + +pub struct TestEnvironment { + container_name: String, + container_tag: String, + logging: bool, +} + +impl Default for TestEnvironment { + fn default() -> Self { + Self { + container_name: "docker.io/restatedev/restate".to_string(), + container_tag: "latest".to_string(), + logging: false, + } + } +} + +impl TestEnvironment { + // --- Builder methods + + pub fn new() -> Self { + Self::default() + } + + pub fn with_container_logging(mut self) -> Self { + self.logging = true; + self + } + + pub fn with_container(mut self, container_name: String, container_tag: String) -> Self { + self.container_name = container_name; + self.container_tag = container_tag; + + self + } + + // --- Start method + + pub async fn start(self, endpoint: Endpoint) -> Result { + let started_endpoint = StartedEndpoint::serve_endpoint(endpoint).await?; + let started_restate_container = StartedRestateContainer::start_container(&self).await?; + if let Err(e) = started_restate_container + .register_endpoint(&started_endpoint) + .await + { + return Err(anyhow::anyhow!("Failed to register endpoint: {e}")); + } + + Ok(StartedTestEnvironment { + _started_endpoint: started_endpoint, + started_restate_container, + }) + } +} + +struct StartedEndpoint { + port: u16, + _cancel_tx: tokio::sync::oneshot::Sender<()>, +} + +impl StartedEndpoint { + async fn serve_endpoint(endpoint: Endpoint) -> Result { + info!("Starting endpoint server..."); + + // 0.0.0.0:0 will listen on a random port, both IPv4 and IPv6 + let host_address = "0.0.0.0:0".to_string(); + let listener = TcpListener::bind(host_address) + .await + .expect("listener can bind"); + let listening_addr = listener.local_addr()?; + let endpoint_server_url = + format!("http://{}:{}", listening_addr.ip(), listening_addr.port()); + + // Start endpoint server + let (cancel_tx, cancel_rx) = tokio::sync::oneshot::channel(); + tokio::spawn(async move { + HttpServer::new(endpoint) + .serve_with_cancel(listener, cancel_rx) + .await; + }); + + let client = reqwest::Client::builder().http2_prior_knowledge().build()?; + + // wait for endpoint server to respond + let mut retries = 0; + loop { + match client + .get(format!("{endpoint_server_url}/health",)) + .send() + .await + { + Ok(res) if res.status().is_success() => break, + Ok(res) => { + warn!("Error when waiting for service endpoint server to be healthy, got response {}", res.status()); + retries += 1; + if retries > 10 { + anyhow::bail!("Service endpoint server failed to start") + } + } + Err(err) => { + warn!("Error when waiting for service endpoint server to be healthy, got error {}", err); + retries += 1; + if retries > 10 { + anyhow::bail!("Service endpoint server failed to start") + } + } + } + } + + info!("Service endpoint server listening at: {endpoint_server_url}",); + + Ok(StartedEndpoint { + port: listening_addr.port(), + _cancel_tx: cancel_tx, + }) + } +} + +struct StartedRestateContainer { + _cancel_tx: tokio::sync::oneshot::Sender<()>, + container: ContainerAsync, + ingress_url: String, +} + +impl StartedRestateContainer { + async fn start_container( + test_environment: &TestEnvironment, + ) -> Result { + let image = GenericImage::new( + &test_environment.container_name, + &test_environment.container_tag, + ) + .with_exposed_port(8080.tcp()) + .with_exposed_port(9070.tcp()) + .with_wait_for(WaitFor::Http( + HttpWaitStrategy::new("/restate/health") + .with_port(8080.tcp()) + .with_response_matcher(|res| res.status().is_success()), + )) + .with_wait_for(WaitFor::Http( + HttpWaitStrategy::new("/health") + .with_port(9070.tcp()) + .with_response_matcher(|res| res.status().is_success()), + )); + + // Start container + let container = ContainerRequest::from(image) + // have to expose entire host network because testcontainer-rs doesn't implement selective SSH port forward from host + // see https://github.com/testcontainers/testcontainers-rs/issues/535 + .with_host( + "host.docker.internal", + testcontainers::core::Host::HostGateway, + ) + .start() + .await?; + + let (cancel_tx, cancel_rx) = tokio::sync::oneshot::channel(); + if test_environment.logging { + let container_stdout = container.stdout(true); + let mut stdout_lines = container_stdout.lines(); + let container_stderr = container.stderr(true); + let mut stderr_lines = container_stderr.lines(); + + // Spawn a task to copy data from the AsyncBufRead to stdout + task::spawn(async move { + tokio::pin!(cancel_rx); + loop { + tokio::select! { + Some(stdout_line) = stdout_lines.next_line().map(|res| res.transpose()) => { + match stdout_line { + Ok(line) => info!("{}", line), + Err(e) => { + error!("Error reading stdout from container stream: {}", e); + break; + } + } + }, + Some(stderr_line) = stderr_lines.next_line().map(|res| res.transpose()) => { + match stderr_line { + Ok(line) => warn!("{}", line), + Err(e) => { + error!("Error reading stderr from container stream: {}", e); + break; + } + } + } + _ = &mut cancel_rx => { + break; + } + } + } + }); + } + + // Resolve ingress url + let host = container.get_host().await?; + let ports = container.ports().await?; + let ingress_port = ports.map_to_host_port_ipv4(8080.tcp()).unwrap(); + let ingress_url = format!("http://{}:{}", host, ingress_port); + + info!("Restate container started, listening on requests at {ingress_url}"); + + Ok(StartedRestateContainer { + _cancel_tx: cancel_tx, + container, + ingress_url, + }) + } + + async fn register_endpoint(&self, endpoint: &StartedEndpoint) -> Result<(), anyhow::Error> { + let host = self.container.get_host().await?; + let ports = self.container.ports().await?; + let admin_port = ports.map_to_host_port_ipv4(9070.tcp()).unwrap(); + + let client = reqwest::Client::builder().http2_prior_knowledge().build()?; + + let deployment_uri: String = format!("http://host.docker.internal:{}/", endpoint.port); + let deployment_payload = RegisterDeploymentRequestHttp { + uri: deployment_uri, + additional_headers: None, + use_http_11: false, + force: false, + dry_run: false, + }; + + let register_admin_url = format!("http://{}:{}/deployments", host, admin_port); + + let response = client + .post(register_admin_url) + .json(&deployment_payload) + .send() + .await + .context("Error when trying to register the service endpoint")?; + + if !response.status().is_success() { + anyhow::bail!( + "Got non success status code when trying to register the service endpoint: {}", + response.status() + ) + } + + Ok(()) + } +} + +pub struct StartedTestEnvironment { + _started_endpoint: StartedEndpoint, + started_restate_container: StartedRestateContainer, +} + +impl StartedTestEnvironment { + pub fn ingress_url(&self) -> String { + self.started_restate_container.ingress_url.clone() + } +} diff --git a/test-env/tests/test_container.rs b/testcontainers/tests/test_container.rs similarity index 81% rename from test-env/tests/test_container.rs rename to testcontainers/tests/test_container.rs index 647cf83..5de3dd5 100644 --- a/test-env/tests/test_container.rs +++ b/testcontainers/tests/test_container.rs @@ -1,9 +1,8 @@ use reqwest::StatusCode; use restate_sdk::prelude::*; -use restate_sdk_test_env::TestContainer; +use restate_sdk_testcontainers::TestEnvironment; use tracing::info; -// Should compile #[restate_sdk::service] trait MyService { async fn my_handler() -> HandlerResult; @@ -40,30 +39,27 @@ async fn test_container() { let endpoint = Endpoint::builder().bind(MyServiceImpl.serve()).build(); - // simple test container intialization with default configuration + // simple test container initialization with default configuration //let test_container = TestContainer::default().start(endpoint).await.unwrap(); // custom test container initialization with builder - let test_container = TestContainer::builder() - // optional passthrough logging from the resstate server testcontainer + let test_environment = TestEnvironment::new() + // optional passthrough logging from the restate server testcontainers // prints container logs to tracing::info level .with_container_logging() .with_container( "docker.io/restatedev/restate".to_string(), "latest".to_string(), ) - .build() .start(endpoint) .await .unwrap(); - let ingress_url = test_container.ingress_url(); + let ingress_url = test_environment.ingress_url(); // call container ingress url for /MyService/my_handler let response = reqwest::Client::new() .post(format!("{}/MyService/my_handler", ingress_url)) - .header("Accept", "application/json") - .header("Content-Type", "*/*") .header("idempotency-key", "abc") .send() .await