@@ -24,22 +24,6 @@ use std::hash;
24
24
use { Collection , Mutable , Set , MutableSet } ;
25
25
use vec:: Vec ;
26
26
27
- /**
28
- * A mask that has a 1 for each defined bit in the n'th element of a `Bitv`,
29
- * assuming n bits.
30
- */
31
- #[ inline]
32
- fn big_mask ( nbits : uint , elem : uint ) -> uint {
33
- let rmd = nbits % uint:: BITS ;
34
- let nelems = ( nbits + uint:: BITS - 1 ) / uint:: BITS ;
35
-
36
- if elem < nelems - 1 || rmd == 0 {
37
- !0
38
- } else {
39
- ( 1 << rmd) - 1
40
- }
41
- }
42
-
43
27
/// The bitvector type
44
28
///
45
29
/// # Example
@@ -75,35 +59,47 @@ pub struct Bitv {
75
59
nbits : uint
76
60
}
77
61
78
- struct Words < ' a > {
62
+ struct MaskWords < ' a > {
79
63
iter : slice:: Items < ' a , uint > ,
64
+ next_word : Option < & ' a uint > ,
65
+ last_word_mask : uint ,
80
66
offset : uint
81
67
}
82
68
83
- impl < ' a > Iterator < ( uint , uint ) > for Words < ' a > {
69
+ impl < ' a > Iterator < ( uint , uint ) > for MaskWords < ' a > {
84
70
/// Returns (offset, word)
85
71
fn next < ' a > ( & ' a mut self ) -> Option < ( uint , uint ) > {
86
- let ret = self . iter . next ( ) . map ( |& n| ( self . offset , n) ) ;
87
- self . offset += 1 ;
88
- ret
72
+ let ret = self . next_word ;
73
+ match ret {
74
+ Some ( & w) => {
75
+ self . next_word = self . iter . next ( ) ;
76
+ self . offset += 1 ;
77
+ // The last word may need to be masked
78
+ if self . next_word . is_none ( ) {
79
+ Some ( ( self . offset - 1 , w & self . last_word_mask ) )
80
+ } else {
81
+ Some ( ( self . offset - 1 , w) )
82
+ }
83
+ } ,
84
+ None => None
85
+ }
89
86
}
90
87
}
91
88
92
89
impl Bitv {
93
90
#[ inline]
94
- fn process ( & mut self , other : & Bitv , nbits : uint ,
95
- op : |uint , uint| -> uint) -> bool {
91
+ fn process ( & mut self , other : & Bitv , op : |uint , uint| -> uint) -> bool {
96
92
let len = other. storage . len ( ) ;
97
93
assert_eq ! ( self . storage. len( ) , len) ;
98
94
let mut changed = false ;
99
- for ( i , ( a , b ) ) in self . storage . mut_iter ( )
100
- . zip ( other . storage . iter ( ) )
101
- . enumerate ( ) {
102
- let mask = big_mask ( nbits , i ) ;
103
- let w0 = * a & mask ;
104
- let w1 = * b & mask ;
105
- let w = op ( w0 , w1 ) & mask ;
106
- if w0 != w {
95
+ // Notice: `a` is *not* masked here, which is fine as long as
96
+ // `op` is a bitwise operation, since any bits that should've
97
+ // been masked were fine to change anyway. `b` is masked to
98
+ // make sure its unmasked bits do not cause damage.
99
+ for ( a , ( _ , b ) ) in self . storage . mut_iter ( )
100
+ . zip ( other . mask_words ( 0 ) ) {
101
+ let w = op ( * a , b ) ;
102
+ if * a != w {
107
103
changed = true ;
108
104
* a = w;
109
105
}
@@ -112,10 +108,20 @@ impl Bitv {
112
108
}
113
109
114
110
#[ inline]
115
- #[ inline]
116
- fn words < ' a > ( & ' a self , start : uint ) -> Words < ' a > {
117
- Words {
118
- iter : self . storage . slice_from ( start) . iter ( ) ,
111
+ fn mask_words < ' a > ( & ' a self , mut start : uint ) -> MaskWords < ' a > {
112
+ if start > self . storage . len ( ) {
113
+ start = self . storage . len ( ) ;
114
+ }
115
+ let mut iter = self . storage . slice_from ( start) . iter ( ) ;
116
+ MaskWords {
117
+ next_word : iter. next ( ) ,
118
+ iter : iter,
119
+ last_word_mask : {
120
+ let rem = self . nbits % uint:: BITS ;
121
+ if rem > 0 {
122
+ ( 1 << rem) - 1
123
+ } else { !0 }
124
+ } ,
119
125
offset : start
120
126
}
121
127
}
@@ -124,15 +130,8 @@ impl Bitv {
124
130
/// to `init`.
125
131
pub fn new ( nbits : uint , init : bool ) -> Bitv {
126
132
Bitv {
127
- storage : {
128
- let nelems = ( nbits + uint:: BITS - 1 ) / uint:: BITS ;
129
- let mut v = Vec :: from_elem ( nelems, if init { !0 u } else { 0 u } ) ;
130
- // Zero out any remainder bits
131
- if nbits % uint:: BITS > 0 {
132
- * v. get_mut ( nelems - 1 ) &= ( 1 << nbits % uint:: BITS ) - 1 ;
133
- }
134
- v
135
- } ,
133
+ storage : Vec :: from_elem ( ( nbits + uint:: BITS - 1 ) / uint:: BITS ,
134
+ if init { !0 u } else { 0 u } ) ,
136
135
nbits : nbits
137
136
}
138
137
}
@@ -145,8 +144,7 @@ impl Bitv {
145
144
*/
146
145
#[ inline]
147
146
pub fn union ( & mut self , other : & Bitv ) -> bool {
148
- let nbits = self . nbits ;
149
- self . process ( other, nbits, |w1, w2| w1 | w2)
147
+ self . process ( other, |w1, w2| w1 | w2)
150
148
}
151
149
152
150
/**
@@ -157,8 +155,7 @@ impl Bitv {
157
155
*/
158
156
#[ inline]
159
157
pub fn intersect ( & mut self , other : & Bitv ) -> bool {
160
- let nbits = self . nbits ;
161
- self . process ( other, nbits, |w1, w2| w1 & w2)
158
+ self . process ( other, |w1, w2| w1 & w2)
162
159
}
163
160
164
161
/**
@@ -169,8 +166,7 @@ impl Bitv {
169
166
*/
170
167
#[ inline]
171
168
pub fn assign ( & mut self , other : & Bitv ) -> bool {
172
- let nbits = self . nbits ;
173
- self . process ( other, nbits, |_, w| w)
169
+ self . process ( other, |_, w| w)
174
170
}
175
171
176
172
/// Retrieve the value at index `i`
@@ -227,20 +223,18 @@ impl Bitv {
227
223
*/
228
224
#[ inline]
229
225
pub fn difference ( & mut self , other : & Bitv ) -> bool {
230
- let nbits = self . nbits ;
231
- self . process ( other, nbits, |w1, w2| w1 & !w2)
226
+ self . process ( other, |w1, w2| w1 & !w2)
232
227
}
233
228
234
229
/// Returns `true` if all bits are 1
235
230
#[ inline]
236
231
pub fn all ( & self ) -> bool {
237
- for ( i, & elem) in self . storage . iter ( ) . enumerate ( ) {
238
- let mask = big_mask ( self . nbits , i) ;
239
- if elem & mask != mask {
240
- return false ;
241
- }
242
- }
243
- true
232
+ let mut last_word = !0 u;
233
+ // Check that every word but the last is all-ones...
234
+ self . mask_words ( 0 ) . all ( |( _, elem) |
235
+ { let tmp = last_word; last_word = elem; tmp == !0 u } ) &&
236
+ // ...and that the last word is ones as far as it needs to be
237
+ ( last_word == ( ( 1 << self . nbits % uint:: BITS ) - 1 ) || last_word == !0 u)
244
238
}
245
239
246
240
/// Returns an iterator over the elements of the vector in order.
@@ -265,13 +259,7 @@ impl Bitv {
265
259
266
260
/// Returns `true` if all bits are 0
267
261
pub fn none ( & self ) -> bool {
268
- for ( i, & elem) in self . storage . iter ( ) . enumerate ( ) {
269
- let mask = big_mask ( self . nbits , i) ;
270
- if elem & mask != 0 {
271
- return false ;
272
- }
273
- }
274
- true
262
+ self . mask_words ( 0 ) . all ( |( _, w) | w == 0 )
275
263
}
276
264
277
265
#[ inline]
@@ -397,8 +385,8 @@ impl fmt::Show for Bitv {
397
385
impl < S : hash:: Writer > hash:: Hash < S > for Bitv {
398
386
fn hash ( & self , state : & mut S ) {
399
387
self . nbits . hash ( state) ;
400
- for ( i , elem) in self . storage . iter ( ) . enumerate ( ) {
401
- ( elem & big_mask ( self . nbits , i ) ) . hash ( state) ;
388
+ for ( _ , elem) in self . mask_words ( 0 ) {
389
+ elem. hash ( state) ;
402
390
}
403
391
}
404
392
}
@@ -409,13 +397,7 @@ impl cmp::PartialEq for Bitv {
409
397
if self . nbits != other. nbits {
410
398
return false ;
411
399
}
412
- for ( i, ( & w1, & w2) ) in self . storage . iter ( ) . zip ( other. storage . iter ( ) ) . enumerate ( ) {
413
- let mask = big_mask ( self . nbits , i) ;
414
- if w1 & mask != w2 & mask {
415
- return false ;
416
- }
417
- }
418
- true
400
+ self . mask_words ( 0 ) . zip ( other. mask_words ( 0 ) ) . all ( |( ( _, w1) , ( _, w2) ) | w1 == w2)
419
401
}
420
402
}
421
403
@@ -546,7 +528,7 @@ impl BitvSet {
546
528
// Unwrap Bitvs
547
529
let & BitvSet ( ref mut self_bitv) = self ;
548
530
let & BitvSet ( ref other_bitv) = other;
549
- for ( i, w) in other_bitv. words ( 0 ) {
531
+ for ( i, w) in other_bitv. mask_words ( 0 ) {
550
532
let old = * self_bitv. storage . get ( i) ;
551
533
let new = f ( old, w) ;
552
534
* self_bitv. storage . get_mut ( i) = new;
@@ -563,7 +545,7 @@ impl BitvSet {
563
545
let n = bitv. storage . iter ( ) . rev ( ) . take_while ( |& & n| n == 0 ) . count ( ) ;
564
546
// Truncate
565
547
let trunc_len = cmp:: max ( old_len - n, 1 ) ;
566
- bitv. storage . truncate ( cmp :: max ( old_len - n , 1 ) ) ;
548
+ bitv. storage . truncate ( trunc_len ) ;
567
549
bitv. nbits = trunc_len * uint:: BITS ;
568
550
}
569
551
@@ -710,6 +692,12 @@ impl MutableSet<uint> for BitvSet {
710
692
}
711
693
let & BitvSet ( ref mut bitv) = self ;
712
694
if value >= bitv. nbits {
695
+ // If we are increasing nbits, make sure we mask out any previously-unconsidered bits
696
+ let old_rem = bitv. nbits % uint:: BITS ;
697
+ if old_rem != 0 {
698
+ let old_last_word = ( bitv. nbits + uint:: BITS - 1 ) / uint:: BITS - 1 ;
699
+ * bitv. storage . get_mut ( old_last_word) &= ( 1 << old_rem) - 1 ;
700
+ }
713
701
bitv. nbits = value + 1 ;
714
702
}
715
703
bitv. set ( value, true ) ;
@@ -733,10 +721,10 @@ impl BitvSet {
733
721
/// and w1/w2 are the words coming from the two vectors self, other.
734
722
fn commons < ' a > ( & ' a self , other : & ' a BitvSet )
735
723
-> Map < ( ( uint , uint ) , ( uint , uint ) ) , ( uint , uint , uint ) ,
736
- Zip < Words < ' a > , Words < ' a > > > {
724
+ Zip < MaskWords < ' a > , MaskWords < ' a > > > {
737
725
let & BitvSet ( ref self_bitv) = self ;
738
726
let & BitvSet ( ref other_bitv) = other;
739
- self_bitv. words ( 0 ) . zip ( other_bitv. words ( 0 ) )
727
+ self_bitv. mask_words ( 0 ) . zip ( other_bitv. mask_words ( 0 ) )
740
728
. map ( |( ( i, w1) , ( _, w2) ) | ( i * uint:: BITS , w1, w2) )
741
729
}
742
730
@@ -748,17 +736,17 @@ impl BitvSet {
748
736
/// is true if the word comes from `self`, and `false` if it comes from
749
737
/// `other`.
750
738
fn outliers < ' a > ( & ' a self , other : & ' a BitvSet )
751
- -> Map < ( uint , uint ) , ( bool , uint , uint ) , Words < ' a > > {
739
+ -> Map < ( uint , uint ) , ( bool , uint , uint ) , MaskWords < ' a > > {
752
740
let slen = self . capacity ( ) / uint:: BITS ;
753
741
let olen = other. capacity ( ) / uint:: BITS ;
754
742
let & BitvSet ( ref self_bitv) = self ;
755
743
let & BitvSet ( ref other_bitv) = other;
756
744
757
745
if olen < slen {
758
- self_bitv. words ( olen)
746
+ self_bitv. mask_words ( olen)
759
747
. map ( |( i, w) | ( true , i * uint:: BITS , w) )
760
748
} else {
761
- other_bitv. words ( slen)
749
+ other_bitv. mask_words ( slen)
762
750
. map ( |( i, w) | ( false , i * uint:: BITS , w) )
763
751
}
764
752
}
@@ -1250,16 +1238,32 @@ mod tests {
1250
1238
} ) ;
1251
1239
}
1252
1240
1241
+ #[ test]
1242
+ fn test_bitv_masking ( ) {
1243
+ let b = Bitv :: new ( 140 , true ) ;
1244
+ let mut bs = BitvSet :: from_bitv ( b) ;
1245
+ assert ! ( bs. contains( & 139 ) ) ;
1246
+ assert ! ( !bs. contains( & 140 ) ) ;
1247
+ assert ! ( bs. insert( 150 ) ) ;
1248
+ assert ! ( !bs. contains( & 140 ) ) ;
1249
+ assert ! ( !bs. contains( & 149 ) ) ;
1250
+ assert ! ( bs. contains( & 150 ) ) ;
1251
+ assert ! ( !bs. contains( & 151 ) ) ;
1252
+ }
1253
+
1253
1254
#[ test]
1254
1255
fn test_bitv_set_basic ( ) {
1255
1256
let mut b = BitvSet :: new ( ) ;
1256
1257
assert ! ( b. insert( 3 ) ) ;
1257
1258
assert ! ( !b. insert( 3 ) ) ;
1258
1259
assert ! ( b. contains( & 3 ) ) ;
1260
+ assert ! ( b. insert( 4 ) ) ;
1261
+ assert ! ( !b. insert( 4 ) ) ;
1262
+ assert ! ( b. contains( & 3 ) ) ;
1259
1263
assert ! ( b. insert( 400 ) ) ;
1260
1264
assert ! ( !b. insert( 400 ) ) ;
1261
1265
assert ! ( b. contains( & 400 ) ) ;
1262
- assert_eq ! ( b. len( ) , 2 ) ;
1266
+ assert_eq ! ( b. len( ) , 3 ) ;
1263
1267
}
1264
1268
1265
1269
#[ test]
0 commit comments