Skip to content

RFC: add array support for the shift kwarg in roll #914

Open
1 of 1 issue completed
Open
@amacati

Description

@amacati

The Array API states that xp.roll accepts two types for the shift argument: ints and tuple[int, ....]. However, it is currently possible to run the following piece of code:

import array_api_strict as xp

a = xp.asarray([1, 2, 3])
b = xp.asarray([1])
print(xp.roll(a, shift=b))  # Type of shift is Array, not tuple[int, ...] or int
# >>> prints [3 1 2]

It is nice that this works, because this allows jitting of the function in jax while using jax arrays as dynamic shifts. However, the spec currently does not guarantee that it works, and therefore I expect the code to fail. Indeed, relying on this behaviour is dangerous. Executing the same snippet with torch yields

import torch
from array_api_compat import array_namespace

xp = array_namespace(torch.tensor(1))
a = xp.asarray([1, 2, 3])
b = xp.asarray([1])
print(xp.roll(a, shift=b))
# Traceback (most recent call last):
#   File "xxx", line 7, in <module>
#     print(xp.roll(a, shift=b))
#           ^^^^^^^^^^^^^^^^^^^
#   File "xxx", line 503, in roll
#     return torch.roll(x, shift, axis, **kwargs)
#            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# RuntimeError: `shifts` required

Using an explicit integer value instead of a tensor as argument does work.

Remedies

There are two ways to fix this:

  1. We can update the spec so that shift accepts the Array type as input
  2. array_api_strict needs to explicitly check the type of xp.roll's kwargs

If the spec is updated to include Arrays, array_api_compat needs to update its torch wrapper for roll such that the example above does not lead to any errors. I would be strongly in favour of accepting arrays (both zero and one-dimensional), because it allows us to do more with xp.roll within jit compiled functions.

Sub-issues

Metadata

Metadata

Assignees

No one assigned

    Labels

    RFCRequest for comments. Feature requests and proposed changes.

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions