1
+ from functools import partial
2
+
1
3
import numpy as np
2
4
import numpy .linalg
3
5
import pytest
9
11
from pytensor import tensor as at
10
12
from pytensor .compile import get_default_mode
11
13
from pytensor .configdefaults import config
14
+ from pytensor .tensor import swapaxes
12
15
from pytensor .tensor .blockwise import Blockwise
13
16
from pytensor .tensor .elemwise import DimShuffle
14
- from pytensor .tensor .math import _allclose
17
+ from pytensor .tensor .math import _allclose , dot , matmul
15
18
from pytensor .tensor .nlinalg import Det , MatrixInverse , matrix_inverse
16
19
from pytensor .tensor .rewriting .linalg import inv_as_solve
17
20
from pytensor .tensor .slinalg import Cholesky , Solve , SolveTriangular , cholesky , solve
18
- from pytensor .tensor .type import dmatrix , matrix , vector
21
+ from pytensor .tensor .type import dmatrix , matrix , tensor , vector
19
22
from tests import unittest_tools as utt
20
23
from tests .test_rop import break_op
21
24
@@ -137,26 +140,30 @@ def test_matrix_inverse_solve():
137
140
@pytest .mark .parametrize ("tag" , ("lower" , "upper" , None ))
138
141
@pytest .mark .parametrize ("cholesky_form" , ("lower" , "upper" ))
139
142
@pytest .mark .parametrize ("product" , ("lower" , "upper" , None ))
140
- def test_cholesky_ldotlt (tag , cholesky_form , product ):
143
+ @pytest .mark .parametrize ("op" , (dot , matmul ))
144
+ def test_cholesky_ldotlt (tag , cholesky_form , product , op ):
141
145
transform_removes_chol = tag is not None and product == tag
142
146
transform_transposes = transform_removes_chol and cholesky_form != tag
143
147
144
- A = matrix ("L" )
148
+ ndim = 2 if op == dot else 2
149
+ A = tensor ("L" , shape = (None ,) * ndim )
145
150
if tag :
146
151
setattr (A .tag , tag + "_triangular" , True )
147
152
148
153
if product == "lower" :
149
- M = A . dot ( A . T )
154
+ M = op ( A , swapaxes ( A , - 1 , - 2 ) )
150
155
elif product == "upper" :
151
- M = A . T . dot ( A )
156
+ M = op ( swapaxes ( A , - 1 , - 2 ), A )
152
157
else :
153
158
M = A
154
159
155
160
C = cholesky (M , lower = (cholesky_form == "lower" ))
156
161
f = pytensor .function ([A ], C , mode = get_default_mode ().including ("cholesky_ldotlt" ))
157
162
158
163
no_cholesky_in_graph = not any (
159
- isinstance (node .op , Cholesky ) for node in f .maker .fgraph .apply_nodes
164
+ isinstance (node .op , Cholesky )
165
+ or (isinstance (node .op , Blockwise ) and isinstance (node .op .core_op , Cholesky ))
166
+ for node in f .maker .fgraph .apply_nodes
160
167
)
161
168
162
169
assert no_cholesky_in_graph == transform_removes_chol
@@ -183,6 +190,11 @@ def test_cholesky_ldotlt(tag, cholesky_form, product):
183
190
]
184
191
)
185
192
193
+ cholesky_vect_fn = np .vectorize (
194
+ partial (scipy .linalg .cholesky , lower = (cholesky_form == "lower" )),
195
+ signature = "(a, a)->(a, a)" ,
196
+ )
197
+
186
198
for Av in Avs :
187
199
if tag == "upper" :
188
200
Av = Av .T
@@ -194,11 +206,13 @@ def test_cholesky_ldotlt(tag, cholesky_form, product):
194
206
else :
195
207
Mv = Av
196
208
197
- assert np .all (
198
- np .isclose (
199
- scipy .linalg .cholesky (Mv , lower = (cholesky_form == "lower" )),
200
- f (Av ),
201
- )
209
+ if ndim == 3 :
210
+ Av = np .broadcast_to (Av , (5 , * Av .shape ))
211
+ Mv = np .broadcast_to (Mv , (5 , * Mv .shape ))
212
+
213
+ np .testing .assert_allclose (
214
+ cholesky_vect_fn (Mv ),
215
+ f (Av ),
202
216
)
203
217
204
218
0 commit comments