Skip to content

Add slogdet for Numba and JAX #172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion pytensor/link/jax/dispatch/nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot, MaxAndArgmax
from pytensor.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull
from pytensor.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull, SLogDet


@jax_funcify.register(SVD)
Expand All @@ -25,6 +25,14 @@ def det(x):
return det


@jax_funcify.register(SLogDet)
def jax_funcify_SLogDet(op, **kwargs):
def slogdet(x):
return jnp.linalg.slogdet(x)

return slogdet


@jax_funcify.register(Eig)
def jax_funcify_Eig(op, **kwargs):
def eig(x):
Expand Down
20 changes: 20 additions & 0 deletions pytensor/link/numba/dispatch/nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
MatrixInverse,
MatrixPinv,
QRFull,
SLogDet,
)


Expand Down Expand Up @@ -58,6 +59,25 @@ def det(x):
return det


@numba_funcify.register(SLogDet)
def numba_funcify_SLogDet(op, node, **kwargs):

out_dtype_1 = node.outputs[0].type.numpy_dtype
out_dtype_2 = node.outputs[1].type.numpy_dtype

inputs_cast = int_to_float_fn(node.inputs, out_dtype_1)

@numba_basic.numba_njit
def slogdet(x):
sign, det = np.linalg.slogdet(inputs_cast(x))
return (
numba_basic.direct_cast(sign, out_dtype_1),
numba_basic.direct_cast(det, out_dtype_2),
)

return slogdet


@numba_funcify.register(Eig)
def numba_funcify_Eig(op, node, **kwargs):

Expand Down
33 changes: 33 additions & 0 deletions pytensor/tensor/nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,39 @@ def __str__(self):
det = Det()


class SLogDet(Op):
"""
Compute the log determinant and its sign of the matrix. Input should be a square matrix.
"""

__props__ = ()

def make_node(self, x):
x = as_tensor_variable(x)
assert x.ndim == 2
sign = scalar(dtype=x.dtype)
det = scalar(dtype=x.dtype)
return Apply(self, [x], [sign, det])

def perform(self, node, inputs, outputs):
(x,) = inputs
(sign, det) = outputs
try:
sign[0], det[0] = (z.astype(x.dtype) for z in np.linalg.slogdet(x))
except Exception:
print("Failed to compute determinant", x)
raise

def infer_shape(self, fgraph, node, shapes):
return [(), ()]

def __str__(self):
return "SLogDet"


slogdet = SLogDet()


class Eig(Op):
"""
Compute the eigenvalues and right eigenvectors of a square array.
Expand Down
4 changes: 4 additions & 0 deletions tests/link/jax/test_nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def assert_fn(x, y):
out_fg = FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)

outs = at_nlinalg.slogdet(x)
out_fg = FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)


@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
Expand Down
35 changes: 35 additions & 0 deletions tests/link/numba/test_nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,41 @@ def test_Det(x, exc):
)


@pytest.mark.parametrize(
"x, exc",
[
(
set_test_value(
at.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
None,
),
(
set_test_value(
at.lmatrix(),
(lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")),
),
None,
),
],
)
def test_SLogDet(x, exc):
g = nlinalg.SLogDet()(x)
g_fg = FunctionGraph(outputs=g)

cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)


# We were seeing some weird results in CI where the following two almost
# sign-swapped results were being return from Numba and Python, respectively.
# The issue might be related to https://github.com/numba/numba/issues/4519.
Expand Down
13 changes: 13 additions & 0 deletions tests/tensor/test_nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
norm,
pinv,
qr,
slogdet,
svd,
tensorinv,
tensorsolve,
Expand Down Expand Up @@ -280,6 +281,18 @@ def test_det_shape():
assert tuple(det_shape.data) == ()


def test_slogdet():
rng = np.random.default_rng(utt.fetch_seed())

r = rng.standard_normal((5, 5)).astype(config.floatX)
x = matrix()
f = pytensor.function([x], slogdet(x))
f_sign, f_det = f(r)
sign, det = np.linalg.slogdet(r)
assert np.equal(sign, f_sign)
assert np.allclose(det, f_det)


def test_trace():
rng = np.random.default_rng(utt.fetch_seed())
x = matrix()
Expand Down