Description
Our users relatively often need to extract a few labels from an axis as another (possibly existing) axis with different labels. Currently,
they usually use set_axes for this, as in:
>>> arr = ndtest(5)
>>> arr
a a0 a1 a2 a3 a4
0 1 2 3 4
>>> b = Axis('b=b1,b2')
>>> arr['a1,a3'].set_axes('a', b)
b b1 b2
1.0 3.0
But given that the axis definition is often far from the set_axes call (even in a different file), there is a high risk (and I have witnessed this actually happening a few times) to get the labels order wrong, which is a pity given it is one of the missions of LArray to prevent that class of errors.
The alternative I recommend is to use set_labels with a map, but then the original labels are specified twice. Unsure if that is the reason, but our users are generally not very enthusiatic about this recommendation.
>>> arr['a1,a3'].set_labels({'a1': 'b1', 'a3': 'b2'}).rename('a', 'b')
b b1 b2
1.0 3.0
I wonder if introducing a new "extract" method would help with this:
>>> arr.extract({'a1': 'b1', 'a3': 'b2'}, 'b')
b b1 b2
1.0 3.0
>>> # works with a predefined axis too
>>> arr.extract({'a3': 'b2', 'a1': 'b1'}, b)
b b1 b2
1.0 3.0
Here is a quick and dirty implementation I did for testing:
def extract(array, label_map, axis=None):
orig_keys = list(label_map.keys())
subset = array.axes._guess_axis(orig_keys)
old_axis = subset.axis
array = array[subset].set_labels(old_axis, label_map)
if axis is not None:
array = array.rename(old_axis, axis)
if isinstance(axis, Axis):
array = array.reindex(axis, axis)
return array
Another option, would be to generalize aggregate methods to be able to explicitly name the new aggregated axis (see #1002), which we probably need to implement anyway.
Currently, the above "extract" test can also be spelled like:
arr.sum('a1 >> b1;a3 >> b2').rename('a', 'b')
and it would be nice to be able to express it like this instead:
arr.sum(b='a1 >> b1;a3 >> b2')
OR (unsure which, or both):
arr.sum('b=a1 >> b1;a3 >> b2')
One final option would be to have an extract method using the same syntax than aggregate methods instead of a dict:
arr.extract('b=a1 >> b1;a3 >> b2')
In either case, this does not support the existing axis usecase. I have not found a way to express that nicely yet.