Skip to content

Commit cb3222b

Browse files
committed
WIP: Basics mostly working, needs more testing and finish ndSort
1 parent d5b588d commit cb3222b

File tree

2 files changed

+245
-63
lines changed

2 files changed

+245
-63
lines changed

nibabel/metasum.py

Lines changed: 182 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,18 @@
66
# copyright and license terms.
77
#
88
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
9-
'''Aggregate information for mutliple images
9+
'''Memory efficient tracking of meta data dicts with repeating elements
1010
'''
11+
from dataclasses import dataclass
12+
from enum import IntEnum
13+
1114
from bitarray import bitarray, frozenbitarray
12-
from bitarry.utils import zeroes
15+
from bitarray.util import zeros
1316

1417

15-
class FloatCanon(object):
18+
class FloatCanon:
1619
'''Look up a canonical float that we compare equal to'''
20+
1721
def __init__(self, n_digits=6):
1822
self._n_digits = n_digits
1923
self._offset = 0.5 * (10 ** -n_digits)
@@ -39,7 +43,9 @@ def get(self, val):
3943

4044
# TODO: Integrate some value canonicalization filtering? Or just require the
4145
# user to do that themselves?
42-
class ValueIndices(object):
46+
47+
48+
class ValueIndices:
4349
"""Track indices of values in sequence.
4450
4551
If values repeat frequently then memory usage can be dramatically improved.
@@ -114,19 +120,31 @@ def get_mask(self, value):
114120
return res
115121
idx = self._unique_vals.get(value)
116122
if idx is not None:
117-
res = zeroes(self._n_inpuf)
123+
res = zeros(self._n_inpuf)
118124
res[idx] = 1
119125
return res
120126
return self._val_bitarrs[value].copy()
121127

122-
def num_indices(self, value):
128+
def num_indices(self, value, mask=None):
123129
'''Number of indices for the given `value`'''
130+
if mask is not None:
131+
if len(mask) != self.n_input:
132+
raise ValueError("Mask length must match input length")
124133
if self._const_val is not _NoValue:
125134
if self._const_val != value:
126135
raise KeyError()
127-
return self._n_input
128-
if value in self._unique_vals:
136+
if mask is None:
137+
return self._n_input
138+
return mask.count()
139+
unique_idx = self._unique_vals.get(_NoValue)
140+
if unique_idx is not _NoValue:
141+
if mask is not None:
142+
if mask[unique_idx]:
143+
return 1
144+
return 0
129145
return 1
146+
if mask is not None:
147+
return (self._val_bitarrs[value] & mask).count
130148
return self._val_bitarrs[value].count()
131149

132150
def get_value(self, idx):
@@ -138,13 +156,17 @@ def get_value(self, idx):
138156
for val, vidx in self._unique_vals.items():
139157
if vidx == idx:
140158
return val
141-
bit_idx = zeroes(self._n_input)
159+
bit_idx = zeros(self._n_input)
142160
bit_idx[idx] = 1
143161
for val, ba in self._val_bitarrs.items():
144-
if (ba | bit_idx).any():
162+
if (ba & bit_idx).any():
145163
return val
146164
assert False
147165

166+
def to_list(self):
167+
'''Convert back to a list of values'''
168+
return [self.get_value(i) for i in range(self.n_input)]
169+
148170
def extend(self, values):
149171
'''Add more values to the end of any existing ones'''
150172
curr_size = self._n_input
@@ -156,7 +178,7 @@ def extend(self, values):
156178
other_size = len(values)
157179
final_size = curr_size + other_size
158180
for ba in self._val_bitarrs.values():
159-
ba.extend(zeroes(other_size))
181+
ba.extend(zeros(other_size))
160182
if other_is_vi:
161183
if self._const_val is not _NoValue:
162184
if values._const_val is not _NoValue:
@@ -186,10 +208,10 @@ def extend(self, values):
186208
if curr_size == 0:
187209
new_ba = other_ba.copy()
188210
else:
189-
new_ba = zeroes(curr_size)
211+
new_ba = zeros(curr_size)
190212
new_ba.extend(other_ba)
191213
else:
192-
new_ba = zeroes(curr_size)
214+
new_ba = zeros(curr_size)
193215
new_ba[curr_idx] = True
194216
new_ba.extend(other_ba)
195217
del self._unique_vals[val]
@@ -221,13 +243,20 @@ def append(self, value):
221243
if curr_idx is None:
222244
self._unique_vals[value] = curr_size
223245
else:
224-
new_ba = zeroes(curr_size + 1)
246+
new_ba = zeros(curr_size + 1)
225247
new_ba[curr_idx] = True
226248
new_ba[curr_size] = True
227249
self._val_bitarrs[value] = new_ba
228250
del self._unique_vals[value]
229251
self._n_input += 1
230252

253+
def reverse(self):
254+
'''Reverse the indices in place'''
255+
for val, idx in self._unique_vals.items():
256+
self._unique_vals[val] = self._n_input - idx - 1
257+
for val, bitarr in self._val_bitarrs.items():
258+
bitarr.reverse()
259+
231260
def argsort(self, reverse=False):
232261
'''Return array of indices in order that sorts the values'''
233262
if self._const_val is not _NoValue:
@@ -248,6 +277,18 @@ def argsort(self, reverse=False):
248277
res_idx += 1
249278
return res
250279

280+
def reorder(self, order):
281+
'''Reorder the indices in place'''
282+
if len(order) != self._n_input:
283+
raise ValueError("The 'order' has the incorrect length")
284+
for val, idx in self._unique_vals.items():
285+
self._unique_vals[val] = order.index(idx)
286+
for val, bitarr in self._val_bitarrs.items():
287+
new_ba = zeros(self._n_input)
288+
for idx in self._extract_indices(bitarr):
289+
new_ba[order.index(idx)] = True
290+
self._val_bitarrs[val] = new_ba
291+
251292
def is_covariant(self, other):
252293
'''True if `other` has values that vary the same way ours do
253294
@@ -267,35 +308,30 @@ def is_covariant(self, other):
267308
return False
268309
return True
269310

270-
def is_blocked(self, block_factor=None):
271-
'''True if each value has the same number of indices
311+
def get_block_size(self):
312+
'''Return size of even blocks of values, or None if values aren't "blocked"
272313
273-
If `block_factor` is not None we also test that it evenly divides the
274-
block size.
314+
The number of values must evenly divide the number of inputs into the block size,
315+
with each value appearing that same number of times.
275316
'''
276317
block_size, rem = divmod(self._n_input, len(self))
277318
if rem != 0:
278-
return False
279-
if block_factor is not None and block_size % block_factor != 0:
280-
return False
319+
return None
281320
for val in self.values():
282321
if self.num_indices(val) != block_size:
283-
return False
284-
return True
322+
return None
323+
return block_size
285324

286325
def is_subpartition(self, other):
287-
'''True if we have more values and they nest within values from other
288-
289-
290-
'''
326+
''''''
291327

292328
def _extract_indices(self, ba):
293329
'''Generate integer indices from bitarray representation'''
294330
start = 0
295331
while True:
296332
try:
297333
# TODO: Is this the most efficient approach?
298-
curr_idx = ba.index(True, start=start)
334+
curr_idx = ba.index(True, start)
299335
except ValueError:
300336
return
301337
yield curr_idx
@@ -309,10 +345,10 @@ def _ingest_single(self, val, final_size, curr_size, other_idx):
309345
if curr_idx is None:
310346
self._unique_vals[val] = curr_size + other_idx
311347
else:
312-
new_ba = zeroes(final_size)
348+
new_ba = zeros(final_size)
313349
new_ba[curr_idx] = True
314350
new_ba[curr_size + other_idx] = True
315-
self._val_bitarrs = new_ba
351+
self._val_bitarrs[val] = new_ba
316352
del self._unique_vals[val]
317353
else:
318354
curr_ba[curr_size + other_idx] = True
@@ -351,13 +387,33 @@ def _extend_const(self, other):
351387
_MissingKey = object()
352388

353389

390+
class DimTypes(IntEnum):
391+
'''Enmerate the three types of nD dimensions'''
392+
SLICE = 1
393+
TIME = 2
394+
PARAM = 3
395+
396+
397+
@dataclass
398+
class DimIndex:
399+
'''Specify an nD index'''
400+
dim_type: DimTypes
401+
402+
key: str
403+
404+
405+
class NdSortError(Exception):
406+
'''Raised when the data cannot be sorted into an nD array as specified'''
407+
408+
354409
class MetaSummary:
355410
'''Summarize a sequence of dicts, tracking how individual keys vary
356411
357412
The assumption is that for any key many values will be constant, or at
358413
least repeated, and thus we can reduce memory consumption by only storing
359414
the value once along with the indices it appears at.
360415
'''
416+
361417
def __init__(self):
362418
self._v_idxs = {}
363419
self._n_input = 0
@@ -380,9 +436,6 @@ def append(self, meta):
380436
self._v_idxs[key] = v_idx
381437
self._n_input += 1
382438

383-
def extend(self, metas):
384-
pass # TODO
385-
386439
def keys(self):
387440
'''Generate all known keys'''
388441
return self._v_idxs.keys()
@@ -412,20 +465,26 @@ def repeating_keys(self):
412465
if 1 < len(v_idx) < n_input:
413466
yield key
414467

415-
def repeating_groups(self, block_only=False, block_factor=None):
416-
'''Generate groups of repeating keys that vary with the same pattern
468+
def covariant_groups(self, keys=None, block_only=False):
469+
'''Generate groups of keys that vary with the same pattern
417470
'''
418-
n_input = self._n_input
419-
if n_input <= 1:
420-
# If there is only one element, consider all keys as const
421-
return
422-
# TODO: Can we sort so grouped v_idxs are sequential?
423-
# - Sort by num values isn't sufficient
424-
curr_group = []
425-
for key, v_idx in self._v_idxs.items():
426-
if 1 < len(v_idx) < n_input:
427-
if v_idx.is_even(block_factor):
428-
pass # TODO
471+
if keys is None:
472+
keys = self.keys()
473+
groups = []
474+
for key in keys:
475+
v_idx = self._v_idxs[key]
476+
if len(groups) == 0:
477+
groups.append((key, v_idx))
478+
continue
479+
for group in groups:
480+
if group[0][1].is_covariant(v_idx):
481+
group.append(key)
482+
break
483+
else:
484+
groups.append((key, v_idx))
485+
for group in groups:
486+
group[0] = group[0][0]
487+
return groups
429488

430489
def get_meta(self, idx):
431490
'''Get the full dict at the given index'''
@@ -439,26 +498,86 @@ def get_meta(self, idx):
439498

440499
def get_val(self, idx, key, default=None):
441500
'''Get the value at `idx` for the `key`, or return `default``'''
442-
res = self._v_idxs[key].get_value(key)
501+
res = self._v_idxs[key].get_value(idx)
443502
if res is _MissingKey:
444503
return default
445504
return res
446505

447-
def nd_sort(self, dim_keys=None):
448-
'''Produce indices ordered so as to fill an n-D array'''
506+
def reorder(self, order):
507+
'''Reorder indices in place'''
508+
for v_idx in self._v_idxs.values():
509+
v_idx.reorder(order)
449510

450-
class SummaryTree:
451-
'''Groups incoming meta data and creates hierarchy of related groups
452-
453-
Each leaf node in the tree is a `MetaSummary`
454-
'''
455-
def __init__(self, group_keys):
456-
self._group_keys = group_keys
457-
self._group_summaries= {}
458-
459-
def add(self, meta):
460-
pass
461-
462-
def groups(self):
463-
'''Generate the groups and their meta summaries'''
511+
def nd_sort(self, dims):
512+
'''Produce linear indices to fill nD array as specified by `dims`
464513
514+
Assumes each input corresponds to a 2D or 3D array, and the combined
515+
array is 3D+
516+
'''
517+
# Make sure dims aren't completely invalid
518+
if len(dims) == 0:
519+
raise ValueError("At least one dimension must be specified")
520+
last_dim = None
521+
for dim in dims:
522+
if last_dim is not None:
523+
if last_dim.dim_type > dim.dim_type:
524+
# TODO: This only allows PARAM dimensions at the end, which I guess is reasonable?
525+
raise ValueError("Invalid dimension order")
526+
elif last_dim.dim_type == dim.dim_type and dim.dim_type != DimTypes.PARAM:
527+
raise ValueError("There can be at most one each of SLICE and TIME dimensions")
528+
last_dim = dim
529+
530+
# Pull out info about different types of dims
531+
n_slices = None
532+
n_vol = None
533+
time_dim = None
534+
param_dims = []
535+
n_params = []
536+
total_params = 1
537+
shape = []
538+
curr_size = 1
539+
for dim in dims:
540+
dim_vidx = self._v_idxs[dim.key]
541+
dim_type = dim.dim_type
542+
if dim_type is DimTypes.SLICE:
543+
n_slices = len(dim_vidx)
544+
n_vol = dim_vidx.get_block_size()
545+
if n_vol is None:
546+
raise NdSortError("There are missing or extra slices")
547+
shape.append(n_slices)
548+
curr_size *= n_slices
549+
elif dim_type is DimTypes.TIME:
550+
time_dim = dim
551+
elif dim_type is DimTypes.PARAM:
552+
if dim_vidx.get_block_size() is None:
553+
raise NdSortError(f"The parameter {dim.key} doesn't evenly divide inputs")
554+
param_dims.append(dim)
555+
n_param = len(dim_vidx)
556+
n_params.append(n_param)
557+
total_params *= n_param
558+
if n_vol is None:
559+
n_vol = self._n_input
560+
561+
# Size of the time dimension must be infered from the size of the other dims
562+
n_time = 1
563+
if time_dim is not None:
564+
n_time, rem = divmod(n_vol, total_params)
565+
if rem != 0:
566+
raise NdSortError(f"The combined parameters don't evenly divide inputs")
567+
shape.append(n_time)
568+
curr_size *= n_time
569+
570+
# Complete the "shape", and do a more detailed check that our param dims make sense
571+
for dim, n_param in zip(param_dims, n_params):
572+
dim_vidx = self._v_idxs[dim.key]
573+
if dim_vidx.get_block_size() != curr_size:
574+
raise NdSortError(f"The parameter {dim.key} doesn't evenly divide inputs")
575+
shape.append(n_param)
576+
curr_size *= n_param
577+
578+
# Extract dim keys for each input and do the actual sort
579+
sort_keys = [(idx, tuple(self.get_val(idx, dim.key) for dim in reversed(dims)))
580+
for idx in range(self._n_input)]
581+
sort_keys.sort(key=lambda x: x[1])
582+
583+
# TODO: Finish this

0 commit comments

Comments
 (0)