Skip to content

Commit c21a8fd

Browse files
committed
TST+BF: Expand tests and fix bugs
1 parent cb3222b commit c21a8fd

File tree

2 files changed

+69
-32
lines changed

2 files changed

+69
-32
lines changed

nibabel/metasum.py

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def get_mask(self, value):
125125
return res
126126
return self._val_bitarrs[value].copy()
127127

128-
def num_indices(self, value, mask=None):
128+
def count(self, value, mask=None):
129129
'''Number of indices for the given `value`'''
130130
if mask is not None:
131131
if len(mask) != self.n_input:
@@ -136,15 +136,15 @@ def num_indices(self, value, mask=None):
136136
if mask is None:
137137
return self._n_input
138138
return mask.count()
139-
unique_idx = self._unique_vals.get(_NoValue)
139+
unique_idx = self._unique_vals.get(value, _NoValue)
140140
if unique_idx is not _NoValue:
141141
if mask is not None:
142142
if mask[unique_idx]:
143143
return 1
144144
return 0
145145
return 1
146146
if mask is not None:
147-
return (self._val_bitarrs[value] & mask).count
147+
return (self._val_bitarrs[value] & mask).count()
148148
return self._val_bitarrs[value].count()
149149

150150
def get_value(self, idx):
@@ -169,14 +169,14 @@ def to_list(self):
169169

170170
def extend(self, values):
171171
'''Add more values to the end of any existing ones'''
172-
curr_size = self._n_input
172+
init_size = self._n_input
173173
if isinstance(values, ValueIndices):
174174
other_is_vi = True
175175
other_size = values._n_input
176176
else:
177177
other_is_vi = False
178178
other_size = len(values)
179-
final_size = curr_size + other_size
179+
final_size = init_size + other_size
180180
for ba in self._val_bitarrs.values():
181181
ba.extend(zeros(other_size))
182182
if other_is_vi:
@@ -185,7 +185,7 @@ def extend(self, values):
185185
self._extend_const(values)
186186
return
187187
else:
188-
self._rm_const()
188+
self._rm_const(final_size)
189189
elif values._const_val is not _NoValue:
190190
cval = values._const_val
191191
other_unique = {}
@@ -199,40 +199,49 @@ def extend(self, values):
199199
other_unique = values._unique_vals
200200
other_bitarrs = values._val_bitarrs
201201
for val, other_idx in other_unique.items():
202-
self._ingest_single(val, final_size, curr_size, other_idx)
202+
self._ingest_single(val, final_size, init_size, other_idx)
203203
for val, other_ba in other_bitarrs.items():
204204
curr_ba = self._val_bitarrs.get(val)
205205
if curr_ba is None:
206206
curr_idx = self._unique_vals.get(val)
207207
if curr_idx is None:
208-
if curr_size == 0:
208+
if init_size == 0:
209209
new_ba = other_ba.copy()
210210
else:
211-
new_ba = zeros(curr_size)
211+
new_ba = zeros(init_size)
212212
new_ba.extend(other_ba)
213213
else:
214-
new_ba = zeros(curr_size)
214+
new_ba = zeros(init_size)
215215
new_ba[curr_idx] = True
216216
new_ba.extend(other_ba)
217217
del self._unique_vals[val]
218218
self._val_bitarrs[val] = new_ba
219219
else:
220-
curr_ba[curr_size:] |= other_ba
220+
curr_ba[init_size:] |= other_ba
221+
self._n_input += other_ba.count()
221222
else:
222223
for other_idx, val in enumerate(values):
223-
self._ingest_single(val, final_size, curr_size, other_idx)
224-
self._n_input = final_size
224+
self._ingest_single(val, final_size, init_size, other_idx)
225+
assert self._n_input == final_size
225226

226227
def append(self, value):
227228
'''Append another value as input'''
228229
if self._const_val == value:
229230
self._n_input += 1
230231
return
231232
elif self._const_val is not _NoValue:
232-
self._rm_const()
233+
self._rm_const(self._n_input + 1)
234+
self._unique_vals[value] = self._n_input
235+
self._n_input += 1
236+
return
237+
if self._n_input == 0:
238+
self._const_val = value
239+
self._n_input += 1
240+
return
233241
curr_size = self._n_input
234242
found = False
235243
for val, bitarr in self._val_bitarrs.items():
244+
assert len(bitarr) == self._n_input
236245
if val == value:
237246
found = True
238247
bitarr.append(True)
@@ -318,7 +327,7 @@ def get_block_size(self):
318327
if rem != 0:
319328
return None
320329
for val in self.values():
321-
if self.num_indices(val) != block_size:
330+
if self.count(val) != block_size:
322331
return None
323332
return block_size
324333

@@ -335,32 +344,43 @@ def _extract_indices(self, ba):
335344
except ValueError:
336345
return
337346
yield curr_idx
338-
start = curr_idx
347+
start = curr_idx + 1
339348

340-
def _ingest_single(self, val, final_size, curr_size, other_idx):
349+
def _ingest_single(self, val, final_size, init_size, other_idx):
341350
'''Helper to ingest single value from another collection'''
351+
if val == self._const_val:
352+
self._n_input += 1
353+
return
354+
elif self._const_val is not _NoValue:
355+
self._rm_const(final_size)
356+
if self._n_input == 0:
357+
self._const_val = val
358+
self._n_input += 1
359+
return
360+
342361
curr_ba = self._val_bitarrs.get(val)
343362
if curr_ba is None:
344363
curr_idx = self._unique_vals.get(val)
345364
if curr_idx is None:
346-
self._unique_vals[val] = curr_size + other_idx
365+
self._unique_vals[val] = init_size + other_idx
347366
else:
348367
new_ba = zeros(final_size)
349368
new_ba[curr_idx] = True
350-
new_ba[curr_size + other_idx] = True
369+
new_ba[init_size + other_idx] = True
351370
self._val_bitarrs[val] = new_ba
352371
del self._unique_vals[val]
353372
else:
354-
curr_ba[curr_size + other_idx] = True
373+
curr_ba[init_size + other_idx] = True
374+
self._n_input += 1
355375

356-
def _rm_const(self):
376+
def _rm_const(self, final_size):
357377
assert self._const_val is not _NoValue
358378
if self._n_input == 1:
359379
self._unique_vals[self._const_val] = 0
360380
else:
361-
self._val_bitarrs[self._const_val] = bitarray(self._n_input)
362-
self._val_bitarrs[self._const_val].setall(1)
363-
self._const_val == _NoValue
381+
self._val_bitarrs[self._const_val] = zeros(final_size)
382+
self._val_bitarrs[self._const_val][:self._n_input] = True
383+
self._const_val = _NoValue
364384

365385
def _extend_const(self, other):
366386
if self._const_val != other._const_val:

nibabel/tests/test_metasum.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,16 @@
1313

1414

1515
@pytest.mark.parametrize("in_list", vidx_test_patterns)
16-
def test_value_indices_rt(in_list):
16+
def test_value_indices_basics(in_list):
1717
'''Test we can roundtrip list -> ValueIndices -> list'''
1818
vidx = ValueIndices(in_list)
19+
assert vidx.n_input == len(in_list)
20+
assert len(vidx) == len(set(in_list))
21+
assert sorted(vidx.values()) == sorted(list(set(in_list)))
22+
for val in vidx.values():
23+
assert vidx.count(val) == in_list.count(val)
24+
for in_idx in vidx[val]:
25+
assert in_list[in_idx] == val
1926
out_list = vidx.to_list()
2027
assert in_list == out_list
2128

@@ -40,22 +47,32 @@ def test_value_indices_append_extend(in_list):
4047
assert vidx.to_list() == in_list + in_list
4148

4249

43-
metasum_test_dicts = (({'key1': 0, 'key2': 'a', 'key3': 3.0},
44-
{'key1': 2, 'key2': 'c', 'key3': 1.0},
45-
{'key1': 1, 'key2': 'b', 'key3': 2.0},
50+
metasum_test_dicts = (({'u1': 0, 'u2': 'a', 'u3': 3.0, 'c1': True, 'r1': 5},
51+
{'u1': 2, 'u2': 'c', 'u3': 1.0, 'c1': True, 'r1': 5},
52+
{'u1': 1, 'u2': 'b', 'u3': 2.0, 'c1': True, 'r1': 7},
4653
),
47-
({'key1': 0, 'key2': 'a', 'key3': 3.0},
48-
{'key1': 2, 'key2': 'c'},
49-
{'key1': 1, 'key2': 'b', 'key3': 2.0},
54+
({'u1': 0, 'u2': 'a', 'u3': 3.0, 'c1': True, 'r1': 5},
55+
{'u1': 2, 'u2': 'c', 'c1': True, 'r1': 5},
56+
{'u1': 1, 'u2': 'b', 'u3': 2.0, 'c1': True},
5057
),
5158
)
5259

5360

5461
@pytest.mark.parametrize("in_dicts", metasum_test_dicts)
55-
def test_meta_summary_rt(in_dicts):
62+
def test_meta_summary_basics(in_dicts):
5663
msum = MetaSummary()
64+
all_keys = set()
5765
for in_dict in in_dicts:
5866
msum.append(in_dict)
67+
for key in in_dict.keys():
68+
all_keys.add(key)
69+
assert all_keys == set(msum.keys())
70+
for key in msum.const_keys():
71+
assert key.startswith('c')
72+
for key in msum.unique_keys():
73+
assert key.startswith('u')
74+
for key in msum.repeating_keys():
75+
assert key.startswith('r')
5976
for in_idx in range(len(in_dicts)):
6077
out_dict = msum.get_meta(in_idx)
6178
assert out_dict == in_dicts[in_idx]

0 commit comments

Comments
 (0)