Skip to content

Commit 28fc9ac

Browse files
committed
Do not raise early when a Shape operation is an input to Arange in the JAX backend
1 parent 71c58f3 commit 28fc9ac

File tree

3 files changed

+20
-6
lines changed

3 files changed

+20
-6
lines changed

pytensor/link/jax/dispatch/tensor_basic.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
get_underlying_scalar_constant_value,
2222
)
2323
from pytensor.tensor.exceptions import NotScalarConstantError
24+
from pytensor.tensor.shape import Shape_i
2425

2526

2627
ARANGE_CONCRETE_VALUE_ERROR = """JAX requires the arguments of `jax.numpy.arange`
@@ -61,14 +62,20 @@ def jax_funcify_ARange(op, node, **kwargs):
6162
arange_args = node.inputs
6263
constant_args = []
6364
for arg in arange_args:
64-
if not isinstance(arg, Constant):
65+
if arg.owner and isinstance(arg.owner.op, Shape_i):
66+
constant_args.append(None)
67+
elif isinstance(arg, Constant):
68+
constant_args.append(arg.value)
69+
else:
70+
# TODO: This might be failing without need (e.g., if arg = shape(x)[-1] + 1)!
6571
raise NotImplementedError(ARANGE_CONCRETE_VALUE_ERROR)
6672

67-
constant_args.append(arg.value)
68-
69-
start, stop, step = constant_args
73+
constant_start, constant_stop, constant_step = constant_args
7074

71-
def arange(*_):
75+
def arange(start, stop, step):
76+
start = start if constant_start is None else constant_start
77+
stop = stop if constant_stop is None else constant_stop
78+
step = step if constant_step is None else constant_step
7279
return jnp.arange(start, stop, step, dtype=op.dtype)
7380

7481
return arange

tests/link/jax/test_slinalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def test_jax_basic():
8585
],
8686
)
8787

88-
out = at.diag(at.specify_shape(b, shape=(10,)))
88+
out = at.diag(b)
8989
out_fg = FunctionGraph([b], [out])
9090
compare_jax_and_py(out_fg, [np.arange(10).astype(config.floatX)])
9191

tests/link/jax/test_tensor_basic.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ def test_arange():
6363
compare_jax_and_py(fgraph, [])
6464

6565

66+
def test_arange_of_shape():
67+
x = vector("x")
68+
out = at.arange(1, x.shape[-1], 2)
69+
fgraph = FunctionGraph([x], [out])
70+
compare_jax_and_py(fgraph, [np.zeros((5,))])
71+
72+
6673
def test_arange_nonconcrete():
6774
"""JAX cannot JIT-compile `jax.numpy.arange` when arguments are not concrete values."""
6875

0 commit comments

Comments
 (0)