From 32eff42241b4c686ce32b17f19b8592853b84ad2 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Fri, 10 Feb 2023 20:10:00 -0600 Subject: [PATCH 1/2] Fix empty numba DimShuffle --- pytensor/link/numba/dispatch/elemwise.py | 2 +- tests/link/numba/test_elemwise.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index aad741c67d..256b6bb63d 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -843,7 +843,7 @@ def dimshuffle_inner(x, shuffle): @numba_basic.numba_njit def dimshuffle_inner(x, shuffle): - return x.item() + return np.reshape(x, ()) # Without the following wrapper function we would see this error: # E No implementation of function Function() found for signature: diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 8fbf026e11..fee49a85da 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -211,6 +211,14 @@ def test_Dimshuffle(v, new_order): ) +def test_Dimshuffle_returns_array(): + x = at.vector("x", shape=(1,)) + y = 2 * at_elemwise.DimShuffle([True], [])(x) + func = pytensor.function([x], y, mode="NUMBA") + out = func(np.zeros(1)) + assert out.ndim == 0 + + @pytest.mark.parametrize( "careduce_fn, axis, v", [ From 6b40c6efaba25b570f58d8a5a653f37aaba82ec1 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 13 Feb 2023 10:11:00 -0600 Subject: [PATCH 2/2] Fix floatX issue in test --- tests/link/numba/test_elemwise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index fee49a85da..54778bda92 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -215,7 +215,7 @@ def test_Dimshuffle_returns_array(): x = at.vector("x", shape=(1,)) y = 2 * at_elemwise.DimShuffle([True], [])(x) func = pytensor.function([x], y, mode="NUMBA") - out = func(np.zeros(1)) + out = func(np.zeros(1, dtype=config.floatX)) assert out.ndim == 0