Skip to content

Commit 37de572

Browse files
forsaken628TCeason
andauthored
fix: fix segfault in list_domain (#16979)
* fix list_domain Signed-off-by: coldWater <forsaken628@gmail.com> * fix Signed-off-by: coldWater <forsaken628@gmail.com> --------- Signed-off-by: coldWater <forsaken628@gmail.com> Co-authored-by: TCeason <33082201+TCeason@users.noreply.github.com>
1 parent f40b8c3 commit 37de572

File tree

2 files changed

+94
-45
lines changed

2 files changed

+94
-45
lines changed

src/query/pipeline/transforms/src/processors/transforms/sort/k_way_merge_sort_partition.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,9 @@ where
150150
}
151151

152152
fn calc_partition_point(&self) -> Partition {
153-
let mut candidate = Candidate::new(&self.rows, EndDomain {
154-
min: self.min_task,
155-
max: self.max_task,
156-
});
153+
let mut candidate =
154+
Candidate::new(&self.rows, EndDomain::new(self.min_task, self.max_task));
155+
157156
candidate.init();
158157

159158
// if candidate.is_small_task() {

src/query/pipeline/transforms/src/processors/transforms/sort/list_domain.rs

Lines changed: 91 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ where Self: Debug
7474

7575
#[derive(Debug)]
7676
pub struct Partition {
77-
pub ends: Vec<(usize, usize)>,
77+
pub ends: Vec<(usize, usize)>, // index, partition point
7878
pub total: usize,
7979
}
8080

@@ -117,8 +117,11 @@ where
117117
T: List + 'a,
118118
T::Item<'a>: Debug,
119119
{
120+
// The cut point value.
120121
target: T::Item<'a>,
122+
// The domain of partition point in each list, and also the domain of the generated task size in each list.
121123
domains: Vec<EndDomain>,
124+
// The size domain of the generated task if the current target is used as the cut point.
122125
sum: EndDomain,
123126
}
124127

@@ -136,6 +139,7 @@ where T: List
136139
}
137140

138141
pub fn init(&mut self) -> bool {
142+
// Take the smallest first and the smallest last of all the lists as the initial target range.
139143
let target: (Option<T::Item<'a>>, Option<T::Item<'a>>) =
140144
self.all_list.iter().fold((None, None), |(min, max), ls| {
141145
let min = match (min, ls.first()) {
@@ -151,9 +155,8 @@ where T: List
151155

152156
(min, max)
153157
});
154-
let (min_target, max_target) = if let (Some(min), Some(max)) = target {
155-
(min, max)
156-
} else {
158+
let (Some(min_target), Some(max_target)) = target else {
159+
// invalid empty input
157160
return false;
158161
};
159162

@@ -189,7 +192,7 @@ where T: List
189192

190193
pub fn is_small_task(&mut self) -> bool {
191194
loop {
192-
let sum = self.do_search_max(Some(8));
195+
let sum = self.reduce_max_domain(Some(8));
193196
match self.expect.overlaps(sum) {
194197
Overlap::Left => return true,
195198
Overlap::Right => return false,
@@ -203,7 +206,7 @@ where T: List
203206
for _ in 0..max_iter {
204207
match self.overlaps() {
205208
(_, _, Overlap::Cross) => {
206-
let sum = self.do_search_max(Some(n));
209+
let sum = self.reduce_max_domain(Some(n));
207210
if self.is_finish(sum) {
208211
return Partition::new(self.max_target.unwrap());
209212
}
@@ -221,7 +224,7 @@ where T: List
221224
Some(Overlap::Cross),
222225
Overlap::Right,
223226
) => {
224-
let sum = self.do_search_mid(Some(n));
227+
let sum = self.reduce_mid_domain(Some(n));
225228
match self.expect.overlaps(sum) {
226229
Overlap::Right => self.cut_right(),
227230
Overlap::Left if matches!(min_overlap, Overlap::Left) => self.cut_left(),
@@ -232,7 +235,7 @@ where T: List
232235
}
233236
}
234237
(Overlap::Cross, Some(Overlap::Left), Overlap::Right) => {
235-
let sum = self.do_search_min(Some(n));
238+
let sum = self.reduce_min_domain(Some(n));
236239
match self.expect.overlaps(sum) {
237240
Overlap::Left => self.cut_left(),
238241
Overlap::Cross if sum.done() => {
@@ -251,19 +254,19 @@ where T: List
251254
};
252255
}
253256

254-
self.do_search_max(None);
257+
self.reduce_max_domain(None);
255258
Partition::new(self.max_target.unwrap())
256259
}
257260

258-
fn do_search_max(&mut self, n: Option<usize>) -> EndDomain {
261+
fn reduce_max_domain(&mut self, n: Option<usize>) -> EndDomain {
259262
do_search(self.all_list, self.max_target.as_mut().unwrap(), n)
260263
}
261264

262-
fn do_search_min(&mut self, n: Option<usize>) -> EndDomain {
265+
fn reduce_min_domain(&mut self, n: Option<usize>) -> EndDomain {
263266
do_search(self.all_list, self.min_target.as_mut().unwrap(), n)
264267
}
265268

266-
fn do_search_mid(&mut self, n: Option<usize>) -> EndDomain {
269+
fn reduce_mid_domain(&mut self, n: Option<usize>) -> EndDomain {
267270
do_search(self.all_list, self.mid_target.as_mut().unwrap(), n)
268271
}
269272

@@ -290,11 +293,8 @@ where T: List
290293
if max_domain.is_zero() {
291294
continue;
292295
}
293-
let five = EndDomain {
294-
min: min_domain.min,
295-
max: max_domain.min,
296-
}
297-
.five_point();
296+
297+
let five = min_domain.merge(max_domain).five_point();
298298
for v in five.into_iter().filter_map(|i| {
299299
let v = ls.index(i);
300300
if v >= *min_target && v <= *max_target {
@@ -336,6 +336,7 @@ where T: List
336336
}
337337

338338
fn overlaps(&self) -> (Overlap, Option<Overlap>, Overlap) {
339+
// Compare expect task size domain with min_target,mid_target and max_target task size domain.
339340
(
340341
self.expect.overlaps(self.min_target.as_ref().unwrap().sum),
341342
self.mid_target
@@ -392,11 +393,16 @@ where
392393

393394
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
394395
pub struct EndDomain {
395-
pub min: usize,
396-
pub max: usize,
396+
min: usize,
397+
max: usize,
397398
}
398399

399400
impl EndDomain {
401+
pub fn new(min: usize, max: usize) -> EndDomain {
402+
assert!(min <= max);
403+
EndDomain { min, max }
404+
}
405+
400406
fn done(&self) -> bool {
401407
self.min == self.max
402408
}
@@ -453,6 +459,13 @@ impl EndDomain {
453459
],
454460
}
455461
}
462+
463+
fn merge(&self, other: &EndDomain) -> EndDomain {
464+
EndDomain {
465+
min: self.min.min(other.min),
466+
max: self.max.max(other.max),
467+
}
468+
}
456469
}
457470

458471
#[derive(Debug)]
@@ -481,15 +494,18 @@ impl std::iter::Sum for EndDomain {
481494

482495
impl From<std::ops::RangeInclusive<usize>> for EndDomain {
483496
fn from(value: std::ops::RangeInclusive<usize>) -> Self {
484-
EndDomain {
485-
min: *value.start(),
486-
max: *value.end(),
487-
}
497+
EndDomain::new(*value.start(), *value.end())
488498
}
489499
}
490500

491501
#[cfg(test)]
492502
mod tests {
503+
use std::iter::repeat_with;
504+
505+
use rand::rngs::StdRng;
506+
use rand::Rng;
507+
use rand::SeedableRng;
508+
493509
use super::*;
494510

495511
impl List for &[i32] {
@@ -525,35 +541,67 @@ mod tests {
525541
run_test(&all_list, (5..=10).into(), 10);
526542
}
527543

544+
{
545+
let all_list = issue_16923();
546+
547+
let all_list: Vec<_> = all_list.iter().map(|v| v.as_slice()).collect();
548+
run_test(&all_list, (5..=100).into(), 20);
549+
}
550+
528551
for _ in 0..100 {
529-
let all_list = rand_data();
552+
let all_list = rand_data(rand::random());
530553
let all_list: Vec<_> = all_list.iter().map(|v| v.as_slice()).collect();
531554

532-
run_test(&all_list, (5..=10).into(), 10)
555+
run_test(&all_list, (5..=100).into(), 20)
533556
}
534557
}
535558

536-
fn rand_data() -> Vec<Vec<i32>> {
537-
use rand::Rng;
538-
let mut rng = rand::thread_rng();
559+
fn rand_data(seed: u64) -> Vec<Vec<i32>> {
560+
let mut rng = StdRng::seed_from_u64(seed);
539561

540-
(0..5)
541-
.map(|_| {
542-
let rows: usize = rng.gen_range(0..=20);
543-
let mut data = (0..rows)
544-
.map(|_| rng.gen_range(0..=1000))
545-
.collect::<Vec<_>>();
546-
data.sort();
547-
data
548-
})
549-
.collect::<Vec<_>>()
562+
let list = rng.gen_range(1..=10);
563+
repeat_with(|| {
564+
let rows = rng.gen_range(0..=40);
565+
let mut data = repeat_with(|| rng.gen_range(0..=1000))
566+
.take(rows)
567+
.collect::<Vec<_>>();
568+
data.sort();
569+
data
570+
})
571+
.take(list)
572+
.collect::<Vec<_>>()
573+
}
574+
575+
fn issue_16923() -> Vec<Vec<i32>> {
576+
vec![
577+
vec![6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6],
578+
vec![
579+
3, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 8, 13, 13, 13, 13, 13, 13,
580+
13, 13, 13, 13, 13, 13, 13, 13, 13, 18,
581+
],
582+
vec![6, 6, 6, 6, 6],
583+
vec![
584+
2, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 11, 12, 14, 15, 16, 19,
585+
],
586+
vec![
587+
6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
588+
],
589+
vec![
590+
1, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
591+
11, 12, 14, 15, 17, 18, 21, 22, 24, 25, 27,
592+
],
593+
vec![
594+
0, 9, 10, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 20,
595+
23, 26, 27, 27, 27, 27, 27, 27, 27, 28,
596+
],
597+
]
550598
}
551599

552600
fn run_test(all_list: &[&[i32]], expect_size: EndDomain, max_iter: usize) {
553601
let mut candidate = Candidate::new(all_list, expect_size);
554602

555603
let got = if candidate.init() {
556-
candidate.calc_partition(3, max_iter)
604+
candidate.calc_partition(4, max_iter)
557605
} else {
558606
let sum: usize = all_list.iter().map(|ls| ls.len()).sum();
559607
assert_eq!(sum, 0);
@@ -574,11 +622,13 @@ mod tests {
574622
(ls[..end].last(), ls[end..].first())
575623
})
576624
.fold((None, None), |acc, (end, start)| {
577-
(acc.0.max(end), match (acc.1, start) {
625+
let max_end = acc.0.max(end);
626+
let min_start = match (acc.1, start) {
578627
(None, None) => None,
579628
(None, v @ Some(_)) | (v @ Some(_), None) => v,
580629
(Some(a), Some(b)) => Some(a.min(b)),
581-
})
630+
};
631+
(max_end, min_start)
582632
});
583633
match x {
584634
(Some(a), Some(b)) => assert!(a < b, "all_list {all_list:?}"),

0 commit comments

Comments
 (0)