1
+ from functools import partial
2
+
1
3
import numpy as np
2
4
import numpy .linalg
3
5
import pytest
9
11
from pytensor import tensor as at
10
12
from pytensor .compile import get_default_mode
11
13
from pytensor .configdefaults import config
14
+ from pytensor .tensor import swapaxes
12
15
from pytensor .tensor .blockwise import Blockwise
13
16
from pytensor .tensor .elemwise import DimShuffle
14
- from pytensor .tensor .math import _allclose
17
+ from pytensor .tensor .math import _allclose , dot , matmul
15
18
from pytensor .tensor .nlinalg import Det , MatrixInverse , matrix_inverse
16
19
from pytensor .tensor .rewriting .linalg import inv_as_solve
17
20
from pytensor .tensor .slinalg import Cholesky , Solve , SolveTriangular , cholesky , solve
18
- from pytensor .tensor .type import dmatrix , matrix , vector
21
+ from pytensor .tensor .type import dmatrix , matrix , tensor , vector
19
22
from tests import unittest_tools as utt
20
23
from tests .test_rop import break_op
21
24
@@ -137,33 +140,38 @@ def test_matrix_inverse_solve():
137
140
@pytest .mark .parametrize ("tag" , ("lower" , "upper" , None ))
138
141
@pytest .mark .parametrize ("cholesky_form" , ("lower" , "upper" ))
139
142
@pytest .mark .parametrize ("product" , ("lower" , "upper" , None ))
140
- def test_cholesky_ldotlt (tag , cholesky_form , product ):
143
+ @pytest .mark .parametrize ("op" , (dot , matmul ))
144
+ def test_cholesky_ldotlt (tag , cholesky_form , product , op ):
141
145
transform_removes_chol = tag is not None and product == tag
142
146
transform_transposes = transform_removes_chol and cholesky_form != tag
143
147
144
- A = matrix ("L" )
148
+ ndim = 2 if op == dot else 3
149
+ A = tensor ("L" , shape = (None ,) * ndim )
145
150
if tag :
146
151
setattr (A .tag , tag + "_triangular" , True )
147
152
148
153
if product == "lower" :
149
- M = A . dot ( A . T )
154
+ M = op ( A , swapaxes ( A , - 1 , - 2 ) )
150
155
elif product == "upper" :
151
- M = A . T . dot ( A )
156
+ M = op ( swapaxes ( A , - 1 , - 2 ), A )
152
157
else :
153
158
M = A
154
159
155
160
C = cholesky (M , lower = (cholesky_form == "lower" ))
156
161
f = pytensor .function ([A ], C , mode = get_default_mode ().including ("cholesky_ldotlt" ))
157
162
158
163
no_cholesky_in_graph = not any (
159
- isinstance (node .op , Cholesky ) for node in f .maker .fgraph .apply_nodes
164
+ isinstance (node .op , Cholesky )
165
+ or (isinstance (node .op , Blockwise ) and isinstance (node .op .core_op , Cholesky ))
166
+ for node in f .maker .fgraph .apply_nodes
160
167
)
161
168
162
169
assert no_cholesky_in_graph == transform_removes_chol
163
170
164
171
if transform_transposes :
172
+ expected_order = (1 , 0 ) if ndim == 2 else (0 , 2 , 1 )
165
173
assert any (
166
- isinstance (node .op , DimShuffle ) and node .op .new_order == ( 1 , 0 )
174
+ isinstance (node .op , DimShuffle ) and node .op .new_order == expected_order
167
175
for node in f .maker .fgraph .apply_nodes
168
176
)
169
177
@@ -183,6 +191,11 @@ def test_cholesky_ldotlt(tag, cholesky_form, product):
183
191
]
184
192
)
185
193
194
+ cholesky_vect_fn = np .vectorize (
195
+ partial (scipy .linalg .cholesky , lower = (cholesky_form == "lower" )),
196
+ signature = "(a, a)->(a, a)" ,
197
+ )
198
+
186
199
for Av in Avs :
187
200
if tag == "upper" :
188
201
Av = Av .T
@@ -194,11 +207,13 @@ def test_cholesky_ldotlt(tag, cholesky_form, product):
194
207
else :
195
208
Mv = Av
196
209
197
- assert np .all (
198
- np .isclose (
199
- scipy .linalg .cholesky (Mv , lower = (cholesky_form == "lower" )),
200
- f (Av ),
201
- )
210
+ if ndim == 3 :
211
+ Av = np .broadcast_to (Av , (5 , * Av .shape ))
212
+ Mv = np .broadcast_to (Mv , (5 , * Mv .shape ))
213
+
214
+ np .testing .assert_allclose (
215
+ cholesky_vect_fn (Mv ),
216
+ f (Av ),
202
217
)
203
218
204
219
0 commit comments