Skip to content

Commit ad68c7f

Browse files
committed
Exclude backend incompatible rewrites in Scan dispatch
1 parent 8dc67ea commit ad68c7f

File tree

3 files changed

+15
-4
lines changed

3 files changed

+15
-4
lines changed

pytensor/link/jax/dispatch/scan.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import jax
22
import jax.numpy as jnp
33

4+
from pytensor.compile.mode import JAX
45
from pytensor.link.jax.dispatch.basic import jax_funcify
56
from pytensor.scan.op import Scan
67

@@ -17,8 +18,8 @@ def jax_funcify_Scan(op: Scan, **kwargs):
1718
"Scan with MIT-MOT (gradients of scan) cannot yet be converted to JAX"
1819
)
1920

20-
# Optimize inner graph
21-
rewriter = op.mode_instance.optimizer
21+
# Optimize inner graph (exclude any defalut rewrites that are incompatible with JAX mode)
22+
rewriter = op.mode_instance.excluding(*JAX._optimizer.exclude).optimizer
2223
rewriter(op.fgraph)
2324
scan_inner_func = jax_funcify(op.fgraph, **kwargs)
2425

pytensor/link/numba/dispatch/scan.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from numba import types
66
from numba.extending import overload
77

8+
from pytensor.compile.mode import NUMBA
89
from pytensor.link.numba.dispatch import basic as numba_basic
910
from pytensor.link.numba.dispatch.basic import (
1011
create_arg_string,
@@ -58,7 +59,7 @@ def numba_funcify_Scan(op, node, **kwargs):
5859
# TODO: Not sure this is the right place to do this, should we have a rewrite that
5960
# explicitly triggers the optimization of the inner graphs of Scan?
6061
# The C-code defers it to the make_thunk phase
61-
rewriter = op.mode_instance.optimizer
62+
rewriter = op.mode_instance.excluding(*NUMBA._optimizer.exclude).optimizer
6263
rewriter(op.fgraph)
6364

6465
scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph))

tests/link/jax/test_scan.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from pytensor.scan.op import Scan
1414
from pytensor.tensor import random
1515
from pytensor.tensor.math import gammaln, log
16-
from pytensor.tensor.type import dmatrix, dvector, lscalar, scalar, vector
16+
from pytensor.tensor.type import dmatrix, dvector, lscalar, matrix, scalar, vector
1717
from tests.link.jax.test_basic import compare_jax_and_py
1818

1919

@@ -418,3 +418,12 @@ def step(x, A):
418418

419419
test_input_vals = [x0_val, A_val]
420420
compare_jax_and_py(fg, test_input_vals)
421+
422+
423+
def test_default_mode_excludes_incompatible_rewrites():
424+
# See issue #426
425+
A = matrix("A")
426+
B = matrix("B")
427+
out, _ = scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2)
428+
fg = FunctionGraph([A, B], [out])
429+
compare_jax_and_py(fg, [np.eye(3), np.eye(3)])

0 commit comments

Comments
 (0)