diff --git a/Cargo.lock b/Cargo.lock
index 2f7f1b2b45b..a4751a43982 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -593,6 +593,7 @@ dependencies = [
"axum",
"futures-util",
"http",
+ "http-body",
"hyper",
"percent-encoding",
"sentry-core",
diff --git a/conduit-axum/Cargo.toml b/conduit-axum/Cargo.toml
index 800dec466ed..590b43cc748 100644
--- a/conduit-axum/Cargo.toml
+++ b/conduit-axum/Cargo.toml
@@ -12,6 +12,7 @@ rust-version = "1.56.0"
axum = "=0.6.2"
hyper = { version = "=0.14.23", features = ["server", "stream"] }
http = "=0.2.8"
+http-body = "=0.4.5"
percent-encoding = "=2.2.0"
sentry-core = { version = "=0.29.1", features = ["client"] }
thiserror = "=1.0.38"
diff --git a/conduit-axum/src/conduit.rs b/conduit-axum/src/conduit.rs
index 603e8baf86f..1059698ba90 100644
--- a/conduit-axum/src/conduit.rs
+++ b/conduit-axum/src/conduit.rs
@@ -1,12 +1,13 @@
-use axum::async_trait;
use axum::body::Bytes;
use axum::extract::FromRequest;
+use axum::response::IntoResponse;
+use axum::{async_trait, RequestExt};
+use http_body::LengthLimitError;
use hyper::Body;
use std::error::Error;
use std::io::Cursor;
use std::ops::{Deref, DerefMut};
-use crate::fallback::check_content_length;
use crate::response::AxumResponse;
use crate::server_error_response;
pub use http::{header, Extensions, HeaderMap, Method, Request, Response, StatusCode, Uri};
@@ -54,16 +55,31 @@ where
type Rejection = AxumResponse;
async fn from_request(req: Request
, _state: &S) -> Result {
- check_content_length(&req)?;
+ let request = match req.with_limited_body() {
+ Ok(req) => {
+ let (parts, body) = req.into_parts();
- let (parts, body) = req.into_parts();
+ let bytes = hyper::body::to_bytes(body).await.map_err(|err| {
+ if err.downcast_ref::().is_some() {
+ StatusCode::BAD_REQUEST.into_response()
+ } else {
+ server_error_response(&*err)
+ }
+ })?;
- let full_body = match hyper::body::to_bytes(body).await {
- Ok(body) => body,
- Err(err) => return Err(server_error_response(&err)),
+ Request::from_parts(parts, bytes)
+ }
+ Err(req) => {
+ let (parts, body) = req.into_parts();
+
+ let bytes = hyper::body::to_bytes(body)
+ .await
+ .map_err(|err| server_error_response(&err))?;
+
+ Request::from_parts(parts, bytes)
+ }
};
- let request = Request::from_parts(parts, Cursor::new(full_body));
- Ok(ConduitRequest(request))
+ Ok(ConduitRequest(request.map(Cursor::new)))
}
}
diff --git a/conduit-axum/src/fallback.rs b/conduit-axum/src/fallback.rs
index 7d7acd71320..7cbe6761b85 100644
--- a/conduit-axum/src/fallback.rs
+++ b/conduit-axum/src/fallback.rs
@@ -3,19 +3,10 @@ use crate::response::AxumResponse;
use std::error::Error;
-use axum::body::{Body, HttpBody};
use axum::extract::Extension;
use axum::response::IntoResponse;
-use http::header::CONTENT_LENGTH;
use http::StatusCode;
-use hyper::Request;
-use tracing::{error, warn};
-
-/// The maximum size allowed in the `Content-Length` header
-///
-/// Chunked requests may grow to be larger over time if that much data is actually sent.
-/// See the usage section of the README if you plan to use this server in production.
-const MAX_CONTENT_LENGTH: u64 = 128 * 1024 * 1024; // 128 MB
+use tracing::error;
#[derive(Clone, Debug)]
pub struct ErrorField(pub String);
@@ -42,42 +33,3 @@ pub fn server_error_response(error: &E) -> AxumResponse {
)
.into_response()
}
-
-/// Check for `Content-Length` values that are invalid or too large
-///
-/// If a `Content-Length` is provided then `hyper::body::to_bytes()` may try to allocate a buffer
-/// of this size upfront, leading to a process abort and denial of service to other clients.
-///
-/// This only checks for requests that claim to be too large. If the request is chunked then it
-/// is possible to allocate larger chunks of memory over time, by actually sending large volumes of
-/// data. Request sizes must be limited higher in the stack to protect against this type of attack.
-pub(crate) fn check_content_length(request: &Request) -> Result<(), AxumResponse> {
- fn bad_request(message: &str) -> AxumResponse {
- warn!("Bad request: Content-Length {}", message);
- StatusCode::BAD_REQUEST.into_response()
- }
-
- if let Some(content_length) = request.headers().get(CONTENT_LENGTH) {
- let content_length = match content_length.to_str() {
- Ok(some) => some,
- Err(_) => return Err(bad_request("not ASCII")),
- };
-
- let content_length = match content_length.parse::() {
- Ok(some) => some,
- Err(_) => return Err(bad_request("not a u64")),
- };
-
- if content_length > MAX_CONTENT_LENGTH {
- return Err(bad_request("too large"));
- }
- }
-
- // A duplicate check, aligning with the specific impl of `hyper::body::to_bytes`
- // (at the time of this writing)
- if request.size_hint().lower() > MAX_CONTENT_LENGTH {
- return Err(bad_request("size_hint().lower() too large"));
- }
-
- Ok(())
-}
diff --git a/conduit-axum/src/tests.rs b/conduit-axum/src/tests.rs
index fcdf536b0a0..821ceccaff2 100644
--- a/conduit-axum/src/tests.rs
+++ b/conduit-axum/src/tests.rs
@@ -1,4 +1,5 @@
use crate::{server_error_response, spawn_blocking, ConduitRequest, HandlerResult, ServiceError};
+use axum::extract::DefaultBodyLimit;
use axum::response::IntoResponse;
use axum::Router;
use http::header::HeaderName;
@@ -111,7 +112,9 @@ async fn spawn_http_server() -> (
let (quit_tx, quit_rx) = oneshot::channel::<()>();
let addr = ([127, 0, 0, 1], 0).into();
- let router = Router::new().fallback(ok_result);
+ let router = Router::new()
+ .fallback(ok_result)
+ .layer(DefaultBodyLimit::max(4096));
let make_service = router.into_make_service();
let server = hyper::Server::bind(&addr).serve(make_service);
@@ -125,8 +128,7 @@ async fn spawn_http_server() -> (
#[tokio::test]
async fn content_length_too_large() {
- const ACTUAL_BODY_SIZE: usize = 10_000;
- const CLAIMED_CONTENT_LENGTH: u64 = 11_111_111_111_111_111_111;
+ const ACTUAL_BODY_SIZE: usize = 4097;
let (url, server, quit_tx) = spawn_http_server().await;
@@ -136,10 +138,7 @@ async fn content_length_too_large() {
.send_data(vec![0; ACTUAL_BODY_SIZE].into())
.await
.unwrap();
- let req = hyper::Request::put(url)
- .header(hyper::header::CONTENT_LENGTH, CLAIMED_CONTENT_LENGTH)
- .body(body)
- .unwrap();
+ let req = hyper::Request::put(url).body(body).unwrap();
let resp = client
.request(req)
diff --git a/src/router.rs b/src/router.rs
index f5fcea9dd47..53c993918c4 100644
--- a/src/router.rs
+++ b/src/router.rs
@@ -1,3 +1,4 @@
+use axum::extract::DefaultBodyLimit;
use axum::response::IntoResponse;
use axum::routing::{delete, get, post, put};
use axum::Router;
@@ -7,12 +8,17 @@ use crate::controllers::*;
use crate::util::errors::not_found;
use crate::Env;
+const MAX_PUBLISH_CONTENT_LENGTH: usize = 128 * 1024 * 1024; // 128 MB
+
pub fn build_axum_router(state: AppState) -> Router {
let mut router = Router::new()
// Route used by both `cargo search` and the frontend
.route("/api/v1/crates", get(krate::search::search))
// Routes used by `cargo`
- .route("/api/v1/crates/new", put(krate::publish::publish))
+ .route(
+ "/api/v1/crates/new",
+ put(krate::publish::publish).layer(DefaultBodyLimit::max(MAX_PUBLISH_CONTENT_LENGTH)),
+ )
.route(
"/api/v1/crates/:crate_id/owners",
get(krate::owners::owners)