Skip to content

Commit 4fd55b3

Browse files
committed
Do not use deprecated variadic in jax dispatch of multiply
1 parent 566cd7c commit 4fd55b3

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

pytensor/scalar/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1898,7 +1898,7 @@ class Mul(ScalarOp):
18981898
commutative = True
18991899
associative = True
19001900
nfunc_spec = ("multiply", 2, 1)
1901-
nfunc_variadic = "product"
1901+
nfunc_variadic = "prod"
19021902

19031903
def impl(self, *inputs):
19041904
return np.prod(inputs)

tests/link/jax/test_elemwise.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from pytensor.tensor.math import prod
1414
from pytensor.tensor.math import sum as at_sum
1515
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
16-
from pytensor.tensor.type import matrix, tensor, vector
16+
from pytensor.tensor.type import matrix, tensor, vector, vectors
1717
from tests.link.jax.test_basic import compare_jax_and_py
1818
from tests.tensor.test_elemwise import TestElemwise
1919

@@ -129,3 +129,11 @@ def test_logsumexp_benchmark(size, axis, benchmark):
129129

130130
exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True)
131131
np.testing.assert_array_almost_equal(res, exp_res)
132+
133+
134+
def test_multiple_input_multiply():
135+
x, y, z = vectors("xyz")
136+
out = at.mul(x, y, z)
137+
138+
fg = FunctionGraph(outputs=[out], clone=False)
139+
compare_jax_and_py(fg, [[1.5], [2.5], [3.5]])

0 commit comments

Comments
 (0)