Skip to content

add Array.extract? #1001

Open
Open
@gdementen

Description

@gdementen

Our users relatively often need to extract a few labels from an axis as another (possibly existing) axis with different labels. Currently,
they usually use set_axes for this, as in:

>>> arr = ndtest(5)
>>> arr
a  a0  a1  a2  a3  a4
    0   1   2   3   4
>>> b = Axis('b=b1,b2')
>>> arr['a1,a3'].set_axes('a', b)
b   b1   b2
   1.0  3.0

But given that the axis definition is often far from the set_axes call (even in a different file), there is a high risk (and I have witnessed this actually happening a few times) to get the labels order wrong, which is a pity given it is one of the missions of LArray to prevent that class of errors.

The alternative I recommend is to use set_labels with a map, but then the original labels are specified twice. Unsure if that is the reason, but our users are generally not very enthusiatic about this recommendation.

>>> arr['a1,a3'].set_labels({'a1': 'b1', 'a3': 'b2'}).rename('a', 'b')
b   b1   b2
   1.0  3.0

I wonder if introducing a new "extract" method would help with this:

>>> arr.extract({'a1': 'b1', 'a3': 'b2'}, 'b')
b   b1   b2
   1.0  3.0
>>> # works with a predefined axis too
>>> arr.extract({'a3': 'b2', 'a1': 'b1'}, b)
b   b1   b2
   1.0  3.0

Here is a quick and dirty implementation I did for testing:

def extract(array, label_map, axis=None):
    orig_keys = list(label_map.keys())
    subset = array.axes._guess_axis(orig_keys)
    old_axis = subset.axis
    array = array[subset].set_labels(old_axis, label_map)
    if axis is not None:
        array = array.rename(old_axis, axis)
        if isinstance(axis, Axis):
            array = array.reindex(axis, axis)
    return array

Another option, would be to generalize aggregate methods to be able to explicitly name the new aggregated axis (see #1002), which we probably need to implement anyway.

Currently, the above "extract" test can also be spelled like:

arr.sum('a1 >> b1;a3 >> b2').rename('a', 'b')

and it would be nice to be able to express it like this instead:

arr.sum(b='a1 >> b1;a3 >> b2')

OR (unsure which, or both):

arr.sum('b=a1 >> b1;a3 >> b2')

One final option would be to have an extract method using the same syntax than aggregate methods instead of a dict:

arr.extract('b=a1 >> b1;a3 >> b2')

In either case, this does not support the existing axis usecase. I have not found a way to express that nicely yet.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions