diff --git a/src/controllers/helpers/pagination.rs b/src/controllers/helpers/pagination.rs index 0a396589af8..0c194fd8294 100644 --- a/src/controllers/helpers/pagination.rs +++ b/src/controllers/helpers/pagination.rs @@ -402,19 +402,49 @@ impl PaginatedQueryWithCountSubq { } macro_rules! seek { + // Field struct + (@variant_struct $vis:vis $variant:ident { + $($(#[$field_meta:meta])? $field:ident: $ty:ty),* $(,)? + }) => { + paste::item! { + #[derive(Debug, Default, Deserialize, PartialEq)] + #[serde(from = $variant "Helper")] + $vis struct $variant { + $($(#[$field_meta])? pub(super) $field: $ty),* + } + + #[derive(Debug, Default, Deserialize, Serialize, PartialEq)] + struct [<$variant Helper>]($($(#[$field_meta])? pub(super) $ty),*); + + impl From<[<$variant Helper>]> for $variant { + fn from(helper: [<$variant Helper>]) -> Self { + let [<$variant Helper>]($($field,)*) = helper; + Self { $($field,)* } + } + } + + impl serde::Serialize for $variant { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let helper = [<$variant Helper>]($(self.$field,)*); + serde::Serialize::serialize(&helper, serializer) + } + } + } + }; ( $vis:vis enum $name:ident { $( - $variant:ident($($(#[$field_meta:meta])? $ty:ty),*) + $variant:ident $fields:tt, )* } ) => { + $( + seek!(@variant_struct $vis $variant $fields); + )* paste::item! { - $( - #[derive(Debug, Default, Deserialize, Serialize, PartialEq)] - $vis struct $variant($($(#[$field_meta])? pub(super) $ty),*); - )* - #[derive(Debug, Deserialize, Serialize, PartialEq)] #[serde(untagged)] $vis enum [<$name Payload>] { @@ -578,17 +608,27 @@ mod tests { mod seek { use chrono::naive::serde::ts_microseconds; - seek! { + seek!( pub(super) enum Seek { - Id(i32) - New(#[serde(with="ts_microseconds")] chrono::NaiveDateTime, i32) - RecentDownloads(Option, i32) + Id { + id: i32, + }, + New { + #[serde(with = "ts_microseconds")] + dt: chrono::NaiveDateTime, + id: i32, + }, + RecentDownloads { + downloads: Option, + id: i32, + }, } - } + ); } #[test] fn test_seek_macro_encode_and_decode() { + use chrono::naive::serde::ts_microseconds; use chrono::{NaiveDate, NaiveDateTime}; use seek::*; @@ -601,8 +641,9 @@ mod tests { assert_eq!(decoded, expect); }; + let id = 1234; let seek = Seek::Id; - let payload = SeekPayload::Id(Id(1234)); + let payload = SeekPayload::Id(Id { id }); let query = format!("seek={}", encode_seek(&payload).unwrap()); assert_decode_after(seek, &query, Some(payload)); @@ -611,12 +652,13 @@ mod tests { .and_hms_opt(9, 10, 11) .unwrap(); let seek = Seek::New; - let payload = SeekPayload::New(New(dt, 1234)); + let payload = SeekPayload::New(New { dt, id }); let query = format!("seek={}", encode_seek(&payload).unwrap()); assert_decode_after(seek, &query, Some(payload)); + let downloads = Some(5678); let seek = Seek::RecentDownloads; - let payload = SeekPayload::RecentDownloads(RecentDownloads(Some(5678), 1234)); + let payload = SeekPayload::RecentDownloads(RecentDownloads { downloads, id }); let query = format!("seek={}", encode_seek(&payload).unwrap()); assert_decode_after(seek, &query, Some(payload)); @@ -624,7 +666,7 @@ mod tests { assert_decode_after(seek, "", None); let seek = Seek::Id; - let payload = SeekPayload::RecentDownloads(RecentDownloads(Some(5678), 1234)); + let payload = SeekPayload::RecentDownloads(RecentDownloads { downloads, id }); let query = format!("seek={}", encode_seek(payload).unwrap()); let pagination = PaginationOptions::builder() .enable_seek(true) @@ -634,23 +676,38 @@ mod tests { assert_eq!(error.to_string(), "invalid seek parameter"); let response = error.response(); assert_eq!(response.status(), StatusCode::BAD_REQUEST); + + // Ensures it still encodes compactly with a field struct + #[derive(Debug, Default, Serialize, PartialEq)] + struct NewTuple( + #[serde(with = "ts_microseconds")] chrono::NaiveDateTime, + i32, + ); + assert_eq!( + encode_seek(NewTuple(dt, id)).unwrap(), + encode_seek(SeekPayload::New(New { dt, id })).unwrap() + ); } #[test] fn test_seek_macro_conv() { use chrono::{NaiveDate, NaiveDateTime}; use seek::*; - - assert_eq!(Seek::from(SeekPayload::Id(Id(1234))), Seek::Id); + let id = 1234; + assert_eq!(Seek::from(SeekPayload::Id(Id { id })), Seek::Id); let dt: NaiveDateTime = NaiveDate::from_ymd_opt(2016, 7, 8) .unwrap() .and_hms_opt(9, 10, 11) .unwrap(); - assert_eq!(Seek::from(SeekPayload::New(New(dt, 1234))), Seek::New); + assert_eq!(Seek::from(SeekPayload::New(New { dt, id })), Seek::New); + let downloads = None; assert_eq!( - Seek::from(SeekPayload::RecentDownloads(RecentDownloads(None, 1234))), + Seek::from(SeekPayload::RecentDownloads(RecentDownloads { + downloads, + id + })), Seek::RecentDownloads ); } diff --git a/src/controllers/krate/search.rs b/src/controllers/krate/search.rs index de00e028855..1c81cdddff0 100644 --- a/src/controllers/krate/search.rs +++ b/src/controllers/krate/search.rs @@ -412,12 +412,12 @@ impl<'a> FilterParams<'a> { .single_value() }; let conditions: Vec> = match *seek_payload { - SeekPayload::Name(Name(id)) => { + SeekPayload::Name(Name { id }) => { // Equivalent of: // `WHERE name > name'` vec![Box::new(crates::name.nullable().gt(crate_name_by_id(id)))] } - SeekPayload::New(New(created_at, id)) => { + SeekPayload::New(New { created_at, id }) => { // Equivalent of: // `WHERE (created_at = created_at' AND id < id') OR created_at < created_at'` vec![ @@ -430,7 +430,7 @@ impl<'a> FilterParams<'a> { Box::new(crates::created_at.lt(created_at).nullable()), ] } - SeekPayload::RecentUpdates(RecentUpdates(updated_at, id)) => { + SeekPayload::RecentUpdates(RecentUpdates { updated_at, id }) => { // Equivalent of: // `WHERE (updated_at = updated_at' AND id < id') OR updated_at < updated_at'` vec![ @@ -443,7 +443,10 @@ impl<'a> FilterParams<'a> { Box::new(crates::updated_at.lt(updated_at).nullable()), ] } - SeekPayload::RecentDownloads(RecentDownloads(recent_downloads, id)) => { + SeekPayload::RecentDownloads(RecentDownloads { + recent_downloads, + id, + }) => { // Equivalent of: // for recent_downloads is not None: // `WHERE (recent_downloads = recent_downloads' AND id < id') @@ -477,7 +480,7 @@ impl<'a> FilterParams<'a> { } } } - SeekPayload::Downloads(Downloads(downloads, id)) => { + SeekPayload::Downloads(Downloads { downloads, id }) => { // Equivalent of: // `WHERE (downloads = downloads' AND id < id') OR downloads < downloads'` vec![ @@ -490,7 +493,7 @@ impl<'a> FilterParams<'a> { Box::new(crate_downloads::downloads.lt(downloads).nullable()), ] } - SeekPayload::Query(Query(exact_match, id)) => { + SeekPayload::Query(Query { exact_match, id }) => { // Equivalent of: // `WHERE (exact_match = exact_match' AND name < name') OR exact_match < // exact_match'` @@ -506,7 +509,11 @@ impl<'a> FilterParams<'a> { Box::new(name_exact_match.lt(exact_match).nullable()), ] } - SeekPayload::Relevance(Relevance(exact, rank_in, id)) => { + SeekPayload::Relevance(Relevance { + exact_match: exact, + rank: rank_in, + id, + }) => { // Equivalent of: // `WHERE (exact_match = exact_match' AND rank = rank' AND name > name') // OR (exact_match = exact_match' AND rank < rank') @@ -551,37 +558,74 @@ mod seek { use crate::models::Crate; use chrono::naive::serde::ts_microseconds; - seek! { + seek!( pub enum Seek { - Name(i32) - New(#[serde(with="ts_microseconds")] chrono::NaiveDateTime, i32) - RecentUpdates(#[serde(with="ts_microseconds")] chrono::NaiveDateTime, i32) - RecentDownloads(Option, i32) - Downloads(i64, i32) - Query(bool, i32) - Relevance(bool, f32, i32) + Name { + id: i32, + }, + New { + #[serde(with = "ts_microseconds")] + created_at: chrono::NaiveDateTime, + id: i32, + }, + RecentUpdates { + #[serde(with = "ts_microseconds")] + updated_at: chrono::NaiveDateTime, + id: i32, + }, + RecentDownloads { + recent_downloads: Option, + id: i32, + }, + Downloads { + downloads: i64, + id: i32, + }, + Query { + exact_match: bool, + id: i32, + }, + Relevance { + exact_match: bool, + rank: f32, + id: i32, + }, } - } + ); impl Seek { pub(crate) fn to_payload( &self, record: &(Crate, bool, i64, Option, f32), ) -> SeekPayload { + let ( + Crate { + id, + updated_at, + created_at, + .. + }, + exact_match, + downloads, + recent_downloads, + rank, + ) = *record; + match *self { - Seek::Name => SeekPayload::Name(Name(record.0.id)), - Seek::New => SeekPayload::New(New(record.0.created_at, record.0.id)), - Seek::RecentUpdates => { - SeekPayload::RecentUpdates(RecentUpdates(record.0.updated_at, record.0.id)) - } - Seek::RecentDownloads => { - SeekPayload::RecentDownloads(RecentDownloads(record.3, record.0.id)) - } - Seek::Downloads => SeekPayload::Downloads(Downloads(record.2, record.0.id)), - Seek::Query => SeekPayload::Query(Query(record.1, record.0.id)), - Seek::Relevance => { - SeekPayload::Relevance(Relevance(record.1, record.4, record.0.id)) - } + Seek::Name => SeekPayload::Name(Name { id }), + Seek::New => SeekPayload::New(New { created_at, id }), + Seek::RecentUpdates => SeekPayload::RecentUpdates(RecentUpdates { updated_at, id }), + Seek::RecentDownloads => SeekPayload::RecentDownloads(RecentDownloads { + recent_downloads, + id, + }), + Seek::Downloads => SeekPayload::Downloads(Downloads { downloads, id }), + Seek::Query => SeekPayload::Query(Query { exact_match, id }), + Seek::Relevance => SeekPayload::Relevance(Relevance { + exact_match, + rank, + id, + }), } } } diff --git a/src/controllers/krate/versions.rs b/src/controllers/krate/versions.rs index fda781cf936..07eaea4d2d8 100644 --- a/src/controllers/krate/versions.rs +++ b/src/controllers/krate/versions.rs @@ -92,7 +92,7 @@ fn list_by_date( !matches!(&options.page, Page::Numeric(_)), "?page= is not supported" ); - if let Some(SeekPayload::Date(Date(created_at, id))) = Seek::Date.after(&options.page)? { + if let Some(SeekPayload::Date(Date { created_at, id })) = Seek::Date.after(&options.page)? { query = query.filter( versions::created_at .eq(created_at) @@ -169,7 +169,7 @@ fn list_by_semver( "?page= is not supported" ); let mut idx = Some(0); - if let Some(SeekPayload::Semver(Semver(id))) = Seek::Semver.after(&options.page)? { + if let Some(SeekPayload::Semver(Semver { id })) = Seek::Semver.after(&options.page)? { idx = sorted_versions .get_index_of(&id) .filter(|i| i + 1 < sorted_versions.len()) @@ -234,18 +234,25 @@ mod seek { // We might consider refactoring this to use named fields, which would be clearer and more // flexible. It's also worth noting that we currently encode seek compactly as a Vec, which // doesn't include field names. - seek! { + seek!( pub enum Seek { - Semver(i32) - Date(#[serde(with="ts_microseconds")] chrono::NaiveDateTime, i32) + Semver { + id: i32, + }, + Date { + #[serde(with = "ts_microseconds")] + created_at: chrono::NaiveDateTime, + id: i32, + }, } - } + ); impl Seek { pub(crate) fn to_payload(&self, record: &(Version, Option)) -> SeekPayload { + let (Version { id, created_at, .. }, _) = *record; match *self { - Seek::Semver => SeekPayload::Semver(Semver(record.0.id)), - Seek::Date => SeekPayload::Date(Date(record.0.created_at, record.0.id)), + Seek::Semver => SeekPayload::Semver(Semver { id }), + Seek::Date => SeekPayload::Date(Date { created_at, id }), } } }