From af8a315ebf46e7ce73a33ab67745b3c2d9f68859 Mon Sep 17 00:00:00 2001 From: eth3lbert Date: Mon, 15 Jan 2024 17:26:07 +0800 Subject: [PATCH 1/5] krate/search: Extract query filtering logic into reusable function Enable consistent filtering for both data retrieval and count queries, and ensure accurate totals in seek-based pagination. --- src/controllers/krate/search.rs | 350 ++++++++++++++++++-------------- 1 file changed, 203 insertions(+), 147 deletions(-) diff --git a/src/controllers/krate/search.rs b/src/controllers/krate/search.rs index c0f3d744944..c32ad61f0db 100644 --- a/src/controllers/krate/search.rs +++ b/src/controllers/krate/search.rs @@ -5,6 +5,7 @@ use diesel::dsl::*; use diesel::sql_types::Array; use diesel_full_text_search::*; use indexmap::IndexMap; +use once_cell::sync::OnceCell; use crate::controllers::cargo_prelude::*; use crate::controllers::helpers::Paginate; @@ -40,31 +41,46 @@ use crate::sql::{array_agg, canon_crate_name, lower}; /// for them. pub async fn search(app: AppState, req: Parts) -> AppResult> { spawn_blocking(move || { - use diesel::sql_types::{Bool, Text}; + use diesel::sql_types::Bool; let params = req.query(); - let sort = params.get("sort").map(|s| &**s); - let include_yanked = params - .get("include_yanked") + let option_param = |s| params.get(s).map(|v| v.as_str()); + let sort = option_param("sort"); + let include_yanked = option_param("include_yanked") .map(|s| s == "yes") .unwrap_or(true); // Remove 0x00 characters from the query string because Postgres can not // handle them and will return an error, which would cause us to throw // an Internal Server Error ourselves. - let q_string = params.get("q").map(|q| q.replace('\u{0}', "")); + let q_string = option_param("q").map(|q| q.replace('\u{0}', "")); + + let filter_params = FilterParams { + q_string: q_string.as_deref(), + include_yanked, + category: option_param("category"), + all_keywords: option_param("all_keywords"), + keyword: option_param("keyword"), + letter: option_param("letter"), + user_id: option_param("user_id").and_then(|s| s.parse::().ok()), + team_id: option_param("team_id").and_then(|s| s.parse::().ok()), + following: option_param("following").is_some(), + has_ids: option_param("ids[]").is_some(), + ..Default::default() + }; let selection = ( ALL_COLUMNS, false.into_sql::(), recent_crate_downloads::downloads.nullable(), ); - let mut query = crates::table - .left_join(recent_crate_downloads::table) - .select(selection) - .into_boxed(); - let mut supports_seek = true; + let conn = &mut *app.db_read()?; + let mut supports_seek = filter_params.supports_seek(); + let mut query = filter_params + .make_query(&req, conn)? + .left_join(recent_crate_downloads::table) + .select(selection); if let Some(q_string) = &q_string { // Searching with a query string always puts the exact match at the start of the results, @@ -72,16 +88,7 @@ pub async fn search(app: AppState, req: Parts) -> AppResult> { supports_seek = false; if !q_string.is_empty() { - let sort = params.get("sort").map(|s| &**s).unwrap_or("relevance"); - - let q = sql::("plainto_tsquery('english', ") - .bind::(q_string) - .sql(")"); - query = query.filter( - q.clone() - .matches(crates::textsearchable_index_col) - .or(Crate::loosly_matches_name(q_string)), - ); + let sort = sort.unwrap_or("relevance"); query = query.select(( ALL_COLUMNS, @@ -91,138 +98,16 @@ pub async fn search(app: AppState, req: Parts) -> AppResult> { query = query.order(Crate::with_name(q_string).desc()); if sort == "relevance" { + let q = to_tsquery_with_search_config( + configuration::TsConfigurationByName("english"), + q_string, + ); let rank = ts_rank_cd(crates::textsearchable_index_col, q); query = query.then_order_by(rank.desc()) } } } - if let Some(cat) = params.get("category") { - // Calculating the total number of results with filters is not supported yet. - supports_seek = false; - - query = query.filter( - crates::id.eq_any( - crates_categories::table - .select(crates_categories::crate_id) - .inner_join(categories::table) - .filter( - categories::slug - .eq(cat) - .or(categories::slug.like(format!("{cat}::%"))), - ), - ), - ); - } - - let conn = &mut *app.db_read()?; - - if let Some(kws) = params.get("all_keywords") { - // Calculating the total number of results with filters is not supported yet. - supports_seek = false; - - let names: Vec<_> = kws - .split_whitespace() - .map(|name| name.to_lowercase()) - .collect(); - - query = query.filter( - // FIXME: Just use `.contains` in Diesel 2.0 - // https://github.com/diesel-rs/diesel/issues/2066 - Contains::new( - crates_keywords::table - .inner_join(keywords::table) - .filter(crates_keywords::crate_id.eq(crates::id)) - .select(array_agg(keywords::keyword)) - .single_value(), - names.into_sql::>(), - ), - ); - } else if let Some(kw) = params.get("keyword") { - // Calculating the total number of results with filters is not supported yet. - supports_seek = false; - - query = query.filter( - crates::id.eq_any( - crates_keywords::table - .select(crates_keywords::crate_id) - .inner_join(keywords::table) - .filter(lower(keywords::keyword).eq(lower(kw))), - ), - ); - } else if let Some(letter) = params.get("letter") { - // Calculating the total number of results with filters is not supported yet. - supports_seek = false; - - let pattern = format!( - "{}%", - letter - .chars() - .next() - .ok_or_else(|| bad_request("letter value must contain 1 character"))? - .to_lowercase() - .collect::() - ); - query = query.filter(canon_crate_name(crates::name).like(pattern)); - } else if let Some(user_id) = params.get("user_id").and_then(|s| s.parse::().ok()) { - // Calculating the total number of results with filters is not supported yet. - supports_seek = false; - - query = query.filter( - crates::id.eq_any( - CrateOwner::by_owner_kind(OwnerKind::User) - .select(crate_owners::crate_id) - .filter(crate_owners::owner_id.eq(user_id)), - ), - ); - } else if let Some(team_id) = params.get("team_id").and_then(|s| s.parse::().ok()) { - // Calculating the total number of results with filters is not supported yet. - supports_seek = false; - - query = query.filter( - crates::id.eq_any( - CrateOwner::by_owner_kind(OwnerKind::Team) - .select(crate_owners::crate_id) - .filter(crate_owners::owner_id.eq(team_id)), - ), - ); - } else if params.get("following").is_some() { - // Calculating the total number of results with filters is not supported yet. - supports_seek = false; - - let user_id = AuthCheck::default().check(&req, conn)?.user_id(); - - query = query.filter( - crates::id.eq_any( - follows::table - .select(follows::crate_id) - .filter(follows::user_id.eq(user_id)), - ), - ); - } else if params.get("ids[]").is_some() { - // Calculating the total number of results with filters is not supported yet. - supports_seek = false; - - let query_bytes = req.uri.query().unwrap_or("").as_bytes(); - let ids: Vec<_> = url::form_urlencoded::parse(query_bytes) - .filter(|(key, _)| key == "ids[]") - .map(|(_, value)| value.to_string()) - .collect(); - - query = query.filter(crates::name.eq_any(ids)); - } - - if !include_yanked { - // Calculating the total number of results with filters is not supported yet. - supports_seek = false; - - query = query.filter(exists( - versions::table - .filter(versions::crate_id.eq(crates::id)) - .filter(versions::yanked.eq(false)), - )); - } - // Any sort other than 'relevance' (default) would ignore exact crate name matches if sort == Some("downloads") { // Custom sorting is not supported yet with seek. @@ -280,8 +165,9 @@ pub async fn search(app: AppState, req: Parts) -> AppResult> { // // If this becomes a problem in the future the crates count could be denormalized, at least // for the filterless happy path. + let count_query = filter_params.make_query(&req, conn)?.count(); let total: i64 = info_span!("db.query", message = "SELECT COUNT(*) FROM crates") - .in_scope(|| crates::table.count().get_result(conn))?; + .in_scope(|| count_query.get_result(conn))?; let results: Vec<(Crate, bool, Option)> = info_span!("db.query", message = "SELECT ... FROM crates") @@ -356,4 +242,174 @@ pub async fn search(app: AppState, req: Parts) -> AppResult> { .await } +#[derive(Default)] +struct FilterParams<'a> { + q_string: Option<&'a str>, + include_yanked: bool, + category: Option<&'a str>, + all_keywords: Option<&'a str>, + keyword: Option<&'a str>, + letter: Option<&'a str>, + user_id: Option, + team_id: Option, + following: bool, + has_ids: bool, + _auth_user_id: OnceCell, + _ids: OnceCell>>, +} + +impl<'a> FilterParams<'a> { + fn ids(&self, req: &Parts) -> Option<&[String]> { + self._ids + .get_or_init(|| { + if self.has_ids { + let query_bytes = req.uri.query().unwrap_or("").as_bytes(); + let v = url::form_urlencoded::parse(query_bytes) + .filter(|(key, _)| key == "ids[]") + .map(|(_, value)| value.to_string()) + .collect::>(); + Some(v) + } else { + None + } + }) + .as_deref() + } + + fn authed_user_id(&self, req: &Parts, conn: &mut PgConnection) -> AppResult<&i32> { + self._auth_user_id.get_or_try_init(|| { + let user_id = AuthCheck::default().check(req, conn)?.user_id(); + Ok(user_id) + }) + } + + fn supports_seek(&self) -> bool { + // Calculating the total number of results with filters is supported but paging is not supported yet. + !(self.q_string.is_some() + || self.category.is_some() + || self.all_keywords.is_some() + || self.keyword.is_some() + || self.letter.is_some() + || self.user_id.is_some() + || self.team_id.is_some() + || self.following + || self.has_ids + || !self.include_yanked) + } + + fn make_query( + &'a self, + req: &Parts, + conn: &mut PgConnection, + ) -> AppResult> { + use diesel::sql_types::Text; + let mut query = crates::table.into_boxed(); + + if let Some(q_string) = self.q_string { + if !q_string.is_empty() { + let q = to_tsquery_with_search_config( + configuration::TsConfigurationByName("english"), + q_string, + ); + query = query.filter( + q.matches(crates::textsearchable_index_col) + .or(Crate::loosly_matches_name(q_string)), + ); + } + } + + if let Some(cat) = self.category { + query = query.filter( + crates::id.eq_any( + crates_categories::table + .select(crates_categories::crate_id) + .inner_join(categories::table) + .filter( + categories::slug + .eq(cat) + .or(categories::slug.like(format!("{cat}::%"))), + ), + ), + ); + } + + if let Some(kws) = self.all_keywords { + let names: Vec<_> = kws + .split_whitespace() + .map(|name| name.to_lowercase()) + .collect(); + + query = query.filter( + // FIXME: Just use `.contains` in Diesel 2.0 + // https://github.com/diesel-rs/diesel/issues/2066 + Contains::new( + crates_keywords::table + .inner_join(keywords::table) + .filter(crates_keywords::crate_id.eq(crates::id)) + .select(array_agg(keywords::keyword)) + .single_value(), + names.into_sql::>(), + ), + ); + } else if let Some(kw) = self.keyword { + query = query.filter( + crates::id.eq_any( + crates_keywords::table + .select(crates_keywords::crate_id) + .inner_join(keywords::table) + .filter(lower(keywords::keyword).eq(lower(kw))), + ), + ); + } else if let Some(letter) = self.letter { + let pattern = format!( + "{}%", + letter + .chars() + .next() + .ok_or_else(|| bad_request("letter value must contain 1 character"))? + .to_lowercase() + .collect::() + ); + query = query.filter(canon_crate_name(crates::name).like(pattern)); + } else if let Some(user_id) = self.user_id { + query = query.filter( + crates::id.eq_any( + CrateOwner::by_owner_kind(OwnerKind::User) + .select(crate_owners::crate_id) + .filter(crate_owners::owner_id.eq(user_id)), + ), + ); + } else if let Some(team_id) = self.team_id { + query = query.filter( + crates::id.eq_any( + CrateOwner::by_owner_kind(OwnerKind::Team) + .select(crate_owners::crate_id) + .filter(crate_owners::owner_id.eq(team_id)), + ), + ); + } else if self.following { + let user_id = self.authed_user_id(req, conn)?; + query = query.filter( + crates::id.eq_any( + follows::table + .select(follows::crate_id) + .filter(follows::user_id.eq(user_id)), + ), + ); + } else if self.ids(req).is_some() { + query = query.filter(crates::name.eq_any(self.ids(req).unwrap())); + } + + if !self.include_yanked { + query = query.filter(exists( + versions::table + .filter(versions::crate_id.eq(crates::id)) + .filter(versions::yanked.eq(false)), + )); + } + + Ok(query) + } +} + diesel::infix_operator!(Contains, "@>"); From 0324b647679c2328e0ed5025b6da8983424d9c9f Mon Sep 17 00:00:00 2001 From: eth3lbert Date: Mon, 15 Jan 2024 17:31:21 +0800 Subject: [PATCH 2/5] krate/search: improve performance with count subquery --- src/controllers/helpers/pagination.rs | 67 +++++++++++++++++++++++++++ src/controllers/krate/search.rs | 5 +- 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/src/controllers/helpers/pagination.rs b/src/controllers/helpers/pagination.rs index 0d29bac4aa6..2e587800015 100644 --- a/src/controllers/helpers/pagination.rs +++ b/src/controllers/helpers/pagination.rs @@ -146,6 +146,18 @@ pub(crate) trait Paginate: Sized { options, } } + + fn pages_pagination_with_count_query( + self, + options: PaginationOptions, + count_query: C, + ) -> PaginatedQueryWithCountSubq { + PaginatedQueryWithCountSubq { + query: self, + count_query, + options, + } + } } impl Paginate for T {} @@ -303,6 +315,61 @@ pub(crate) fn decode_seek Deserialize<'a>>(seek: &str) -> anyhow::Res Ok(decoded) } +#[derive(Debug)] +pub(crate) struct PaginatedQueryWithCountSubq { + query: T, + count_query: C, + options: PaginationOptions, +} + +impl QueryId for PaginatedQueryWithCountSubq { + const HAS_STATIC_QUERY_ID: bool = false; + type QueryId = (); +} + +impl< + T: Query, + C: Query + QueryDsl + diesel::query_dsl::methods::SelectDsl, + > Query for PaginatedQueryWithCountSubq +{ + type SqlType = (T::SqlType, BigInt); +} + +impl RunQueryDsl for PaginatedQueryWithCountSubq {} + +impl QueryFragment for PaginatedQueryWithCountSubq +where + T: QueryFragment, + C: QueryFragment, +{ + fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> { + out.push_sql("SELECT *, ("); + self.count_query.walk_ast(out.reborrow())?; + out.push_sql(") FROM ("); + self.query.walk_ast(out.reborrow())?; + out.push_sql(") t LIMIT "); + out.push_bind_param::(&self.options.per_page)?; + if let Some(offset) = self.options.offset() { + out.push_sql(format!(" OFFSET {offset}").as_str()); + } + Ok(()) + } +} + +impl PaginatedQueryWithCountSubq { + pub(crate) fn load<'a, U>(self, conn: &mut PgConnection) -> QueryResult> + where + Self: LoadQuery<'a, PgConnection, WithCount>, + { + let options = self.options.clone(); + let records_and_total = self.internal_load(conn)?.collect::>()?; + Ok(Paginated { + records_and_total, + options, + }) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/controllers/krate/search.rs b/src/controllers/krate/search.rs index c32ad61f0db..197bbed7822 100644 --- a/src/controllers/krate/search.rs +++ b/src/controllers/krate/search.rs @@ -186,7 +186,10 @@ pub async fn search(app: AppState, req: Parts) -> AppResult> { (total, next_page, None, results, conn) } else { - let query = query.pages_pagination(pagination); + let query = query.pages_pagination_with_count_query( + pagination, + filter_params.make_query(&req, conn)?.count(), + ); let data: Paginated<(Crate, bool, Option)> = info_span!("db.query", message = "SELECT ..., COUNT(*) FROM crates") .in_scope(|| query.load(conn))?; From 179351a01d0aa27a27e0b3060bfe0a73491f42fa Mon Sep 17 00:00:00 2001 From: eth3lbert Date: Wed, 17 Jan 2024 13:33:09 +0800 Subject: [PATCH 3/5] tests/routes/crates/list: Assert total in the test suite --- src/tests/routes/crates/list.rs | 64 ++++++++++++++++++++++++++------- 1 file changed, 52 insertions(+), 12 deletions(-) diff --git a/src/tests/routes/crates/list.rs b/src/tests/routes/crates/list.rs index e3e84622f10..f370b50ed4e 100644 --- a/src/tests/routes/crates/list.rs +++ b/src/tests/routes/crates/list.rs @@ -68,21 +68,53 @@ fn index_queries() { // Query containing a space assert_eq!(anon.search("q=foo%20kw3").meta.total, 1); - assert_eq!(anon.search_by_user_id(user.id).crates.len(), 4); - assert_eq!(anon.search_by_user_id(0).crates.len(), 0); + let json = anon.search_by_user_id(user.id); + assert_eq!(json.crates.len(), 4); + assert_eq!(json.meta.total, 4); + + let json = anon.search_by_user_id(0); + assert_eq!(json.crates.len(), 0); + assert_eq!(json.meta.total, 0); + + let json = anon.search("letter=F"); + assert_eq!(json.crates.len(), 2); + assert_eq!(json.meta.total, 2); + + let json = anon.search("letter=B"); + assert_eq!(json.crates.len(), 1); + assert_eq!(json.meta.total, 1); + + let json = anon.search("letter=b"); + assert_eq!(json.crates.len(), 1); + assert_eq!(json.meta.total, 1); + + let json = anon.search("letter=c"); + assert_eq!(json.crates.len(), 0); + assert_eq!(json.meta.total, 0); + + let json = anon.search("keyword=kw1"); + assert_eq!(json.crates.len(), 3); + assert_eq!(json.meta.total, 3); + + let json = anon.search("keyword=KW1"); + assert_eq!(json.crates.len(), 3); + assert_eq!(json.meta.total, 3); + + let json = anon.search("keyword=kw2"); + assert_eq!(json.crates.len(), 0); + assert_eq!(json.meta.total, 0); - assert_eq!(anon.search("letter=F").crates.len(), 2); - assert_eq!(anon.search("letter=B").crates.len(), 1); - assert_eq!(anon.search("letter=b").crates.len(), 1); - assert_eq!(anon.search("letter=c").crates.len(), 0); + let json = anon.search("all_keywords=kw1%20kw3"); + assert_eq!(json.crates.len(), 1); + assert_eq!(json.meta.total, 1); - assert_eq!(anon.search("keyword=kw1").crates.len(), 3); - assert_eq!(anon.search("keyword=KW1").crates.len(), 3); - assert_eq!(anon.search("keyword=kw2").crates.len(), 0); - assert_eq!(anon.search("all_keywords=kw1%20kw3").crates.len(), 1); + let json = anon.search("q=foo&keyword=kw1"); + assert_eq!(json.crates.len(), 1); + assert_eq!(json.meta.total, 1); - assert_eq!(anon.search("q=foo&keyword=kw1").crates.len(), 1); - assert_eq!(anon.search("q=foo2&keyword=kw1").crates.len(), 0); + let json = anon.search("q=foo2&keyword=kw1"); + assert_eq!(json.crates.len(), 0); + assert_eq!(json.meta.total, 0); app.db(|conn| { new_category("Category 1", "cat1", "Category 1 crates") @@ -727,6 +759,10 @@ fn pagination_links_included_if_applicable() { Some("?letter=p&page=2&per_page=1".to_string()), page3.meta.prev_page ); + assert!([page1.meta.total, page2.meta.total, page3.meta.total] + .iter() + .all(|w| *w == 3)); + assert_eq!(page4.meta.total, 0); } #[test] @@ -791,10 +827,12 @@ fn test_pages_work_even_with_seek_based_pagination() { // The next_page returned by the request is seek-based let first = anon.search("per_page=1"); assert!(first.meta.next_page.unwrap().contains("seek=")); + assert_eq!(first.meta.total, 3); // Calling with page=2 will revert to offset-based pagination let second = anon.search("page=2&per_page=1"); assert!(second.meta.next_page.unwrap().contains("page=3")); + assert_eq!(second.meta.total, 3); } #[test] @@ -844,6 +882,7 @@ fn crates_by_user_id() { let response = user.search_by_user_id(id); assert_eq!(response.crates.len(), 1); + assert_eq!(response.meta.total, 1); } #[test] @@ -858,4 +897,5 @@ fn crates_by_user_id_not_including_deleted_owners() { let response = anon.search_by_user_id(user.id); assert_eq!(response.crates.len(), 0); + assert_eq!(response.meta.total, 0); } From d53a93fd4ccea54917c53c13d7ca4b332edf8f00 Mon Sep 17 00:00:00 2001 From: eth3lbert Date: Fri, 2 Feb 2024 13:22:32 +0800 Subject: [PATCH 4/5] Fix the issue of search not working with spaces in them (rust-lang#8052) --- src/controllers/krate/search.rs | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/controllers/krate/search.rs b/src/controllers/krate/search.rs index 197bbed7822..afcdba826a9 100644 --- a/src/controllers/krate/search.rs +++ b/src/controllers/krate/search.rs @@ -2,7 +2,7 @@ use crate::auth::AuthCheck; use diesel::dsl::*; -use diesel::sql_types::Array; +use diesel::sql_types::{Array, Text}; use diesel_full_text_search::*; use indexmap::IndexMap; use once_cell::sync::OnceCell; @@ -98,10 +98,9 @@ pub async fn search(app: AppState, req: Parts) -> AppResult> { query = query.order(Crate::with_name(q_string).desc()); if sort == "relevance" { - let q = to_tsquery_with_search_config( - configuration::TsConfigurationByName("english"), - q_string, - ); + let q = sql::("plainto_tsquery('english', ") + .bind::(q_string) + .sql(")"); let rank = ts_rank_cd(crates::textsearchable_index_col, q); query = query.then_order_by(rank.desc()) } @@ -305,15 +304,13 @@ impl<'a> FilterParams<'a> { req: &Parts, conn: &mut PgConnection, ) -> AppResult> { - use diesel::sql_types::Text; let mut query = crates::table.into_boxed(); if let Some(q_string) = self.q_string { if !q_string.is_empty() { - let q = to_tsquery_with_search_config( - configuration::TsConfigurationByName("english"), - q_string, - ); + let q = sql::("plainto_tsquery('english', ") + .bind::(q_string) + .sql(")"); query = query.filter( q.matches(crates::textsearchable_index_col) .or(Crate::loosly_matches_name(q_string)), From 4a95592690211bea4c0283107ed4c7f348daea5c Mon Sep 17 00:00:00 2001 From: eth3lbert Date: Fri, 2 Feb 2024 23:47:44 +0800 Subject: [PATCH 5/5] Update src/controllers/helpers/pagination.rs Co-authored-by: Carol (Nichols || Goulding) <193874+carols10cents@users.noreply.github.com> --- src/controllers/helpers/pagination.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/controllers/helpers/pagination.rs b/src/controllers/helpers/pagination.rs index 2e587800015..327b4bc3186 100644 --- a/src/controllers/helpers/pagination.rs +++ b/src/controllers/helpers/pagination.rs @@ -350,6 +350,8 @@ where out.push_sql(") t LIMIT "); out.push_bind_param::(&self.options.per_page)?; if let Some(offset) = self.options.offset() { + // Injection safety: `offset()` returns `Option`, so this interpolation is constrained to known + // valid values and this is not vulnerable to user injection attacks. out.push_sql(format!(" OFFSET {offset}").as_str()); } Ok(())