Skip to content

Commit 0754b83

Browse files
authored
Merge pull request #5885 from Turbo87/parts-trait
Implement `RequestPartsExt` trait
2 parents ba583c8 + aa66d7e commit 0754b83

File tree

7 files changed

+93
-18
lines changed

7 files changed

+93
-18
lines changed

src/auth.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use crate::controllers;
2+
use crate::controllers::util::RequestPartsExt;
23
use crate::middleware::app::RequestApp;
34
use crate::middleware::log_request::CustomMetadataRequestExt;
45
use crate::middleware::session::RequestSession;
@@ -8,7 +9,7 @@ use crate::util::errors::{
89
account_locked, forbidden, internal, AppError, AppResult, InsecurelyGeneratedTokenRevoked,
910
};
1011
use chrono::Utc;
11-
use http::{header, Request};
12+
use http::header;
1213

1314
#[derive(Debug, Clone)]
1415
pub struct AuthCheck {
@@ -54,7 +55,7 @@ impl AuthCheck {
5455
}
5556
}
5657

57-
pub fn check<B>(&self, request: &Request<B>) -> AppResult<Authentication> {
58+
pub fn check<T: RequestPartsExt>(&self, request: &T) -> AppResult<Authentication> {
5859
let auth = authenticate_user(request)?;
5960

6061
if let Some(token) = auth.api_token() {
@@ -151,7 +152,7 @@ impl Authentication {
151152
}
152153
}
153154

154-
fn authenticate_user<B>(req: &Request<B>) -> AppResult<Authentication> {
155+
fn authenticate_user<T: RequestPartsExt>(req: &T) -> AppResult<Authentication> {
155156
controllers::util::verify_origin(req)?;
156157

157158
let conn = req.app().db_write()?;

src/controllers.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ mod prelude {
2020
pub use http::{header, Request, StatusCode};
2121

2222
pub use super::conduit_axum::conduit_compat;
23+
use crate::controllers::util::RequestPartsExt;
2324
pub use crate::middleware::app::RequestApp;
2425
pub use crate::util::errors::{cargo_err, AppError, AppResult, BoxedAppError};
2526
use indexmap::IndexMap;
@@ -34,7 +35,7 @@ mod prelude {
3435
fn query_with_params(&self, params: IndexMap<String, String>) -> String;
3536
}
3637

37-
impl<B> RequestUtils for Request<B> {
38+
impl<T: RequestPartsExt> RequestUtils for T {
3839
fn query(&self) -> IndexMap<String, String> {
3940
url::form_urlencoded::parse(self.uri().query().unwrap_or("").as_bytes())
4041
.into_owned()

src/controllers/helpers/pagination.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::config::Server;
22
use crate::controllers::prelude::*;
3+
use crate::controllers::util::RequestPartsExt;
34
use crate::middleware::log_request::CustomMetadataRequestExt;
45
use crate::models::helpers::with_count::*;
56
use crate::util::errors::{bad_request, AppResult};
@@ -72,7 +73,7 @@ impl PaginationOptionsBuilder {
7273
self
7374
}
7475

75-
pub(crate) fn gather<B>(self, req: &Request<B>) -> AppResult<PaginationOptions> {
76+
pub(crate) fn gather<T: RequestPartsExt>(self, req: &T) -> AppResult<PaginationOptions> {
7677
let params = req.query();
7778
let page_param = params.get("page");
7879
let seek_param = params.get("seek");

src/controllers/util.rs

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
use super::prelude::*;
22
use crate::util::errors::{forbidden, internal, AppError, AppResult};
3-
use http::Request;
3+
use http::request::Parts;
4+
use http::{Extensions, HeaderMap, HeaderValue, Method, Request, Uri, Version};
45

56
/// The Origin header (https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin)
67
/// is sent with CORS requests and POST requests, and indicates where the request comes from.
78
/// We don't want to accept authenticated requests that originated from other sites, so this
89
/// function returns an error if the Origin header doesn't match what we expect "this site" to
910
/// be: https://crates.io in production, or http://localhost:port/ in development.
10-
pub fn verify_origin<B>(req: &Request<B>) -> AppResult<()> {
11+
pub fn verify_origin<T: RequestPartsExt>(req: &T) -> AppResult<()> {
1112
let headers = req.headers();
1213
let allowed_origins = &req.app().config.allowed_origins;
1314

@@ -23,3 +24,75 @@ pub fn verify_origin<B>(req: &Request<B>) -> AppResult<()> {
2324
}
2425
Ok(())
2526
}
27+
28+
pub trait RequestPartsExt {
29+
fn method(&self) -> &Method;
30+
fn uri(&self) -> &Uri;
31+
fn version(&self) -> Version;
32+
fn headers(&self) -> &HeaderMap<HeaderValue>;
33+
fn extensions(&self) -> &Extensions;
34+
fn extensions_mut(&mut self) -> &mut Extensions;
35+
}
36+
37+
impl RequestPartsExt for Parts {
38+
fn method(&self) -> &Method {
39+
&self.method
40+
}
41+
fn uri(&self) -> &Uri {
42+
&self.uri
43+
}
44+
fn version(&self) -> Version {
45+
self.version
46+
}
47+
fn headers(&self) -> &HeaderMap<HeaderValue> {
48+
&self.headers
49+
}
50+
fn extensions(&self) -> &Extensions {
51+
&self.extensions
52+
}
53+
fn extensions_mut(&mut self) -> &mut Extensions {
54+
&mut self.extensions
55+
}
56+
}
57+
58+
impl<B> RequestPartsExt for Request<B> {
59+
fn method(&self) -> &Method {
60+
self.method()
61+
}
62+
fn uri(&self) -> &Uri {
63+
self.uri()
64+
}
65+
fn version(&self) -> Version {
66+
self.version()
67+
}
68+
fn headers(&self) -> &HeaderMap<HeaderValue> {
69+
self.headers()
70+
}
71+
fn extensions(&self) -> &Extensions {
72+
self.extensions()
73+
}
74+
fn extensions_mut(&mut self) -> &mut Extensions {
75+
self.extensions_mut()
76+
}
77+
}
78+
79+
impl RequestPartsExt for ConduitRequest {
80+
fn method(&self) -> &Method {
81+
self.0.method()
82+
}
83+
fn uri(&self) -> &Uri {
84+
self.0.uri()
85+
}
86+
fn version(&self) -> Version {
87+
self.0.version()
88+
}
89+
fn headers(&self) -> &HeaderMap<HeaderValue> {
90+
self.0.headers()
91+
}
92+
fn extensions(&self) -> &Extensions {
93+
self.0.extensions()
94+
}
95+
fn extensions_mut(&mut self) -> &mut Extensions {
96+
self.0.extensions_mut()
97+
}
98+
}

src/middleware/app.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use axum::response::Response;
44
use http::Request;
55

66
use crate::app::AppState;
7+
use crate::controllers::util::RequestPartsExt;
78

89
/// `axum` middleware that injects the `AppState` instance into the `Request` extensions.
910
pub async fn add_app_state_extension<B>(
@@ -21,7 +22,7 @@ pub trait RequestApp {
2122
fn app(&self) -> &AppState;
2223
}
2324

24-
impl<T> RequestApp for Request<T> {
25+
impl<T: RequestPartsExt> RequestApp for T {
2526
fn app(&self) -> &AppState {
2627
self.extensions().get::<AppState>().expect("Missing app")
2728
}

src/middleware/log_request.rs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! Log all requests in a format similar to Heroku's router, but with additional
22
//! information that we care about like User-Agent
33
4+
use crate::controllers::util::RequestPartsExt;
45
use crate::headers::{XRealIp, XRequestId};
56
use crate::middleware::normalize_path::OriginalPath;
67
use axum::headers::UserAgent;
@@ -140,23 +141,19 @@ pub async fn log_requests<B>(
140141
pub struct CustomMetadata(Arc<Mutex<Vec<(&'static str, String)>>>);
141142

142143
pub trait CustomMetadataRequestExt {
144+
fn add_custom_metadata<V: Display>(&self, key: &'static str, value: V);
145+
}
146+
147+
impl<T: RequestPartsExt> CustomMetadataRequestExt for T {
143148
fn add_custom_metadata<V: Display>(&self, key: &'static str, value: V) {
144-
if let Some(metadata) = self.metadata_extension() {
149+
if let Some(metadata) = self.extensions().get::<CustomMetadata>() {
145150
if let Ok(mut metadata) = metadata.lock() {
146151
metadata.push((key, value.to_string()));
147152
}
148153
}
149154

150155
sentry::configure_scope(|scope| scope.set_extra(key, value.to_string().into()));
151156
}
152-
153-
fn metadata_extension(&self) -> Option<&CustomMetadata>;
154-
}
155-
156-
impl<B> CustomMetadataRequestExt for Request<B> {
157-
fn metadata_extension(&self) -> Option<&CustomMetadata> {
158-
self.extensions().get::<CustomMetadata>()
159-
}
160157
}
161158

162159
struct LogLine<'f, 'g> {

src/middleware/session.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use crate::controllers::util::RequestPartsExt;
12
use axum::middleware::Next;
23
use axum::response::{IntoResponse, Response};
34
use axum_extra::extract::SignedCookieJar;
@@ -63,7 +64,7 @@ pub trait RequestSession {
6364
fn session_remove(&mut self, key: &str) -> Option<String>;
6465
}
6566

66-
impl<T> RequestSession for Request<T> {
67+
impl<T: RequestPartsExt> RequestSession for T {
6768
fn session_get(&self, key: &str) -> Option<String> {
6869
let session = self
6970
.extensions()

0 commit comments

Comments
 (0)