diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index aad741c67d..256b6bb63d 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -843,7 +843,7 @@ def dimshuffle_inner(x, shuffle): @numba_basic.numba_njit def dimshuffle_inner(x, shuffle): - return x.item() + return np.reshape(x, ()) # Without the following wrapper function we would see this error: # E No implementation of function Function() found for signature: diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 8fbf026e11..54778bda92 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -211,6 +211,14 @@ def test_Dimshuffle(v, new_order): ) +def test_Dimshuffle_returns_array(): + x = at.vector("x", shape=(1,)) + y = 2 * at_elemwise.DimShuffle([True], [])(x) + func = pytensor.function([x], y, mode="NUMBA") + out = func(np.zeros(1, dtype=config.floatX)) + assert out.ndim == 0 + + @pytest.mark.parametrize( "careduce_fn, axis, v", [