Skip to content

RFC: expand_dims for tuple of axes #105

Closed as not planned
Closed as not planned
@izaid

Description

@izaid

Hello all! I've just spent time converting a reasonably large code base to use array-api-compat. I was pleasantly surprised that nearly everything worked, but there was one particular function that didn't map smoothly. And that was expand_dims. The issue is not necessarily the problem of array-api-compat, but I thought I'd start here.

In the array API, expand_dims supports only a single axis (https://data-apis.org/array-api/latest/API_specification/generated/array_api.expand_dims.html) as opposed to a tuple of axes. This is different to NumPy, CuPy, and Jax, which support a tuple of axes. PyTorch, however, supports only a single axis. I don't know the justification for why the array API supports only a single axis as opposed to a tuple, but what it means is that expand_dims no longer works in many places.

In practice, expand_dims is just a light wrapper for reshape, see https://github.com/numpy/numpy/blob/3b246c6488cf246d488bbe5726ca58dc26b6ea74/numpy/lib/_shape_base_impl.py#L594. But I'd really rather not have to write my own version of expand_dims in every library now. Would array_api_compat be willing to provide a non-strict version of expand_dims that still supports a tuple of axes? Or has there been a clear discussion and decision that expand_dims will only support a single axis going forward, effectively making all users of expand_dims copy and paste the NumPy implementation?

Many thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions