Skip to content

Commit fdaaafb

Browse files
committed
Add slogdet for Numba
1 parent 25236cf commit fdaaafb

File tree

4 files changed

+109
-0
lines changed

4 files changed

+109
-0
lines changed

pytensor/link/numba/dispatch/nlinalg.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
MatrixInverse,
1919
MatrixPinv,
2020
QRFull,
21+
SLogDet,
2122
)
2223

2324

@@ -58,6 +59,25 @@ def det(x):
5859
return det
5960

6061

62+
@numba_funcify.register(SLogDet)
63+
def numba_funcify_SLogDet(op, node, **kwargs):
64+
65+
out_dtype_1 = node.outputs[0].type.numpy_dtype
66+
out_dtype_2 = node.outputs[1].type.numpy_dtype
67+
68+
inputs_cast = int_to_float_fn(node.inputs, out_dtype_1)
69+
70+
@numba_basic.numba_njit
71+
def slogdet(x):
72+
sign, det = np.linalg.slogdet(inputs_cast(x))
73+
return (
74+
numba_basic.direct_cast(sign, out_dtype_1),
75+
numba_basic.direct_cast(det, out_dtype_2),
76+
)
77+
78+
return slogdet
79+
80+
6181
@numba_funcify.register(Eig)
6282
def numba_funcify_Eig(op, node, **kwargs):
6383

pytensor/tensor/nlinalg.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,39 @@ def __str__(self):
231231
det = Det()
232232

233233

234+
class SLogDet(Op):
235+
"""
236+
Compute the log determinant and its sign of the matrix. Input should be a square matrix.
237+
"""
238+
239+
__props__ = ()
240+
241+
def make_node(self, x):
242+
x = as_tensor_variable(x)
243+
assert x.ndim == 2
244+
sign = scalar(dtype=x.dtype)
245+
det = scalar(dtype=x.dtype)
246+
return Apply(self, [x], [sign, det])
247+
248+
def perform(self, node, inputs, outputs):
249+
(x,) = inputs
250+
(sign, det) = outputs
251+
try:
252+
sign[0], det[0] = (z.astype(x.dtype) for z in np.linalg.slogdet(x))
253+
except Exception:
254+
print("Failed to compute determinant", x)
255+
raise
256+
257+
def infer_shape(self, fgraph, node, shapes):
258+
return [(), ()]
259+
260+
def __str__(self):
261+
return "SLogDet"
262+
263+
264+
slogdet = SLogDet()
265+
266+
234267
class Eig(Op):
235268
"""
236269
Compute the eigenvalues and right eigenvectors of a square array.

tests/link/numba/test_nlinalg.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,41 @@ def test_Det(x, exc):
179179
)
180180

181181

182+
@pytest.mark.parametrize(
183+
"x, exc",
184+
[
185+
(
186+
set_test_value(
187+
at.dmatrix(),
188+
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
189+
),
190+
None,
191+
),
192+
(
193+
set_test_value(
194+
at.lmatrix(),
195+
(lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")),
196+
),
197+
None,
198+
),
199+
],
200+
)
201+
def test_SLogDet(x, exc):
202+
g = nlinalg.SLogDet()(x)
203+
g_fg = FunctionGraph(outputs=g)
204+
205+
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
206+
with cm:
207+
compare_numba_and_py(
208+
g_fg,
209+
[
210+
i.tag.test_value
211+
for i in g_fg.inputs
212+
if not isinstance(i, (SharedVariable, Constant))
213+
],
214+
)
215+
216+
182217
# We were seeing some weird results in CI where the following two almost
183218
# sign-swapped results were being return from Numba and Python, respectively.
184219
# The issue might be related to https://github.com/numba/numba/issues/4519.

tests/tensor/test_nlinalg.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
norm,
2525
pinv,
2626
qr,
27+
slogdet,
2728
svd,
2829
tensorinv,
2930
tensorsolve,
@@ -280,6 +281,26 @@ def test_det_shape():
280281
assert tuple(det_shape.data) == ()
281282

282283

284+
def test_slogdet():
285+
rng = np.random.default_rng(utt.fetch_seed())
286+
287+
r = rng.standard_normal((5, 5)).astype(config.floatX)
288+
x = matrix()
289+
f = pytensor.function([x], slogdet(x))
290+
f_sign, f_det = f(r)
291+
sign, det = np.linalg.slogdet(r)
292+
assert np.equal(sign, f_sign)
293+
assert np.allclose(det, f_det)
294+
295+
296+
def test_slogdet_shape():
297+
x = matrix()
298+
sign, det = slogdet(x)
299+
for shape in [sign.shape, det.shape]:
300+
assert isinstance(shape, Constant)
301+
assert tuple(shape.data) == ()
302+
303+
283304
def test_trace():
284305
rng = np.random.default_rng(utt.fetch_seed())
285306
x = matrix()

0 commit comments

Comments
 (0)