diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index 7b58f32420..ee6391ae5c 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -49,7 +49,7 @@ jobs: fetch-depth: 0 - name: Build wheels - uses: pypa/cibuildwheel@v2.16.0 + uses: pypa/cibuildwheel@v2.14.1 - uses: actions/upload-artifact@v3 with: diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 9589c7aa79..e681eb6a17 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -6,7 +6,7 @@ from pytensor.tensor.blas import Dot22 from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle -from pytensor.tensor.math import Dot, Prod, log, prod +from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod from pytensor.tensor.nlinalg import MatrixInverse, det from pytensor.tensor.rewriting.basic import ( register_canonicalize, @@ -168,13 +168,25 @@ def cholesky_ldotlt(fgraph, node): rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular, or cholesky(dot(U.T, U), upper=True) = U where U is upper triangular. + Also works with matmul. + This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices. """ if not isinstance(node.op.core_op, Cholesky): return A = node.inputs[0] - if not (A.owner and isinstance(A.owner.op, (Dot, Dot22))): + if not ( + A.owner is not None + and ( + ( + isinstance(A.owner.op, (Dot, Dot22)) + # This rewrite only applies to matrix Dot + and A.owner.inputs[0].type.ndim == 2 + ) + or (A.owner.op == _matrix_matrix_matmul) + ) + ): return l, r = A.owner.inputs diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 58ea98d626..498ed18bf9 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -1,3 +1,5 @@ +from functools import partial + import numpy as np import numpy.linalg import pytest @@ -9,13 +11,14 @@ from pytensor import tensor as at from pytensor.compile import get_default_mode from pytensor.configdefaults import config +from pytensor.tensor import swapaxes from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle -from pytensor.tensor.math import _allclose +from pytensor.tensor.math import _allclose, dot, matmul from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse from pytensor.tensor.rewriting.linalg import inv_as_solve from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve -from pytensor.tensor.type import dmatrix, matrix, vector +from pytensor.tensor.type import dmatrix, matrix, tensor, vector from tests import unittest_tools as utt from tests.test_rop import break_op @@ -137,18 +140,20 @@ def test_matrix_inverse_solve(): @pytest.mark.parametrize("tag", ("lower", "upper", None)) @pytest.mark.parametrize("cholesky_form", ("lower", "upper")) @pytest.mark.parametrize("product", ("lower", "upper", None)) -def test_cholesky_ldotlt(tag, cholesky_form, product): +@pytest.mark.parametrize("op", (dot, matmul)) +def test_cholesky_ldotlt(tag, cholesky_form, product, op): transform_removes_chol = tag is not None and product == tag transform_transposes = transform_removes_chol and cholesky_form != tag - A = matrix("L") + ndim = 2 if op == dot else 3 + A = tensor("L", shape=(None,) * ndim) if tag: setattr(A.tag, tag + "_triangular", True) if product == "lower": - M = A.dot(A.T) + M = op(A, swapaxes(A, -1, -2)) elif product == "upper": - M = A.T.dot(A) + M = op(swapaxes(A, -1, -2), A) else: M = A @@ -156,14 +161,17 @@ def test_cholesky_ldotlt(tag, cholesky_form, product): f = pytensor.function([A], C, mode=get_default_mode().including("cholesky_ldotlt")) no_cholesky_in_graph = not any( - isinstance(node.op, Cholesky) for node in f.maker.fgraph.apply_nodes + isinstance(node.op, Cholesky) + or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Cholesky)) + for node in f.maker.fgraph.apply_nodes ) assert no_cholesky_in_graph == transform_removes_chol if transform_transposes: + expected_order = (1, 0) if ndim == 2 else (0, 2, 1) assert any( - isinstance(node.op, DimShuffle) and node.op.new_order == (1, 0) + isinstance(node.op, DimShuffle) and node.op.new_order == expected_order for node in f.maker.fgraph.apply_nodes ) @@ -183,6 +191,11 @@ def test_cholesky_ldotlt(tag, cholesky_form, product): ] ) + cholesky_vect_fn = np.vectorize( + partial(scipy.linalg.cholesky, lower=(cholesky_form == "lower")), + signature="(a, a)->(a, a)", + ) + for Av in Avs: if tag == "upper": Av = Av.T @@ -194,11 +207,13 @@ def test_cholesky_ldotlt(tag, cholesky_form, product): else: Mv = Av - assert np.all( - np.isclose( - scipy.linalg.cholesky(Mv, lower=(cholesky_form == "lower")), - f(Av), - ) + if ndim == 3: + Av = np.broadcast_to(Av, (5, *Av.shape)) + Mv = np.broadcast_to(Mv, (5, *Mv.shape)) + + np.testing.assert_allclose( + cholesky_vect_fn(Mv), + f(Av), )