Skip to content

Commit 2cef9c0

Browse files
dehorsleyricardoV94
authored andcommitted
add Exponentially scaled modified Bessel function
Fixes #542
1 parent 68b41a4 commit 2cef9c0

File tree

5 files changed

+80
-0
lines changed

5 files changed

+80
-0
lines changed

pytensor/link/jax/dispatch/scalar.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
Erfcx,
2828
Erfinv,
2929
Iv,
30+
Ive,
3031
Log1mexp,
3132
Psi,
3233
TriGamma,
@@ -267,6 +268,13 @@ def iv(v, x):
267268
return iv
268269

269270

271+
@jax_funcify.register(Ive)
272+
def jax_funcify_Ive(op, **kwargs):
273+
ive = try_import_tfp_jax_op(op, jax_op_name="bessel_ive")
274+
275+
return ive
276+
277+
270278
@jax_funcify.register(Log1mexp)
271279
def jax_funcify_Log1mexp(op, node, **kwargs):
272280
def log1mexp(x):

pytensor/scalar/math.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,6 +1197,37 @@ def c_code(self, *args, **kwargs):
11971197
i0 = I0(upgrade_to_float, name="i0")
11981198

11991199

1200+
class Ive(BinaryScalarOp):
1201+
"""
1202+
Exponentially scaled modified Bessel function of the first kind of order v (real).
1203+
"""
1204+
1205+
nfunc_spec = ("scipy.special.ive", 2, 1)
1206+
1207+
@staticmethod
1208+
def st_impl(v, x):
1209+
return scipy.special.ive(v, x)
1210+
1211+
def impl(self, v, x):
1212+
return self.st_impl(v, x)
1213+
1214+
def grad(self, inputs, grads):
1215+
v, x = inputs
1216+
(gz,) = grads
1217+
return [
1218+
grad_not_implemented(self, 0, v),
1219+
gz
1220+
* (ive(v - 1, x) - 2.0 * _unsafe_sign(x) * ive(v, x) + ive(v + 1, x))
1221+
/ 2.0,
1222+
]
1223+
1224+
def c_code(self, *args, **kwargs):
1225+
raise NotImplementedError()
1226+
1227+
1228+
ive = Ive(upgrade_to_float, name="ive")
1229+
1230+
12001231
class Sigmoid(UnaryScalarOp):
12011232
"""
12021233
Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit

pytensor/tensor/inplace.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,11 @@ def iv_inplace(v, x):
313313
"""Modified Bessel function of the first kind of order v (real)."""
314314

315315

316+
@scalar_elemwise
317+
def ive_inplace(v, x):
318+
"""Exponentially scaled modified Bessel function of the first kind of order v (real)."""
319+
320+
316321
@scalar_elemwise
317322
def sigmoid_inplace(x):
318323
"""Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit"""

pytensor/tensor/math.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1435,6 +1435,11 @@ def iv(v, x):
14351435
"""Modified Bessel function of the first kind of order v (real)."""
14361436

14371437

1438+
@scalar_elemwise
1439+
def ive(v, x):
1440+
"""Exponentially scaled modified Bessel function of the first kind of order v (real)."""
1441+
1442+
14381443
@scalar_elemwise
14391444
def sigmoid(x):
14401445
"""Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit"""
@@ -3039,6 +3044,7 @@ def vectorize_node_to_matmul(op, node, batched_x, batched_y):
30393044
"i0",
30403045
"i1",
30413046
"iv",
3047+
"ive",
30423048
"sigmoid",
30433049
"expit",
30443050
"softplus",

tests/tensor/test_math_scipy.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def scipy_special_gammal(k, x):
7575
expected_i0 = scipy.special.i0
7676
expected_i1 = scipy.special.i1
7777
expected_iv = scipy.special.iv
78+
expected_ive = scipy.special.ive
7879
expected_erfcx = scipy.special.erfcx
7980
expected_sigmoid = scipy.special.expit
8081
expected_hyp2f1 = scipy.special.hyp2f1
@@ -639,6 +640,23 @@ def fixed_first_input_jv(x):
639640
inplace=True,
640641
)
641642

643+
TestIveBroadcast = makeBroadcastTester(
644+
op=at.ive,
645+
expected=expected_ive,
646+
good=_good_broadcast_binary_bessel,
647+
eps=2e-10,
648+
mode=mode_no_scipy,
649+
)
650+
651+
TestIveInplaceBroadcast = makeBroadcastTester(
652+
op=inplace.ive_inplace,
653+
expected=expected_ive,
654+
good=_good_broadcast_binary_bessel,
655+
eps=2e-10,
656+
mode=mode_no_scipy,
657+
inplace=True,
658+
)
659+
642660

643661
def test_verify_iv_grad():
644662
# Verify Iv gradient.
@@ -652,6 +670,18 @@ def fixed_first_input_iv(x):
652670
utt.verify_grad(fixed_first_input_iv, [x_val])
653671

654672

673+
def test_verify_ive_grad():
674+
# Verify Ive gradient.
675+
# Implemented separately due to need to fix first input for which grad is
676+
# not defined.
677+
v_val, x_val = _grad_broadcast_binary_bessel["normal"]
678+
679+
def fixed_first_input_ive(x):
680+
return at.ive(v_val, x)
681+
682+
utt.verify_grad(fixed_first_input_ive, [x_val])
683+
684+
655685
TestSigmoidBroadcast = makeBroadcastTester(
656686
op=at.sigmoid,
657687
expected=expected_sigmoid,

0 commit comments

Comments
 (0)