6
6
# copyright and license terms.
7
7
#
8
8
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
9
- '''Aggregate information for mutliple images
9
+ '''Memory efficient tracking of meta data dicts with repeating elements
10
10
'''
11
+ from dataclasses import dataclass
12
+ from enum import IntEnum
13
+
11
14
from bitarray import bitarray , frozenbitarray
12
- from bitarry . utils import zeroes
15
+ from bitarray . util import zeros
13
16
14
17
15
- class FloatCanon ( object ) :
18
+ class FloatCanon :
16
19
'''Look up a canonical float that we compare equal to'''
20
+
17
21
def __init__ (self , n_digits = 6 ):
18
22
self ._n_digits = n_digits
19
23
self ._offset = 0.5 * (10 ** - n_digits )
@@ -39,7 +43,9 @@ def get(self, val):
39
43
40
44
# TODO: Integrate some value canonicalization filtering? Or just require the
41
45
# user to do that themselves?
42
- class ValueIndices (object ):
46
+
47
+
48
+ class ValueIndices :
43
49
"""Track indices of values in sequence.
44
50
45
51
If values repeat frequently then memory usage can be dramatically improved.
@@ -114,19 +120,31 @@ def get_mask(self, value):
114
120
return res
115
121
idx = self ._unique_vals .get (value )
116
122
if idx is not None :
117
- res = zeroes (self ._n_inpuf )
123
+ res = zeros (self ._n_inpuf )
118
124
res [idx ] = 1
119
125
return res
120
126
return self ._val_bitarrs [value ].copy ()
121
127
122
- def num_indices (self , value ):
128
+ def num_indices (self , value , mask = None ):
123
129
'''Number of indices for the given `value`'''
130
+ if mask is not None :
131
+ if len (mask ) != self .n_input :
132
+ raise ValueError ("Mask length must match input length" )
124
133
if self ._const_val is not _NoValue :
125
134
if self ._const_val != value :
126
135
raise KeyError ()
127
- return self ._n_input
128
- if value in self ._unique_vals :
136
+ if mask is None :
137
+ return self ._n_input
138
+ return mask .count ()
139
+ unique_idx = self ._unique_vals .get (_NoValue )
140
+ if unique_idx is not _NoValue :
141
+ if mask is not None :
142
+ if mask [unique_idx ]:
143
+ return 1
144
+ return 0
129
145
return 1
146
+ if mask is not None :
147
+ return (self ._val_bitarrs [value ] & mask ).count
130
148
return self ._val_bitarrs [value ].count ()
131
149
132
150
def get_value (self , idx ):
@@ -138,13 +156,17 @@ def get_value(self, idx):
138
156
for val , vidx in self ._unique_vals .items ():
139
157
if vidx == idx :
140
158
return val
141
- bit_idx = zeroes (self ._n_input )
159
+ bit_idx = zeros (self ._n_input )
142
160
bit_idx [idx ] = 1
143
161
for val , ba in self ._val_bitarrs .items ():
144
- if (ba | bit_idx ).any ():
162
+ if (ba & bit_idx ).any ():
145
163
return val
146
164
assert False
147
165
166
+ def to_list (self ):
167
+ '''Convert back to a list of values'''
168
+ return [self .get_value (i ) for i in range (self .n_input )]
169
+
148
170
def extend (self , values ):
149
171
'''Add more values to the end of any existing ones'''
150
172
curr_size = self ._n_input
@@ -156,7 +178,7 @@ def extend(self, values):
156
178
other_size = len (values )
157
179
final_size = curr_size + other_size
158
180
for ba in self ._val_bitarrs .values ():
159
- ba .extend (zeroes (other_size ))
181
+ ba .extend (zeros (other_size ))
160
182
if other_is_vi :
161
183
if self ._const_val is not _NoValue :
162
184
if values ._const_val is not _NoValue :
@@ -186,10 +208,10 @@ def extend(self, values):
186
208
if curr_size == 0 :
187
209
new_ba = other_ba .copy ()
188
210
else :
189
- new_ba = zeroes (curr_size )
211
+ new_ba = zeros (curr_size )
190
212
new_ba .extend (other_ba )
191
213
else :
192
- new_ba = zeroes (curr_size )
214
+ new_ba = zeros (curr_size )
193
215
new_ba [curr_idx ] = True
194
216
new_ba .extend (other_ba )
195
217
del self ._unique_vals [val ]
@@ -221,13 +243,20 @@ def append(self, value):
221
243
if curr_idx is None :
222
244
self ._unique_vals [value ] = curr_size
223
245
else :
224
- new_ba = zeroes (curr_size + 1 )
246
+ new_ba = zeros (curr_size + 1 )
225
247
new_ba [curr_idx ] = True
226
248
new_ba [curr_size ] = True
227
249
self ._val_bitarrs [value ] = new_ba
228
250
del self ._unique_vals [value ]
229
251
self ._n_input += 1
230
252
253
+ def reverse (self ):
254
+ '''Reverse the indices in place'''
255
+ for val , idx in self ._unique_vals .items ():
256
+ self ._unique_vals [val ] = self ._n_input - idx - 1
257
+ for val , bitarr in self ._val_bitarrs .items ():
258
+ bitarr .reverse ()
259
+
231
260
def argsort (self , reverse = False ):
232
261
'''Return array of indices in order that sorts the values'''
233
262
if self ._const_val is not _NoValue :
@@ -248,6 +277,18 @@ def argsort(self, reverse=False):
248
277
res_idx += 1
249
278
return res
250
279
280
+ def reorder (self , order ):
281
+ '''Reorder the indices in place'''
282
+ if len (order ) != self ._n_input :
283
+ raise ValueError ("The 'order' has the incorrect length" )
284
+ for val , idx in self ._unique_vals .items ():
285
+ self ._unique_vals [val ] = order .index (idx )
286
+ for val , bitarr in self ._val_bitarrs .items ():
287
+ new_ba = zeros (self ._n_input )
288
+ for idx in self ._extract_indices (bitarr ):
289
+ new_ba [order .index (idx )] = True
290
+ self ._val_bitarrs [val ] = new_ba
291
+
251
292
def is_covariant (self , other ):
252
293
'''True if `other` has values that vary the same way ours do
253
294
@@ -267,35 +308,30 @@ def is_covariant(self, other):
267
308
return False
268
309
return True
269
310
270
- def is_blocked (self , block_factor = None ):
271
- '''True if each value has the same number of indices
311
+ def get_block_size (self ):
312
+ '''Return size of even blocks of values, or None if values aren't "blocked"
272
313
273
- If `block_factor` is not None we also test that it evenly divides the
274
- block size .
314
+ The number of values must evenly divide the number of inputs into the block size,
315
+ with each value appearing that same number of times .
275
316
'''
276
317
block_size , rem = divmod (self ._n_input , len (self ))
277
318
if rem != 0 :
278
- return False
279
- if block_factor is not None and block_size % block_factor != 0 :
280
- return False
319
+ return None
281
320
for val in self .values ():
282
321
if self .num_indices (val ) != block_size :
283
- return False
284
- return True
322
+ return None
323
+ return block_size
285
324
286
325
def is_subpartition (self , other ):
287
- '''True if we have more values and they nest within values from other
288
-
289
-
290
- '''
326
+ ''''''
291
327
292
328
def _extract_indices (self , ba ):
293
329
'''Generate integer indices from bitarray representation'''
294
330
start = 0
295
331
while True :
296
332
try :
297
333
# TODO: Is this the most efficient approach?
298
- curr_idx = ba .index (True , start = start )
334
+ curr_idx = ba .index (True , start )
299
335
except ValueError :
300
336
return
301
337
yield curr_idx
@@ -309,10 +345,10 @@ def _ingest_single(self, val, final_size, curr_size, other_idx):
309
345
if curr_idx is None :
310
346
self ._unique_vals [val ] = curr_size + other_idx
311
347
else :
312
- new_ba = zeroes (final_size )
348
+ new_ba = zeros (final_size )
313
349
new_ba [curr_idx ] = True
314
350
new_ba [curr_size + other_idx ] = True
315
- self ._val_bitarrs = new_ba
351
+ self ._val_bitarrs [ val ] = new_ba
316
352
del self ._unique_vals [val ]
317
353
else :
318
354
curr_ba [curr_size + other_idx ] = True
@@ -351,13 +387,33 @@ def _extend_const(self, other):
351
387
_MissingKey = object ()
352
388
353
389
390
+ class DimTypes (IntEnum ):
391
+ '''Enmerate the three types of nD dimensions'''
392
+ SLICE = 1
393
+ TIME = 2
394
+ PARAM = 3
395
+
396
+
397
+ @dataclass
398
+ class DimIndex :
399
+ '''Specify an nD index'''
400
+ dim_type : DimTypes
401
+
402
+ key : str
403
+
404
+
405
+ class NdSortError (Exception ):
406
+ '''Raised when the data cannot be sorted into an nD array as specified'''
407
+
408
+
354
409
class MetaSummary :
355
410
'''Summarize a sequence of dicts, tracking how individual keys vary
356
411
357
412
The assumption is that for any key many values will be constant, or at
358
413
least repeated, and thus we can reduce memory consumption by only storing
359
414
the value once along with the indices it appears at.
360
415
'''
416
+
361
417
def __init__ (self ):
362
418
self ._v_idxs = {}
363
419
self ._n_input = 0
@@ -380,9 +436,6 @@ def append(self, meta):
380
436
self ._v_idxs [key ] = v_idx
381
437
self ._n_input += 1
382
438
383
- def extend (self , metas ):
384
- pass # TODO
385
-
386
439
def keys (self ):
387
440
'''Generate all known keys'''
388
441
return self ._v_idxs .keys ()
@@ -412,20 +465,26 @@ def repeating_keys(self):
412
465
if 1 < len (v_idx ) < n_input :
413
466
yield key
414
467
415
- def repeating_groups (self , block_only = False , block_factor = None ):
416
- '''Generate groups of repeating keys that vary with the same pattern
468
+ def covariant_groups (self , keys = None , block_only = False ):
469
+ '''Generate groups of keys that vary with the same pattern
417
470
'''
418
- n_input = self ._n_input
419
- if n_input <= 1 :
420
- # If there is only one element, consider all keys as const
421
- return
422
- # TODO: Can we sort so grouped v_idxs are sequential?
423
- # - Sort by num values isn't sufficient
424
- curr_group = []
425
- for key , v_idx in self ._v_idxs .items ():
426
- if 1 < len (v_idx ) < n_input :
427
- if v_idx .is_even (block_factor ):
428
- pass # TODO
471
+ if keys is None :
472
+ keys = self .keys ()
473
+ groups = []
474
+ for key in keys :
475
+ v_idx = self ._v_idxs [key ]
476
+ if len (groups ) == 0 :
477
+ groups .append ((key , v_idx ))
478
+ continue
479
+ for group in groups :
480
+ if group [0 ][1 ].is_covariant (v_idx ):
481
+ group .append (key )
482
+ break
483
+ else :
484
+ groups .append ((key , v_idx ))
485
+ for group in groups :
486
+ group [0 ] = group [0 ][0 ]
487
+ return groups
429
488
430
489
def get_meta (self , idx ):
431
490
'''Get the full dict at the given index'''
@@ -439,26 +498,86 @@ def get_meta(self, idx):
439
498
440
499
def get_val (self , idx , key , default = None ):
441
500
'''Get the value at `idx` for the `key`, or return `default``'''
442
- res = self ._v_idxs [key ].get_value (key )
501
+ res = self ._v_idxs [key ].get_value (idx )
443
502
if res is _MissingKey :
444
503
return default
445
504
return res
446
505
447
- def nd_sort (self , dim_keys = None ):
448
- '''Produce indices ordered so as to fill an n-D array'''
506
+ def reorder (self , order ):
507
+ '''Reorder indices in place'''
508
+ for v_idx in self ._v_idxs .values ():
509
+ v_idx .reorder (order )
449
510
450
- class SummaryTree :
451
- '''Groups incoming meta data and creates hierarchy of related groups
452
-
453
- Each leaf node in the tree is a `MetaSummary`
454
- '''
455
- def __init__ (self , group_keys ):
456
- self ._group_keys = group_keys
457
- self ._group_summaries = {}
458
-
459
- def add (self , meta ):
460
- pass
461
-
462
- def groups (self ):
463
- '''Generate the groups and their meta summaries'''
511
+ def nd_sort (self , dims ):
512
+ '''Produce linear indices to fill nD array as specified by `dims`
464
513
514
+ Assumes each input corresponds to a 2D or 3D array, and the combined
515
+ array is 3D+
516
+ '''
517
+ # Make sure dims aren't completely invalid
518
+ if len (dims ) == 0 :
519
+ raise ValueError ("At least one dimension must be specified" )
520
+ last_dim = None
521
+ for dim in dims :
522
+ if last_dim is not None :
523
+ if last_dim .dim_type > dim .dim_type :
524
+ # TODO: This only allows PARAM dimensions at the end, which I guess is reasonable?
525
+ raise ValueError ("Invalid dimension order" )
526
+ elif last_dim .dim_type == dim .dim_type and dim .dim_type != DimTypes .PARAM :
527
+ raise ValueError ("There can be at most one each of SLICE and TIME dimensions" )
528
+ last_dim = dim
529
+
530
+ # Pull out info about different types of dims
531
+ n_slices = None
532
+ n_vol = None
533
+ time_dim = None
534
+ param_dims = []
535
+ n_params = []
536
+ total_params = 1
537
+ shape = []
538
+ curr_size = 1
539
+ for dim in dims :
540
+ dim_vidx = self ._v_idxs [dim .key ]
541
+ dim_type = dim .dim_type
542
+ if dim_type is DimTypes .SLICE :
543
+ n_slices = len (dim_vidx )
544
+ n_vol = dim_vidx .get_block_size ()
545
+ if n_vol is None :
546
+ raise NdSortError ("There are missing or extra slices" )
547
+ shape .append (n_slices )
548
+ curr_size *= n_slices
549
+ elif dim_type is DimTypes .TIME :
550
+ time_dim = dim
551
+ elif dim_type is DimTypes .PARAM :
552
+ if dim_vidx .get_block_size () is None :
553
+ raise NdSortError (f"The parameter { dim .key } doesn't evenly divide inputs" )
554
+ param_dims .append (dim )
555
+ n_param = len (dim_vidx )
556
+ n_params .append (n_param )
557
+ total_params *= n_param
558
+ if n_vol is None :
559
+ n_vol = self ._n_input
560
+
561
+ # Size of the time dimension must be infered from the size of the other dims
562
+ n_time = 1
563
+ if time_dim is not None :
564
+ n_time , rem = divmod (n_vol , total_params )
565
+ if rem != 0 :
566
+ raise NdSortError (f"The combined parameters don't evenly divide inputs" )
567
+ shape .append (n_time )
568
+ curr_size *= n_time
569
+
570
+ # Complete the "shape", and do a more detailed check that our param dims make sense
571
+ for dim , n_param in zip (param_dims , n_params ):
572
+ dim_vidx = self ._v_idxs [dim .key ]
573
+ if dim_vidx .get_block_size () != curr_size :
574
+ raise NdSortError (f"The parameter { dim .key } doesn't evenly divide inputs" )
575
+ shape .append (n_param )
576
+ curr_size *= n_param
577
+
578
+ # Extract dim keys for each input and do the actual sort
579
+ sort_keys = [(idx , tuple (self .get_val (idx , dim .key ) for dim in reversed (dims )))
580
+ for idx in range (self ._n_input )]
581
+ sort_keys .sort (key = lambda x : x [1 ])
582
+
583
+ # TODO: Finish this
0 commit comments