File tree Expand file tree Collapse file tree 4 files changed +107
-0
lines changed Expand file tree Collapse file tree 4 files changed +107
-0
lines changed Original file line number Diff line number Diff line change 18
18
MatrixInverse ,
19
19
MatrixPinv ,
20
20
QRFull ,
21
+ SLogDet ,
21
22
)
22
23
23
24
@@ -58,6 +59,25 @@ def det(x):
58
59
return det
59
60
60
61
62
+ @numba_funcify .register (SLogDet )
63
+ def numba_funcify_SLogDet (op , node , ** kwargs ):
64
+
65
+ out_dtype_1 = node .outputs [0 ].type .numpy_dtype
66
+ out_dtype_2 = node .outputs [1 ].type .numpy_dtype
67
+
68
+ inputs_cast = int_to_float_fn (node .inputs , out_dtype_1 )
69
+
70
+ @numba_basic .numba_njit
71
+ def slogdet (x ):
72
+ sign , det = np .linalg .slogdet (inputs_cast (x ))
73
+ return (
74
+ numba_basic .direct_cast (sign , out_dtype_1 ),
75
+ numba_basic .direct_cast (det , out_dtype_2 ),
76
+ )
77
+
78
+ return slogdet
79
+
80
+
61
81
@numba_funcify .register (Eig )
62
82
def numba_funcify_Eig (op , node , ** kwargs ):
63
83
Original file line number Diff line number Diff line change @@ -231,6 +231,45 @@ def __str__(self):
231
231
det = Det ()
232
232
233
233
234
+ class SLogDet (Op ):
235
+ """
236
+ Compute sign and log determinant of the matrix. Input should be a square matrix.
237
+ """
238
+
239
+ __props__ = ()
240
+
241
+ def make_node (self , x ):
242
+ x = as_tensor_variable (x )
243
+ assert x .ndim == 2
244
+ s = scalar (dtype = x .dtype )
245
+ d = scalar (dtype = x .dtype )
246
+ return Apply (self , [x ], [s , d ])
247
+
248
+ def perform (self , node , inputs , outputs ):
249
+ (x ,) = inputs
250
+ (s , d ) = outputs
251
+ try :
252
+ s [0 ], d [0 ] = (z .astype (x .dtype ) for z in np .linalg .slogdet (x ))
253
+ except Exception :
254
+ print ("Failed to compute determinant" , x )
255
+ raise
256
+
257
+ def grad (self , inputs , g_outputs ):
258
+ (gz ,) = g_outputs
259
+ (x ,) = inputs
260
+ sign , det = self (x )
261
+ return [gz * sign * np .exp (det ) * matrix_inverse (x ).T ]
262
+
263
+ def infer_shape (self , fgraph , node , shapes ):
264
+ return [(), ()]
265
+
266
+ def __str__ (self ):
267
+ return "SLogDet"
268
+
269
+
270
+ slogdet = SLogDet ()
271
+
272
+
234
273
class Eig (Op ):
235
274
"""
236
275
Compute the eigenvalues and right eigenvectors of a square array.
Original file line number Diff line number Diff line change @@ -179,6 +179,41 @@ def test_Det(x, exc):
179
179
)
180
180
181
181
182
+ @pytest .mark .parametrize (
183
+ "x, exc" ,
184
+ [
185
+ (
186
+ set_test_value (
187
+ at .dmatrix (),
188
+ (lambda x : x .T .dot (x ))(rng .random (size = (3 , 3 )).astype ("float64" )),
189
+ ),
190
+ None ,
191
+ ),
192
+ (
193
+ set_test_value (
194
+ at .lmatrix (),
195
+ (lambda x : x .T .dot (x ))(rng .poisson (size = (3 , 3 )).astype ("int64" )),
196
+ ),
197
+ None ,
198
+ ),
199
+ ],
200
+ )
201
+ def test_SLogDet (x , exc ):
202
+ g = nlinalg .SLogDet ()(x )
203
+ g_fg = FunctionGraph (outputs = g )
204
+
205
+ cm = contextlib .suppress () if exc is None else pytest .warns (exc )
206
+ with cm :
207
+ compare_numba_and_py (
208
+ g_fg ,
209
+ [
210
+ i .tag .test_value
211
+ for i in g_fg .inputs
212
+ if not isinstance (i , (SharedVariable , Constant ))
213
+ ],
214
+ )
215
+
216
+
182
217
# We were seeing some weird results in CI where the following two almost
183
218
# sign-swapped results were being return from Numba and Python, respectively.
184
219
# The issue might be related to https://github.com/numba/numba/issues/4519.
Original file line number Diff line number Diff line change 24
24
norm ,
25
25
pinv ,
26
26
qr ,
27
+ slogdet ,
27
28
svd ,
28
29
tensorinv ,
29
30
tensorsolve ,
@@ -266,6 +267,18 @@ def test_det():
266
267
assert np .allclose (np .linalg .det (r ), f (r ))
267
268
268
269
270
+ def test_slogdet ():
271
+ rng = np .random .default_rng (utt .fetch_seed ())
272
+
273
+ r = rng .standard_normal ((5 , 5 )).astype (config .floatX )
274
+ x = matrix ()
275
+ f = pytensor .function ([x ], slogdet (x ))
276
+ f_sign , f_det = f (r )
277
+ sign , det = np .linalg .slogdet (r )
278
+ assert np .equal (sign , f_sign )
279
+ assert np .allclose (det , f_det )
280
+
281
+
269
282
def test_det_grad ():
270
283
rng = np .random .default_rng (utt .fetch_seed ())
271
284
You can’t perform that action at this time.
0 commit comments