diff --git a/pytensor/link/jax/dispatch/scalar.py b/pytensor/link/jax/dispatch/scalar.py index 81f2d6c047..bb63bf3135 100644 --- a/pytensor/link/jax/dispatch/scalar.py +++ b/pytensor/link/jax/dispatch/scalar.py @@ -181,7 +181,8 @@ def composite(*args): @jax_funcify.register(Second) def jax_funcify_Second(op, **kwargs): def second(x, y): - return jnp.broadcast_to(y, x.shape) + _, y = jnp.broadcast_arrays(x, y) + return y return second diff --git a/tests/link/jax/test_scalar.py b/tests/link/jax/test_scalar.py index 260f198944..8d428f450f 100644 --- a/tests/link/jax/test_scalar.py +++ b/tests/link/jax/test_scalar.py @@ -25,6 +25,7 @@ jax = pytest.importorskip("jax") +from pytensor.link.jax.dispatch import jax_funcify def test_second(): @@ -40,6 +41,25 @@ def test_second(): fgraph = FunctionGraph([a1, b], [out]) compare_jax_and_py(fgraph, [np.zeros([5], dtype=config.floatX), 5.0]) + a2 = matrix("a2", shape=(1, None), dtype="float64") + b2 = matrix("b2", shape=(None, 1), dtype="int32") + out = at.second(a2, b2) + fgraph = FunctionGraph([a2, b2], [out]) + compare_jax_and_py( + fgraph, [np.zeros((1, 3), dtype="float64"), np.ones((5, 1), dtype="int32")] + ) + + +def test_second_constant_scalar(): + b = scalar("b", dtype="int") + out = at.second(0.0, b) + fgraph = FunctionGraph([b], [out]) + # Test dispatch directly as useless second is removed during compilation + fn = jax_funcify(fgraph) + [res] = fn(1) + assert res == 1 + assert res.dtype == out.dtype + def test_identity(): a = scalar("a")