Skip to content

Add Exponentially scaled modified Bessel Op #543

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions pytensor/link/jax/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Erfcx,
Erfinv,
Iv,
Ive,
Log1mexp,
Psi,
TriGamma,
Expand Down Expand Up @@ -267,6 +268,13 @@ def iv(v, x):
return iv


@jax_funcify.register(Ive)
def jax_funcify_Ive(op, **kwargs):
ive = try_import_tfp_jax_op(op, jax_op_name="bessel_ive")

return ive


@jax_funcify.register(Log1mexp)
def jax_funcify_Log1mexp(op, node, **kwargs):
def log1mexp(x):
Expand Down
31 changes: 31 additions & 0 deletions pytensor/scalar/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,6 +1197,37 @@ def c_code(self, *args, **kwargs):
i0 = I0(upgrade_to_float, name="i0")


class Ive(BinaryScalarOp):
"""
Exponentially scaled modified Bessel function of the first kind of order v (real).
"""

nfunc_spec = ("scipy.special.ive", 2, 1)

@staticmethod
def st_impl(v, x):
return scipy.special.ive(v, x)

def impl(self, v, x):
return self.st_impl(v, x)

def grad(self, inputs, grads):
v, x = inputs
(gz,) = grads
return [
grad_not_implemented(self, 0, v),
gz
* (ive(v - 1, x) - 2.0 * _unsafe_sign(x) * ive(v, x) + ive(v + 1, x))
/ 2.0,
]

def c_code(self, *args, **kwargs):
raise NotImplementedError()


ive = Ive(upgrade_to_float, name="ive")


class Sigmoid(UnaryScalarOp):
"""
Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit
Expand Down
5 changes: 5 additions & 0 deletions pytensor/tensor/inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,11 @@ def iv_inplace(v, x):
"""Modified Bessel function of the first kind of order v (real)."""


@scalar_elemwise
def ive_inplace(v, x):
"""Exponentially scaled modified Bessel function of the first kind of order v (real)."""


@scalar_elemwise
def sigmoid_inplace(x):
"""Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit"""
Expand Down
6 changes: 6 additions & 0 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,11 @@ def iv(v, x):
"""Modified Bessel function of the first kind of order v (real)."""


@scalar_elemwise
def ive(v, x):
"""Exponentially scaled modified Bessel function of the first kind of order v (real)."""


@scalar_elemwise
def sigmoid(x):
"""Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit"""
Expand Down Expand Up @@ -3039,6 +3044,7 @@ def vectorize_node_to_matmul(op, node, batched_x, batched_y):
"i0",
"i1",
"iv",
"ive",
"sigmoid",
"expit",
"softplus",
Expand Down
30 changes: 30 additions & 0 deletions tests/tensor/test_math_scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def scipy_special_gammal(k, x):
expected_i0 = scipy.special.i0
expected_i1 = scipy.special.i1
expected_iv = scipy.special.iv
expected_ive = scipy.special.ive
expected_erfcx = scipy.special.erfcx
expected_sigmoid = scipy.special.expit
expected_hyp2f1 = scipy.special.hyp2f1
Expand Down Expand Up @@ -639,6 +640,23 @@ def fixed_first_input_jv(x):
inplace=True,
)

TestIveBroadcast = makeBroadcastTester(
op=at.ive,
expected=expected_ive,
good=_good_broadcast_binary_bessel,
eps=2e-10,
mode=mode_no_scipy,
)

TestIveInplaceBroadcast = makeBroadcastTester(
op=inplace.ive_inplace,
expected=expected_ive,
good=_good_broadcast_binary_bessel,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
)


def test_verify_iv_grad():
# Verify Iv gradient.
Expand All @@ -652,6 +670,18 @@ def fixed_first_input_iv(x):
utt.verify_grad(fixed_first_input_iv, [x_val])


def test_verify_ive_grad():
# Verify Ive gradient.
# Implemented separately due to need to fix first input for which grad is
# not defined.
v_val, x_val = _grad_broadcast_binary_bessel["normal"]

def fixed_first_input_ive(x):
return at.ive(v_val, x)

utt.verify_grad(fixed_first_input_ive, [x_val])


TestSigmoidBroadcast = makeBroadcastTester(
op=at.sigmoid,
expected=expected_sigmoid,
Expand Down