Skip to content

Commit 9ada945

Browse files
committed
Provide JAX Ops from Optional tfp dependency
1 parent 8ac8342 commit 9ada945

File tree

3 files changed

+65
-2
lines changed

3 files changed

+65
-2
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ jobs:
145145
# PyTensor next, pip installs a lower version of numpy via the PyPI.
146146
if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION == "3.9" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numpy<1.23" "numba>=0.57" numba-scipy; fi
147147
if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION != "3.9" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57" numba-scipy; fi
148-
if [[ $INSTALL_JAX == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro; fi
148+
if [[ $INSTALL_JAX == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi
149149
pip install -e ./
150150
mamba list && pip freeze
151151
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'

pytensor/link/jax/dispatch/scalar.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import functools
2+
import typing
3+
from typing import Callable, Optional
24

35
import jax
46
import jax.numpy as jnp
@@ -18,7 +20,21 @@
1820
Second,
1921
Sub,
2022
)
21-
from pytensor.scalar.math import Erf, Erfc, Erfinv, Log1mexp, Psi
23+
from pytensor.scalar.math import Erf, Erfc, Erfcinv, Erfcx, Erfinv, Iv, Log1mexp, Psi
24+
25+
26+
def try_import_tfp_jax_op(op: ScalarOp, jax_op_name: Optional[str] = None) -> Callable:
27+
try:
28+
import tensorflow_probability.substrates.jax.math as tfp_jax_math
29+
except ModuleNotFoundError:
30+
raise NotImplementedError(
31+
f"No JAX implementation for Op {op.name}. "
32+
"Implementation is available if TensorFlow Probability is installed"
33+
)
34+
35+
if jax_op_name is None:
36+
jax_op_name = op.name
37+
return typing.cast(Callable, getattr(tfp_jax_math, jax_op_name))
2238

2339

2440
def check_if_inputs_scalars(node):
@@ -211,6 +227,24 @@ def erfinv(x):
211227
return erfinv
212228

213229

230+
@jax_funcify.register(Erfcx)
231+
@jax_funcify.register(Erfcinv)
232+
def jax_funcify_from_tfp(op, **kwargs):
233+
tfp_jax_op = try_import_tfp_jax_op(op)
234+
235+
return tfp_jax_op
236+
237+
238+
@jax_funcify.register(Iv)
239+
def jax_funcify_Iv(op, **kwargs):
240+
ive = try_import_tfp_jax_op(op, jax_op_name="bessel_ive")
241+
242+
def iv(v, x):
243+
return ive(v, x) / jnp.exp(-jnp.abs(jnp.real(x)))
244+
245+
return iv
246+
247+
214248
@jax_funcify.register(Log1mexp)
215249
def jax_funcify_Log1mexp(op, node, **kwargs):
216250
def log1mexp(x):

tests/link/jax/test_scalar.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,17 @@
77
from pytensor.graph.fg import FunctionGraph
88
from pytensor.graph.op import get_test_value
99
from pytensor.scalar.basic import Composite
10+
from pytensor.tensor import as_tensor
1011
from pytensor.tensor.elemwise import Elemwise
1112
from pytensor.tensor.math import all as at_all
1213
from pytensor.tensor.math import (
1314
cosh,
1415
erf,
1516
erfc,
17+
erfcinv,
18+
erfcx,
1619
erfinv,
20+
iv,
1721
log,
1822
log1mexp,
1923
psi,
@@ -28,6 +32,14 @@
2832
from pytensor.link.jax.dispatch import jax_funcify
2933

3034

35+
try:
36+
pass
37+
38+
TFP_INSTALLED = True
39+
except ModuleNotFoundError:
40+
TFP_INSTALLED = False
41+
42+
3143
def test_second():
3244
a0 = scalar("a0")
3345
b = scalar("b")
@@ -134,6 +146,23 @@ def test_erfinv():
134146
compare_jax_and_py(fg, [0.95])
135147

136148

149+
@pytest.mark.parametrize(
150+
"op, test_values",
151+
[
152+
(erfcx, (0.7,)),
153+
(erfcinv, (0.7,)),
154+
(iv, (0.3, 0.7)),
155+
],
156+
)
157+
@pytest.mark.skipif(not TFP_INSTALLED, reason="Test requires tensorflow-probability")
158+
def test_tfp_ops(op, test_values):
159+
inputs = [as_tensor(test_value).type() for test_value in test_values]
160+
output = op(*inputs)
161+
162+
fg = FunctionGraph(inputs, [output])
163+
compare_jax_and_py(fg, test_values)
164+
165+
137166
def test_psi():
138167
x = scalar("x")
139168
out = psi(x)

0 commit comments

Comments
 (0)