Skip to content

Commit 8ef585b

Browse files
mtsokoltwiecki
authored andcommitted
Add slogdet for JAX
1 parent 06e7afe commit 8ef585b

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

pytensor/link/jax/dispatch/nlinalg.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pytensor.link.jax.dispatch import jax_funcify
44
from pytensor.tensor.blas import BatchedDot
55
from pytensor.tensor.math import Dot, MaxAndArgmax
6-
from pytensor.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull
6+
from pytensor.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull, SLogDet
77

88

99
@jax_funcify.register(SVD)
@@ -25,6 +25,14 @@ def det(x):
2525
return det
2626

2727

28+
@jax_funcify.register(SLogDet)
29+
def jax_funcify_SLogDet(op, **kwargs):
30+
def slogdet(x):
31+
return jnp.linalg.slogdet(x)
32+
33+
return slogdet
34+
35+
2836
@jax_funcify.register(Eig)
2937
def jax_funcify_Eig(op, **kwargs):
3038
def eig(x):

tests/link/jax/test_nlinalg.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ def assert_fn(x, y):
8585
out_fg = FunctionGraph([x], outs)
8686
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
8787

88+
outs = at_nlinalg.slogdet(x)
89+
out_fg = FunctionGraph([x], outs)
90+
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
91+
8892

8993
@pytest.mark.xfail(
9094
version_parse(jax.__version__) >= version_parse("0.2.12"),

0 commit comments

Comments
 (0)