diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index aa5b911552..ee1623e822 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -139,7 +139,7 @@ jobs: - name: Install dependencies shell: bash -l {0} run: | - mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy + mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl "numpy<1.26" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy # numba-scipy downgrades the installed scipy to 1.7.3 in Python 3.9, but # not numpy, even though scipy 1.7 requires numpy<1.23. When installing # PyTensor next, pip installs a lower version of numpy via the PyPI. diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index db80f89c47..98b0506d6e 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -1898,7 +1898,7 @@ class Mul(ScalarOp): commutative = True associative = True nfunc_spec = ("multiply", 2, 1) - nfunc_variadic = "product" + nfunc_variadic = "prod" def impl(self, *inputs): return np.prod(inputs) diff --git a/tests/link/jax/test_elemwise.py b/tests/link/jax/test_elemwise.py index 7d13d1aaf7..e0002d3873 100644 --- a/tests/link/jax/test_elemwise.py +++ b/tests/link/jax/test_elemwise.py @@ -13,7 +13,7 @@ from pytensor.tensor.math import prod from pytensor.tensor.math import sum as at_sum from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax -from pytensor.tensor.type import matrix, tensor, vector +from pytensor.tensor.type import matrix, tensor, vector, vectors from tests.link.jax.test_basic import compare_jax_and_py from tests.tensor.test_elemwise import TestElemwise @@ -129,3 +129,11 @@ def test_logsumexp_benchmark(size, axis, benchmark): exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True) np.testing.assert_array_almost_equal(res, exp_res) + + +def test_multiple_input_multiply(): + x, y, z = vectors("xyz") + out = at.mul(x, y, z) + + fg = FunctionGraph(outputs=[out], clone=False) + compare_jax_and_py(fg, [[1.5], [2.5], [3.5]])