@@ -461,6 +461,26 @@ pub trait RngUtil {
461
461
* ~~~
462
462
*/
463
463
fn shuffle_mut < T > ( & mut self , values : & mut [ T ] ) ;
464
+
465
+ /**
466
+ * Sample up to `n` values from an iterator.
467
+ *
468
+ * # Example
469
+ *
470
+ * ~~~ {.rust}
471
+ *
472
+ * use std::rand;
473
+ * use std::rand::RngUtil;
474
+ *
475
+ * fn main() {
476
+ * let mut rng = rand::rng();
477
+ * let vals = range(1, 100).to_owned_vec();
478
+ * let sample = rng.sample(vals.iter(), 5);
479
+ * printfln!(sample);
480
+ * }
481
+ * ~~~
482
+ */
483
+ fn sample < A , T : Iterator < A > > ( & mut self , iter : T , n : uint ) -> ~[ A ] ;
464
484
}
465
485
466
486
/// Extension methods for random number generators
@@ -607,6 +627,23 @@ impl<R: Rng> RngUtil for R {
607
627
values. swap ( i, self . gen_uint_range ( 0 u, i + 1 u) ) ;
608
628
}
609
629
}
630
+
631
+ /// Randomly sample up to `n` elements from an iterator
632
+ fn sample < A , T : Iterator < A > > ( & mut self , iter : T , n : uint ) -> ~[ A ] {
633
+ let mut reservoir : ~[ A ] = vec:: with_capacity ( n) ;
634
+ for ( i, elem) in iter. enumerate ( ) {
635
+ if i < n {
636
+ reservoir. push ( elem) ;
637
+ loop
638
+ }
639
+
640
+ let k = self . gen_uint_range ( 0 , i + 1 ) ;
641
+ if k < reservoir. len ( ) {
642
+ reservoir[ k] = elem
643
+ }
644
+ }
645
+ reservoir
646
+ }
610
647
}
611
648
612
649
/// Create a random number generator with a default algorithm and seed.
@@ -914,6 +951,7 @@ pub fn random<T: Rand>() -> T {
914
951
915
952
#[ cfg( test) ]
916
953
mod test {
954
+ use iterator:: { Iterator , range} ;
917
955
use option:: { Option , Some } ;
918
956
use super :: * ;
919
957
@@ -1130,6 +1168,24 @@ mod test {
1130
1168
}
1131
1169
}
1132
1170
}
1171
+
1172
+ #[ test]
1173
+ fn test_sample ( ) {
1174
+ let MIN_VAL = 1 ;
1175
+ let MAX_VAL = 100 ;
1176
+
1177
+ let mut r = rng ( ) ;
1178
+ let vals = range ( MIN_VAL , MAX_VAL ) . to_owned_vec ( ) ;
1179
+ let small_sample = r. sample ( vals. iter ( ) , 5 ) ;
1180
+ let large_sample = r. sample ( vals. iter ( ) , vals. len ( ) + 5 ) ;
1181
+
1182
+ assert_eq ! ( small_sample. len( ) , 5 ) ;
1183
+ assert_eq ! ( large_sample. len( ) , vals. len( ) ) ;
1184
+
1185
+ assert ! ( small_sample. iter( ) . all( |e| {
1186
+ * * e >= MIN_VAL && * * e <= MAX_VAL
1187
+ } ) ) ;
1188
+ }
1133
1189
}
1134
1190
1135
1191
#[ cfg( test) ]
0 commit comments