Skip to content

Commit 867d774

Browse files
committed
Temporarily exclude fusion rewrite from Numba Scan tests
Otherwise they fail due to lack of support for multi-output Elemwises in the Numba backend
1 parent 187fd07 commit 867d774

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

tests/link/numba/test_scan.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,8 @@ def f_pow2(x_tm2, x_tm1):
370370
state_val = np.array([1.0, 2.0])
371371

372372
numba_mode = get_mode("NUMBA").including("scan_save_mem")
373+
# multi-output Elemwise not supported in NUMBA
374+
numba_mode = numba_mode.excluding("fusion")
373375
py_mode = Mode("py").including("scan_save_mem")
374376

375377
out_fg = FunctionGraph([init_x, n_steps], [output])
@@ -409,6 +411,8 @@ def inner_fct(seq, state_old, state_current):
409411
g_outs = grad(out.sum(), [seq, init_x])
410412

411413
numba_mode = get_mode("NUMBA").including("scan_save_mem")
414+
# multi-output Elemwise not supported in NUMBA
415+
numba_mode = numba_mode.excluding("fusion")
412416
py_mode = Mode("py").including("scan_save_mem")
413417

414418
out_fg = FunctionGraph([seq, init_x], g_outs)

0 commit comments

Comments
 (0)