diff --git a/Cargo.lock b/Cargo.lock index 2f7f1b2b45b..a4751a43982 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -593,6 +593,7 @@ dependencies = [ "axum", "futures-util", "http", + "http-body", "hyper", "percent-encoding", "sentry-core", diff --git a/conduit-axum/Cargo.toml b/conduit-axum/Cargo.toml index 800dec466ed..590b43cc748 100644 --- a/conduit-axum/Cargo.toml +++ b/conduit-axum/Cargo.toml @@ -12,6 +12,7 @@ rust-version = "1.56.0" axum = "=0.6.2" hyper = { version = "=0.14.23", features = ["server", "stream"] } http = "=0.2.8" +http-body = "=0.4.5" percent-encoding = "=2.2.0" sentry-core = { version = "=0.29.1", features = ["client"] } thiserror = "=1.0.38" diff --git a/conduit-axum/src/conduit.rs b/conduit-axum/src/conduit.rs index 603e8baf86f..1059698ba90 100644 --- a/conduit-axum/src/conduit.rs +++ b/conduit-axum/src/conduit.rs @@ -1,12 +1,13 @@ -use axum::async_trait; use axum::body::Bytes; use axum::extract::FromRequest; +use axum::response::IntoResponse; +use axum::{async_trait, RequestExt}; +use http_body::LengthLimitError; use hyper::Body; use std::error::Error; use std::io::Cursor; use std::ops::{Deref, DerefMut}; -use crate::fallback::check_content_length; use crate::response::AxumResponse; use crate::server_error_response; pub use http::{header, Extensions, HeaderMap, Method, Request, Response, StatusCode, Uri}; @@ -54,16 +55,31 @@ where type Rejection = AxumResponse; async fn from_request(req: Request, _state: &S) -> Result { - check_content_length(&req)?; + let request = match req.with_limited_body() { + Ok(req) => { + let (parts, body) = req.into_parts(); - let (parts, body) = req.into_parts(); + let bytes = hyper::body::to_bytes(body).await.map_err(|err| { + if err.downcast_ref::().is_some() { + StatusCode::BAD_REQUEST.into_response() + } else { + server_error_response(&*err) + } + })?; - let full_body = match hyper::body::to_bytes(body).await { - Ok(body) => body, - Err(err) => return Err(server_error_response(&err)), + Request::from_parts(parts, bytes) + } + Err(req) => { + let (parts, body) = req.into_parts(); + + let bytes = hyper::body::to_bytes(body) + .await + .map_err(|err| server_error_response(&err))?; + + Request::from_parts(parts, bytes) + } }; - let request = Request::from_parts(parts, Cursor::new(full_body)); - Ok(ConduitRequest(request)) + Ok(ConduitRequest(request.map(Cursor::new))) } } diff --git a/conduit-axum/src/fallback.rs b/conduit-axum/src/fallback.rs index 7d7acd71320..7cbe6761b85 100644 --- a/conduit-axum/src/fallback.rs +++ b/conduit-axum/src/fallback.rs @@ -3,19 +3,10 @@ use crate::response::AxumResponse; use std::error::Error; -use axum::body::{Body, HttpBody}; use axum::extract::Extension; use axum::response::IntoResponse; -use http::header::CONTENT_LENGTH; use http::StatusCode; -use hyper::Request; -use tracing::{error, warn}; - -/// The maximum size allowed in the `Content-Length` header -/// -/// Chunked requests may grow to be larger over time if that much data is actually sent. -/// See the usage section of the README if you plan to use this server in production. -const MAX_CONTENT_LENGTH: u64 = 128 * 1024 * 1024; // 128 MB +use tracing::error; #[derive(Clone, Debug)] pub struct ErrorField(pub String); @@ -42,42 +33,3 @@ pub fn server_error_response(error: &E) -> AxumResponse { ) .into_response() } - -/// Check for `Content-Length` values that are invalid or too large -/// -/// If a `Content-Length` is provided then `hyper::body::to_bytes()` may try to allocate a buffer -/// of this size upfront, leading to a process abort and denial of service to other clients. -/// -/// This only checks for requests that claim to be too large. If the request is chunked then it -/// is possible to allocate larger chunks of memory over time, by actually sending large volumes of -/// data. Request sizes must be limited higher in the stack to protect against this type of attack. -pub(crate) fn check_content_length(request: &Request) -> Result<(), AxumResponse> { - fn bad_request(message: &str) -> AxumResponse { - warn!("Bad request: Content-Length {}", message); - StatusCode::BAD_REQUEST.into_response() - } - - if let Some(content_length) = request.headers().get(CONTENT_LENGTH) { - let content_length = match content_length.to_str() { - Ok(some) => some, - Err(_) => return Err(bad_request("not ASCII")), - }; - - let content_length = match content_length.parse::() { - Ok(some) => some, - Err(_) => return Err(bad_request("not a u64")), - }; - - if content_length > MAX_CONTENT_LENGTH { - return Err(bad_request("too large")); - } - } - - // A duplicate check, aligning with the specific impl of `hyper::body::to_bytes` - // (at the time of this writing) - if request.size_hint().lower() > MAX_CONTENT_LENGTH { - return Err(bad_request("size_hint().lower() too large")); - } - - Ok(()) -} diff --git a/conduit-axum/src/tests.rs b/conduit-axum/src/tests.rs index fcdf536b0a0..821ceccaff2 100644 --- a/conduit-axum/src/tests.rs +++ b/conduit-axum/src/tests.rs @@ -1,4 +1,5 @@ use crate::{server_error_response, spawn_blocking, ConduitRequest, HandlerResult, ServiceError}; +use axum::extract::DefaultBodyLimit; use axum::response::IntoResponse; use axum::Router; use http::header::HeaderName; @@ -111,7 +112,9 @@ async fn spawn_http_server() -> ( let (quit_tx, quit_rx) = oneshot::channel::<()>(); let addr = ([127, 0, 0, 1], 0).into(); - let router = Router::new().fallback(ok_result); + let router = Router::new() + .fallback(ok_result) + .layer(DefaultBodyLimit::max(4096)); let make_service = router.into_make_service(); let server = hyper::Server::bind(&addr).serve(make_service); @@ -125,8 +128,7 @@ async fn spawn_http_server() -> ( #[tokio::test] async fn content_length_too_large() { - const ACTUAL_BODY_SIZE: usize = 10_000; - const CLAIMED_CONTENT_LENGTH: u64 = 11_111_111_111_111_111_111; + const ACTUAL_BODY_SIZE: usize = 4097; let (url, server, quit_tx) = spawn_http_server().await; @@ -136,10 +138,7 @@ async fn content_length_too_large() { .send_data(vec![0; ACTUAL_BODY_SIZE].into()) .await .unwrap(); - let req = hyper::Request::put(url) - .header(hyper::header::CONTENT_LENGTH, CLAIMED_CONTENT_LENGTH) - .body(body) - .unwrap(); + let req = hyper::Request::put(url).body(body).unwrap(); let resp = client .request(req) diff --git a/src/router.rs b/src/router.rs index f5fcea9dd47..53c993918c4 100644 --- a/src/router.rs +++ b/src/router.rs @@ -1,3 +1,4 @@ +use axum::extract::DefaultBodyLimit; use axum::response::IntoResponse; use axum::routing::{delete, get, post, put}; use axum::Router; @@ -7,12 +8,17 @@ use crate::controllers::*; use crate::util::errors::not_found; use crate::Env; +const MAX_PUBLISH_CONTENT_LENGTH: usize = 128 * 1024 * 1024; // 128 MB + pub fn build_axum_router(state: AppState) -> Router { let mut router = Router::new() // Route used by both `cargo search` and the frontend .route("/api/v1/crates", get(krate::search::search)) // Routes used by `cargo` - .route("/api/v1/crates/new", put(krate::publish::publish)) + .route( + "/api/v1/crates/new", + put(krate::publish::publish).layer(DefaultBodyLimit::max(MAX_PUBLISH_CONTENT_LENGTH)), + ) .route( "/api/v1/crates/:crate_id/owners", get(krate::owners::owners)