Description
Hello all! I've just spent time converting a reasonably large code base to use array-api-compat. I was pleasantly surprised that nearly everything worked, but there was one particular function that didn't map smoothly. And that was expand_dims
. The issue is not necessarily the problem of array-api-compat, but I thought I'd start here.
In the array API, expand_dims
supports only a single axis (https://data-apis.org/array-api/latest/API_specification/generated/array_api.expand_dims.html) as opposed to a tuple of axes. This is different to NumPy, CuPy, and Jax, which support a tuple of axes. PyTorch, however, supports only a single axis. I don't know the justification for why the array API supports only a single axis as opposed to a tuple, but what it means is that expand_dims
no longer works in many places.
In practice, expand_dims
is just a light wrapper for reshape
, see https://github.com/numpy/numpy/blob/3b246c6488cf246d488bbe5726ca58dc26b6ea74/numpy/lib/_shape_base_impl.py#L594. But I'd really rather not have to write my own version of expand_dims in every library now. Would array_api_compat
be willing to provide a non-strict version of expand_dims
that still supports a tuple of axes? Or has there been a clear discussion and decision that expand_dims
will only support a single axis going forward, effectively making all users of expand_dims
copy and paste the NumPy implementation?
Many thanks!