diff --git a/pytensor/link/jax/dispatch/scalar.py b/pytensor/link/jax/dispatch/scalar.py index 8709633da1..45b3c34c73 100644 --- a/pytensor/link/jax/dispatch/scalar.py +++ b/pytensor/link/jax/dispatch/scalar.py @@ -27,6 +27,7 @@ Erfcx, Erfinv, Iv, + Ive, Log1mexp, Psi, TriGamma, @@ -267,6 +268,13 @@ def iv(v, x): return iv +@jax_funcify.register(Ive) +def jax_funcify_Ive(op, **kwargs): + ive = try_import_tfp_jax_op(op, jax_op_name="bessel_ive") + + return ive + + @jax_funcify.register(Log1mexp) def jax_funcify_Log1mexp(op, node, **kwargs): def log1mexp(x): diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index 1f326b3fab..2ab97ff122 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -1197,6 +1197,37 @@ def c_code(self, *args, **kwargs): i0 = I0(upgrade_to_float, name="i0") +class Ive(BinaryScalarOp): + """ + Exponentially scaled modified Bessel function of the first kind of order v (real). + """ + + nfunc_spec = ("scipy.special.ive", 2, 1) + + @staticmethod + def st_impl(v, x): + return scipy.special.ive(v, x) + + def impl(self, v, x): + return self.st_impl(v, x) + + def grad(self, inputs, grads): + v, x = inputs + (gz,) = grads + return [ + grad_not_implemented(self, 0, v), + gz + * (ive(v - 1, x) - 2.0 * _unsafe_sign(x) * ive(v, x) + ive(v + 1, x)) + / 2.0, + ] + + def c_code(self, *args, **kwargs): + raise NotImplementedError() + + +ive = Ive(upgrade_to_float, name="ive") + + class Sigmoid(UnaryScalarOp): """ Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit diff --git a/pytensor/tensor/inplace.py b/pytensor/tensor/inplace.py index afb7f7ac7c..5f78f54615 100644 --- a/pytensor/tensor/inplace.py +++ b/pytensor/tensor/inplace.py @@ -313,6 +313,11 @@ def iv_inplace(v, x): """Modified Bessel function of the first kind of order v (real).""" +@scalar_elemwise +def ive_inplace(v, x): + """Exponentially scaled modified Bessel function of the first kind of order v (real).""" + + @scalar_elemwise def sigmoid_inplace(x): """Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit""" diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 7d1e32ba21..6fa0065ce2 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -1435,6 +1435,11 @@ def iv(v, x): """Modified Bessel function of the first kind of order v (real).""" +@scalar_elemwise +def ive(v, x): + """Exponentially scaled modified Bessel function of the first kind of order v (real).""" + + @scalar_elemwise def sigmoid(x): """Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit""" @@ -3039,6 +3044,7 @@ def vectorize_node_to_matmul(op, node, batched_x, batched_y): "i0", "i1", "iv", + "ive", "sigmoid", "expit", "softplus", diff --git a/tests/tensor/test_math_scipy.py b/tests/tensor/test_math_scipy.py index 11e0e1730a..2d4c52282a 100644 --- a/tests/tensor/test_math_scipy.py +++ b/tests/tensor/test_math_scipy.py @@ -75,6 +75,7 @@ def scipy_special_gammal(k, x): expected_i0 = scipy.special.i0 expected_i1 = scipy.special.i1 expected_iv = scipy.special.iv +expected_ive = scipy.special.ive expected_erfcx = scipy.special.erfcx expected_sigmoid = scipy.special.expit expected_hyp2f1 = scipy.special.hyp2f1 @@ -639,6 +640,23 @@ def fixed_first_input_jv(x): inplace=True, ) +TestIveBroadcast = makeBroadcastTester( + op=at.ive, + expected=expected_ive, + good=_good_broadcast_binary_bessel, + eps=2e-10, + mode=mode_no_scipy, +) + +TestIveInplaceBroadcast = makeBroadcastTester( + op=inplace.ive_inplace, + expected=expected_ive, + good=_good_broadcast_binary_bessel, + eps=2e-10, + mode=mode_no_scipy, + inplace=True, +) + def test_verify_iv_grad(): # Verify Iv gradient. @@ -652,6 +670,18 @@ def fixed_first_input_iv(x): utt.verify_grad(fixed_first_input_iv, [x_val]) +def test_verify_ive_grad(): + # Verify Ive gradient. + # Implemented separately due to need to fix first input for which grad is + # not defined. + v_val, x_val = _grad_broadcast_binary_bessel["normal"] + + def fixed_first_input_ive(x): + return at.ive(v_val, x) + + utt.verify_grad(fixed_first_input_ive, [x_val]) + + TestSigmoidBroadcast = makeBroadcastTester( op=at.sigmoid, expected=expected_sigmoid,