Skip to content

Commit a698b81

Browse files
committed
collections::bitv: ensure correct masking behaviour
The internal masking behaviour for `Bitv` is now defined as: - Any entirely words in self.storage must be all zeroes. - Any partially used words may have anything at all in their unused bits. This means: - When decreasing self.nbits, care must be taken that any no-longer-used words are zeroed out. - When increasing self.nbits, care must be taken that any newly-unmasked bits are set to their correct values. - When reading words, care should be taken that the values of unused bits are not used. (Preferably, use `Bitv::mask_words` which zeroes them out for you.) The old behaviour was that every unused bit was always set to zero. The problem with this is that unused bits are almost never read, so forgetting to do this will result in very subtle and hard-to-track down bugs. This way the responsibility for masking falls on the places which might cause unused bits to be read: for now, this is only `Bitv::mask_words` and `BitvSet::insert`.
1 parent 2d23319 commit a698b81

File tree

1 file changed

+87
-83
lines changed

1 file changed

+87
-83
lines changed

src/libcollections/bitv.rs

Lines changed: 87 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,6 @@ use std::hash;
2424
use {Collection, Mutable, Set, MutableSet};
2525
use vec::Vec;
2626

27-
/**
28-
* A mask that has a 1 for each defined bit in the n'th element of a `Bitv`,
29-
* assuming n bits.
30-
*/
31-
#[inline]
32-
fn big_mask(nbits: uint, elem: uint) -> uint {
33-
let rmd = nbits % uint::BITS;
34-
let nelems = (nbits + uint::BITS - 1) / uint::BITS;
35-
36-
if elem < nelems - 1 || rmd == 0 {
37-
!0
38-
} else {
39-
(1 << rmd) - 1
40-
}
41-
}
42-
4327
/// The bitvector type
4428
///
4529
/// # Example
@@ -75,35 +59,47 @@ pub struct Bitv {
7559
nbits: uint
7660
}
7761

78-
struct Words<'a> {
62+
struct MaskWords<'a> {
7963
iter: slice::Items<'a, uint>,
64+
next_word: Option<&'a uint>,
65+
last_word_mask: uint,
8066
offset: uint
8167
}
8268

83-
impl<'a> Iterator<(uint, uint)> for Words<'a> {
69+
impl<'a> Iterator<(uint, uint)> for MaskWords<'a> {
8470
/// Returns (offset, word)
8571
fn next<'a>(&'a mut self) -> Option<(uint, uint)> {
86-
let ret = self.iter.next().map(|&n| (self.offset, n));
87-
self.offset += 1;
88-
ret
72+
let ret = self.next_word;
73+
match ret {
74+
Some(&w) => {
75+
self.next_word = self.iter.next();
76+
self.offset += 1;
77+
// The last word may need to be masked
78+
if self.next_word.is_none() {
79+
Some((self.offset - 1, w & self.last_word_mask))
80+
} else {
81+
Some((self.offset - 1, w))
82+
}
83+
},
84+
None => None
85+
}
8986
}
9087
}
9188

9289
impl Bitv {
9390
#[inline]
94-
fn process(&mut self, other: &Bitv, nbits: uint,
95-
op: |uint, uint| -> uint) -> bool {
91+
fn process(&mut self, other: &Bitv, op: |uint, uint| -> uint) -> bool {
9692
let len = other.storage.len();
9793
assert_eq!(self.storage.len(), len);
9894
let mut changed = false;
99-
for (i, (a, b)) in self.storage.mut_iter()
100-
.zip(other.storage.iter())
101-
.enumerate() {
102-
let mask = big_mask(nbits, i);
103-
let w0 = *a & mask;
104-
let w1 = *b & mask;
105-
let w = op(w0, w1) & mask;
106-
if w0 != w {
95+
// Notice: `a` is *not* masked here, which is fine as long as
96+
// `op` is a bitwise operation, since any bits that should've
97+
// been masked were fine to change anyway. `b` is masked to
98+
// make sure its unmasked bits do not cause damage.
99+
for (a, (_, b)) in self.storage.mut_iter()
100+
.zip(other.mask_words(0)) {
101+
let w = op(*a, b);
102+
if *a != w {
107103
changed = true;
108104
*a = w;
109105
}
@@ -112,10 +108,20 @@ impl Bitv {
112108
}
113109

114110
#[inline]
115-
#[inline]
116-
fn words<'a>(&'a self, start: uint) -> Words<'a> {
117-
Words {
118-
iter: self.storage.slice_from(start).iter(),
111+
fn mask_words<'a>(&'a self, mut start: uint) -> MaskWords<'a> {
112+
if start > self.storage.len() {
113+
start = self.storage.len();
114+
}
115+
let mut iter = self.storage.slice_from(start).iter();
116+
MaskWords {
117+
next_word: iter.next(),
118+
iter: iter,
119+
last_word_mask: {
120+
let rem = self.nbits % uint::BITS;
121+
if rem > 0 {
122+
(1 << rem) - 1
123+
} else { !0 }
124+
},
119125
offset: start
120126
}
121127
}
@@ -124,15 +130,8 @@ impl Bitv {
124130
/// to `init`.
125131
pub fn new(nbits: uint, init: bool) -> Bitv {
126132
Bitv {
127-
storage: {
128-
let nelems = (nbits + uint::BITS - 1) / uint::BITS;
129-
let mut v = Vec::from_elem(nelems, if init { !0u } else { 0u });
130-
// Zero out any remainder bits
131-
if nbits % uint::BITS > 0 {
132-
*v.get_mut(nelems - 1) &= (1 << nbits % uint::BITS) - 1;
133-
}
134-
v
135-
},
133+
storage: Vec::from_elem((nbits + uint::BITS - 1) / uint::BITS,
134+
if init { !0u } else { 0u }),
136135
nbits: nbits
137136
}
138137
}
@@ -145,8 +144,7 @@ impl Bitv {
145144
*/
146145
#[inline]
147146
pub fn union(&mut self, other: &Bitv) -> bool {
148-
let nbits = self.nbits;
149-
self.process(other, nbits, |w1, w2| w1 | w2)
147+
self.process(other, |w1, w2| w1 | w2)
150148
}
151149

152150
/**
@@ -157,8 +155,7 @@ impl Bitv {
157155
*/
158156
#[inline]
159157
pub fn intersect(&mut self, other: &Bitv) -> bool {
160-
let nbits = self.nbits;
161-
self.process(other, nbits, |w1, w2| w1 & w2)
158+
self.process(other, |w1, w2| w1 & w2)
162159
}
163160

164161
/**
@@ -169,8 +166,7 @@ impl Bitv {
169166
*/
170167
#[inline]
171168
pub fn assign(&mut self, other: &Bitv) -> bool {
172-
let nbits = self.nbits;
173-
self.process(other, nbits, |_, w| w)
169+
self.process(other, |_, w| w)
174170
}
175171

176172
/// Retrieve the value at index `i`
@@ -227,20 +223,18 @@ impl Bitv {
227223
*/
228224
#[inline]
229225
pub fn difference(&mut self, other: &Bitv) -> bool {
230-
let nbits = self.nbits;
231-
self.process(other, nbits, |w1, w2| w1 & !w2)
226+
self.process(other, |w1, w2| w1 & !w2)
232227
}
233228

234229
/// Returns `true` if all bits are 1
235230
#[inline]
236231
pub fn all(&self) -> bool {
237-
for (i, &elem) in self.storage.iter().enumerate() {
238-
let mask = big_mask(self.nbits, i);
239-
if elem & mask != mask {
240-
return false;
241-
}
242-
}
243-
true
232+
let mut last_word = !0u;
233+
// Check that every word but the last is all-ones...
234+
self.mask_words(0).all(|(_, elem)|
235+
{ let tmp = last_word; last_word = elem; tmp == !0u }) &&
236+
// ...and that the last word is ones as far as it needs to be
237+
(last_word == ((1 << self.nbits % uint::BITS) - 1) || last_word == !0u)
244238
}
245239

246240
/// Returns an iterator over the elements of the vector in order.
@@ -265,13 +259,7 @@ impl Bitv {
265259

266260
/// Returns `true` if all bits are 0
267261
pub fn none(&self) -> bool {
268-
for (i, &elem) in self.storage.iter().enumerate() {
269-
let mask = big_mask(self.nbits, i);
270-
if elem & mask != 0 {
271-
return false;
272-
}
273-
}
274-
true
262+
self.mask_words(0).all(|(_, w)| w == 0)
275263
}
276264

277265
#[inline]
@@ -397,8 +385,8 @@ impl fmt::Show for Bitv {
397385
impl<S: hash::Writer> hash::Hash<S> for Bitv {
398386
fn hash(&self, state: &mut S) {
399387
self.nbits.hash(state);
400-
for (i, elem) in self.storage.iter().enumerate() {
401-
(elem & big_mask(self.nbits, i)).hash(state);
388+
for (_, elem) in self.mask_words(0) {
389+
elem.hash(state);
402390
}
403391
}
404392
}
@@ -409,13 +397,7 @@ impl cmp::PartialEq for Bitv {
409397
if self.nbits != other.nbits {
410398
return false;
411399
}
412-
for (i, (&w1, &w2)) in self.storage.iter().zip(other.storage.iter()).enumerate() {
413-
let mask = big_mask(self.nbits, i);
414-
if w1 & mask != w2 & mask {
415-
return false;
416-
}
417-
}
418-
true
400+
self.mask_words(0).zip(other.mask_words(0)).all(|((_, w1), (_, w2))| w1 == w2)
419401
}
420402
}
421403

@@ -546,7 +528,7 @@ impl BitvSet {
546528
// Unwrap Bitvs
547529
let &BitvSet(ref mut self_bitv) = self;
548530
let &BitvSet(ref other_bitv) = other;
549-
for (i, w) in other_bitv.words(0) {
531+
for (i, w) in other_bitv.mask_words(0) {
550532
let old = *self_bitv.storage.get(i);
551533
let new = f(old, w);
552534
*self_bitv.storage.get_mut(i) = new;
@@ -563,7 +545,7 @@ impl BitvSet {
563545
let n = bitv.storage.iter().rev().take_while(|&&n| n == 0).count();
564546
// Truncate
565547
let trunc_len = cmp::max(old_len - n, 1);
566-
bitv.storage.truncate(cmp::max(old_len - n, 1));
548+
bitv.storage.truncate(trunc_len);
567549
bitv.nbits = trunc_len * uint::BITS;
568550
}
569551

@@ -710,6 +692,12 @@ impl MutableSet<uint> for BitvSet {
710692
}
711693
let &BitvSet(ref mut bitv) = self;
712694
if value >= bitv.nbits {
695+
// If we are increasing nbits, make sure we mask out any previously-unconsidered bits
696+
let old_rem = bitv.nbits % uint::BITS;
697+
if old_rem != 0 {
698+
let old_last_word = (bitv.nbits + uint::BITS - 1) / uint::BITS - 1;
699+
*bitv.storage.get_mut(old_last_word) &= (1 << old_rem) - 1;
700+
}
713701
bitv.nbits = value + 1;
714702
}
715703
bitv.set(value, true);
@@ -733,10 +721,10 @@ impl BitvSet {
733721
/// and w1/w2 are the words coming from the two vectors self, other.
734722
fn commons<'a>(&'a self, other: &'a BitvSet)
735723
-> Map<((uint, uint), (uint, uint)), (uint, uint, uint),
736-
Zip<Words<'a>, Words<'a>>> {
724+
Zip<MaskWords<'a>, MaskWords<'a>>> {
737725
let &BitvSet(ref self_bitv) = self;
738726
let &BitvSet(ref other_bitv) = other;
739-
self_bitv.words(0).zip(other_bitv.words(0))
727+
self_bitv.mask_words(0).zip(other_bitv.mask_words(0))
740728
.map(|((i, w1), (_, w2))| (i * uint::BITS, w1, w2))
741729
}
742730

@@ -748,17 +736,17 @@ impl BitvSet {
748736
/// is true if the word comes from `self`, and `false` if it comes from
749737
/// `other`.
750738
fn outliers<'a>(&'a self, other: &'a BitvSet)
751-
-> Map<(uint, uint), (bool, uint, uint), Words<'a>> {
739+
-> Map<(uint, uint), (bool, uint, uint), MaskWords<'a>> {
752740
let slen = self.capacity() / uint::BITS;
753741
let olen = other.capacity() / uint::BITS;
754742
let &BitvSet(ref self_bitv) = self;
755743
let &BitvSet(ref other_bitv) = other;
756744

757745
if olen < slen {
758-
self_bitv.words(olen)
746+
self_bitv.mask_words(olen)
759747
.map(|(i, w)| (true, i * uint::BITS, w))
760748
} else {
761-
other_bitv.words(slen)
749+
other_bitv.mask_words(slen)
762750
.map(|(i, w)| (false, i * uint::BITS, w))
763751
}
764752
}
@@ -1250,16 +1238,32 @@ mod tests {
12501238
});
12511239
}
12521240

1241+
#[test]
1242+
fn test_bitv_masking() {
1243+
let b = Bitv::new(140, true);
1244+
let mut bs = BitvSet::from_bitv(b);
1245+
assert!(bs.contains(&139));
1246+
assert!(!bs.contains(&140));
1247+
assert!(bs.insert(150));
1248+
assert!(!bs.contains(&140));
1249+
assert!(!bs.contains(&149));
1250+
assert!(bs.contains(&150));
1251+
assert!(!bs.contains(&151));
1252+
}
1253+
12531254
#[test]
12541255
fn test_bitv_set_basic() {
12551256
let mut b = BitvSet::new();
12561257
assert!(b.insert(3));
12571258
assert!(!b.insert(3));
12581259
assert!(b.contains(&3));
1260+
assert!(b.insert(4));
1261+
assert!(!b.insert(4));
1262+
assert!(b.contains(&3));
12591263
assert!(b.insert(400));
12601264
assert!(!b.insert(400));
12611265
assert!(b.contains(&400));
1262-
assert_eq!(b.len(), 2);
1266+
assert_eq!(b.len(), 3);
12631267
}
12641268

12651269
#[test]

0 commit comments

Comments
 (0)