|
1 | 1 | import numpy as np
|
2 | 2 | import numpy.linalg
|
| 3 | +import pytest |
| 4 | +import scipy.linalg |
3 | 5 |
|
4 | 6 | import pytensor
|
5 | 7 | from pytensor import function
|
6 | 8 | from pytensor import tensor as at
|
| 9 | +from pytensor.compile import get_default_mode |
7 | 10 | from pytensor.configdefaults import config
|
8 | 11 | from pytensor.tensor.elemwise import DimShuffle
|
9 | 12 | from pytensor.tensor.math import _allclose
|
@@ -105,3 +108,75 @@ def test_matrix_inverse_solve():
|
105 | 108 | node = matrix_inverse(A).dot(b).owner
|
106 | 109 | [out] = inv_as_solve.transform(None, node)
|
107 | 110 | assert isinstance(out.owner.op, Solve)
|
| 111 | + |
| 112 | + |
| 113 | +@pytest.mark.parametrize("tag", ("lower", "upper", None)) |
| 114 | +@pytest.mark.parametrize("cholesky_form", ("lower", "upper")) |
| 115 | +@pytest.mark.parametrize("product", ("lower", "upper", None)) |
| 116 | +def test_cholesky_ldotlt(tag, cholesky_form, product): |
| 117 | + cholesky = Cholesky(lower=(cholesky_form == "lower")) |
| 118 | + |
| 119 | + transform_removes_chol = tag is not None and product == tag |
| 120 | + transform_transposes = transform_removes_chol and cholesky_form != tag |
| 121 | + |
| 122 | + A = matrix("L") |
| 123 | + if tag: |
| 124 | + setattr(A.tag, tag + "_triangular", True) |
| 125 | + |
| 126 | + if product == "lower": |
| 127 | + M = A.dot(A.T) |
| 128 | + elif product == "upper": |
| 129 | + M = A.T.dot(A) |
| 130 | + else: |
| 131 | + M = A |
| 132 | + |
| 133 | + C = cholesky(M) |
| 134 | + f = pytensor.function([A], C, mode=get_default_mode().including("cholesky_ldotlt")) |
| 135 | + |
| 136 | + print(f.maker.fgraph.apply_nodes) |
| 137 | + |
| 138 | + no_cholesky_in_graph = not any( |
| 139 | + isinstance(node.op, Cholesky) for node in f.maker.fgraph.apply_nodes |
| 140 | + ) |
| 141 | + |
| 142 | + assert no_cholesky_in_graph == transform_removes_chol |
| 143 | + |
| 144 | + if transform_transposes: |
| 145 | + assert any( |
| 146 | + isinstance(node.op, DimShuffle) and node.op.new_order == (1, 0) |
| 147 | + for node in f.maker.fgraph.apply_nodes |
| 148 | + ) |
| 149 | + |
| 150 | + # Test some concrete value through f |
| 151 | + # there must be lower triangular (f assumes they are) |
| 152 | + Avs = [ |
| 153 | + np.eye(1, dtype=pytensor.config.floatX), |
| 154 | + np.eye(10, dtype=pytensor.config.floatX), |
| 155 | + np.array([[2, 0], [1, 4]], dtype=pytensor.config.floatX), |
| 156 | + ] |
| 157 | + if not tag: |
| 158 | + # these must be positive def |
| 159 | + Avs.extend( |
| 160 | + [ |
| 161 | + np.ones((4, 4), dtype=pytensor.config.floatX) |
| 162 | + + np.eye(4, dtype=pytensor.config.floatX), |
| 163 | + ] |
| 164 | + ) |
| 165 | + |
| 166 | + for Av in Avs: |
| 167 | + if tag == "upper": |
| 168 | + Av = Av.T |
| 169 | + |
| 170 | + if product == "lower": |
| 171 | + Mv = Av.dot(Av.T) |
| 172 | + elif product == "upper": |
| 173 | + Mv = Av.T.dot(Av) |
| 174 | + else: |
| 175 | + Mv = Av |
| 176 | + |
| 177 | + assert np.all( |
| 178 | + np.isclose( |
| 179 | + scipy.linalg.cholesky(Mv, lower=(cholesky_form == "lower")), |
| 180 | + f(Av), |
| 181 | + ) |
| 182 | + ) |
0 commit comments