Skip to content

Add RngUtils::sample() method for reservoir sampling from iterators #8491

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions src/libstd/rand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,26 @@ pub trait RngUtil {
* ~~~
*/
fn shuffle_mut<T>(&mut self, values: &mut [T]);

/**
* Sample up to `n` values from an iterator.
*
* # Example
*
* ~~~ {.rust}
*
* use std::rand;
* use std::rand::RngUtil;
*
* fn main() {
* let mut rng = rand::rng();
* let vals = range(1, 100).to_owned_vec();
* let sample = rng.sample(vals.iter(), 5);
* printfln!(sample);
* }
* ~~~
*/
fn sample<A, T: Iterator<A>>(&mut self, iter: T, n: uint) -> ~[A];
}

/// Extension methods for random number generators
Expand Down Expand Up @@ -607,6 +627,23 @@ impl<R: Rng> RngUtil for R {
values.swap(i, self.gen_uint_range(0u, i + 1u));
}
}

/// Randomly sample up to `n` elements from an iterator
fn sample<A, T: Iterator<A>>(&mut self, iter: T, n: uint) -> ~[A] {
let mut reservoir : ~[A] = vec::with_capacity(n);
for (i, elem) in iter.enumerate() {
if i < n {
reservoir.push(elem);
loop
}

let k = self.gen_uint_range(0, i + 1);
if k < reservoir.len() {
reservoir[k] = elem
}
}
reservoir
}
}

/// Create a random number generator with a default algorithm and seed.
Expand Down Expand Up @@ -914,6 +951,7 @@ pub fn random<T: Rand>() -> T {

#[cfg(test)]
mod test {
use iterator::{Iterator, range};
use option::{Option, Some};
use super::*;

Expand Down Expand Up @@ -1130,6 +1168,24 @@ mod test {
}
}
}

#[test]
fn test_sample() {
let MIN_VAL = 1;
let MAX_VAL = 100;

let mut r = rng();
let vals = range(MIN_VAL, MAX_VAL).to_owned_vec();
let small_sample = r.sample(vals.iter(), 5);
let large_sample = r.sample(vals.iter(), vals.len() + 5);

assert_eq!(small_sample.len(), 5);
assert_eq!(large_sample.len(), vals.len());

assert!(small_sample.iter().all(|e| {
**e >= MIN_VAL && **e <= MAX_VAL
}));
}
}

#[cfg(test)]
Expand Down