Skip to content

Commit da66c2e

Browse files
ricardoV94twiecki
authored andcommitted
Fix SoftmaxGrad failure with constant dy in numba backend
1 parent fc5e10f commit da66c2e

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -925,7 +925,12 @@ def softmax_grad_py_fn(dy, sm):
925925
dx = dy_times_sm - sum_dy_times_sm * sm
926926
return dx
927927

928-
softmax_grad = jit_compile_reducer(node, softmax_grad_py_fn)
928+
# The signature inferred by jit_compile_reducer is wrong when dy is a constant (readonly=True)
929+
# softmax_grad = jit_compile_reducer(node, softmax_grad_py_fn)
930+
softmax_grad = numba_njit(
931+
boundscheck=False,
932+
fastmath=config.numba__fastmath,
933+
)(softmax_grad_py_fn)
929934

930935
return softmax_grad
931936

tests/link/numba/test_elemwise.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,16 @@ def test_SoftmaxGrad(dy, sm, axis, exc):
445445
)
446446

447447

448+
def test_SoftMaxGrad_constant_dy():
449+
dy = at.constant(np.zeros((3,), dtype=config.floatX))
450+
sm = at.vector(shape=(3,))
451+
452+
g = SoftmaxGrad(axis=None)(dy, sm)
453+
g_fg = FunctionGraph(outputs=[g])
454+
455+
compare_numba_and_py(g_fg, [np.ones((3,), dtype=config.floatX)])
456+
457+
448458
@pytest.mark.parametrize(
449459
"x, axis, exc",
450460
[

0 commit comments

Comments
 (0)