Description
The Array API states that xp.roll
accepts two types for the shift argument: int
s and tuple[int, ....]
. However, it is currently possible to run the following piece of code:
import array_api_strict as xp
a = xp.asarray([1, 2, 3])
b = xp.asarray([1])
print(xp.roll(a, shift=b)) # Type of shift is Array, not tuple[int, ...] or int
# >>> prints [3 1 2]
It is nice that this works, because this allows jitting of the function in jax while using jax arrays as dynamic shifts. However, the spec currently does not guarantee that it works, and therefore I expect the code to fail. Indeed, relying on this behaviour is dangerous. Executing the same snippet with torch yields
import torch
from array_api_compat import array_namespace
xp = array_namespace(torch.tensor(1))
a = xp.asarray([1, 2, 3])
b = xp.asarray([1])
print(xp.roll(a, shift=b))
# Traceback (most recent call last):
# File "xxx", line 7, in <module>
# print(xp.roll(a, shift=b))
# ^^^^^^^^^^^^^^^^^^^
# File "xxx", line 503, in roll
# return torch.roll(x, shift, axis, **kwargs)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# RuntimeError: `shifts` required
Using an explicit integer value instead of a tensor as argument does work.
Remedies
There are two ways to fix this:
- We can update the spec so that shift accepts the Array type as input
array_api_strict
needs to explicitly check the type ofxp.roll
's kwargs
If the spec is updated to include Arrays, array_api_compat
needs to update its torch wrapper for roll
such that the example above does not lead to any errors. I would be strongly in favour of accepting arrays (both zero and one-dimensional), because it allows us to do more with xp.roll within jit compiled functions.