diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 14a6d91a7d..bbdc9cbba7 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -1,7 +1,7 @@ import logging import warnings from collections.abc import Sequence -from functools import reduce +from functools import partial, reduce from typing import Literal, cast import numpy as np @@ -589,6 +589,7 @@ def lu( class PivotToPermutations(Op): + gufunc_signature = "(x)->(x)" __props__ = ("inverse",) def __init__(self, inverse=True): @@ -723,40 +724,22 @@ def lu_factor( ) -def lu_solve( - LU_and_pivots: tuple[TensorLike, TensorLike], +def _lu_solve( + LU: TensorLike, + pivots: TensorLike, b: TensorLike, trans: bool = False, b_ndim: int | None = None, check_finite: bool = True, - overwrite_b: bool = False, ): - """ - Solve a system of linear equations given the LU decomposition of the matrix. - - Parameters - ---------- - LU_and_pivots: tuple[TensorLike, TensorLike] - LU decomposition of the matrix, as returned by `lu_factor` - b: TensorLike - Right-hand side of the equation - trans: bool - If True, solve A^T x = b, instead of Ax = b. Default is False - b_ndim: int, optional - The number of core dimensions in b. Used to distinguish between a batch of vectors (b_ndim=1) and a matrix - of vectors (b_ndim=2). Default is None, which will infer the number of core dimensions from the input. - check_finite: bool - If True, check that the input matrices contain only finite numbers. Default is True. - overwrite_b: bool - Ignored by Pytensor. Pytensor will always compute inplace when possible. - """ b_ndim = _default_b_ndim(b, b_ndim) - LU, pivots = LU_and_pivots LU, pivots, b = map(pt.as_tensor_variable, [LU, pivots, b]) - inv_permutation = pivot_to_permutation(pivots, inverse=True) + inv_permutation = pivot_to_permutation(pivots, inverse=True) x = b[inv_permutation] if not trans else b + # TODO: Use PermuteRows on b + # x = permute_rows(b, pivots) if not trans else b x = solve_triangular( LU, @@ -777,11 +760,52 @@ def lu_solve( b_ndim=b_ndim, check_finite=check_finite, ) - x = x[pt.argsort(inv_permutation)] if trans else x + # TODO: Use PermuteRows(inverse=True) on x + # if trans: + # x = permute_rows(x, pivots, inverse=True) + x = x[pt.argsort(inv_permutation)] if trans else x return x +def lu_solve( + LU_and_pivots: tuple[TensorLike, TensorLike], + b: TensorLike, + trans: bool = False, + b_ndim: int | None = None, + check_finite: bool = True, + overwrite_b: bool = False, +): + """ + Solve a system of linear equations given the LU decomposition of the matrix. + + Parameters + ---------- + LU_and_pivots: tuple[TensorLike, TensorLike] + LU decomposition of the matrix, as returned by `lu_factor` + b: TensorLike + Right-hand side of the equation + trans: bool + If True, solve A^T x = b, instead of Ax = b. Default is False + b_ndim: int, optional + The number of core dimensions in b. Used to distinguish between a batch of vectors (b_ndim=1) and a matrix + of vectors (b_ndim=2). Default is None, which will infer the number of core dimensions from the input. + check_finite: bool + If True, check that the input matrices contain only finite numbers. Default is True. + overwrite_b: bool + Ignored by Pytensor. Pytensor will always compute inplace when possible. + """ + b_ndim = _default_b_ndim(b, b_ndim) + if b_ndim == 1: + signature = "(m,m),(m),(m)->(m)" + else: + signature = "(m,m),(m),(m,n)->(m,n)" + partialled_func = partial( + _lu_solve, trans=trans, b_ndim=b_ndim, check_finite=check_finite + ) + return pt.vectorize(partialled_func, signature=signature)(*LU_and_pivots, b) + + class SolveTriangular(SolveBase): """Solve a system of linear equations.""" diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index f57488a9b8..f18f514244 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -6,7 +6,6 @@ import pytest import scipy -import pytensor from pytensor import function, grad from pytensor import tensor as pt from pytensor.configdefaults import config @@ -130,7 +129,7 @@ def test_cholesky_grad_indef(): def test_cholesky_infer_shape(): x = matrix() - f_chol = pytensor.function([x], [cholesky(x).shape, cholesky(x, lower=False).shape]) + f_chol = function([x], [cholesky(x).shape, cholesky(x, lower=False).shape]) if config.mode != "FAST_COMPILE": topo_chol = f_chol.maker.fgraph.toposort() f_chol.dprint() @@ -313,7 +312,7 @@ def test_solve_correctness( b_ndim=len(b_size), ) - solve_func = pytensor.function([A, b], y) + solve_func = function([A, b], y) X_np = solve_func(A_val.copy(), b_val.copy()) ATOL = 1e-8 if config.floatX.endswith("64") else 1e-4 @@ -444,7 +443,7 @@ def test_correctness(self, b_shape: tuple[int], lower, trans, unit_diagonal): b_ndim=len(b_shape), ) - f = pytensor.function([A, b], x) + f = function([A, b], x) x_pt = f(A_val, b_val) x_sp = scipy.linalg.solve_triangular( @@ -508,8 +507,8 @@ def test_infer_shape(self): A = matrix() b = matrix() self._compile_and_check( - [A, b], # pytensor.function inputs - [self.op_class(b_ndim=2)(A, b)], # pytensor.function outputs + [A, b], # function inputs + [self.op_class(b_ndim=2)(A, b)], # function outputs # A must be square [ np.asarray(rng.random((5, 5)), dtype=config.floatX), @@ -522,8 +521,8 @@ def test_infer_shape(self): A = matrix() b = vector() self._compile_and_check( - [A, b], # pytensor.function inputs - [self.op_class(b_ndim=1)(A, b)], # pytensor.function outputs + [A, b], # function inputs + [self.op_class(b_ndim=1)(A, b)], # function outputs # A must be square [ np.asarray(rng.random((5, 5)), dtype=config.floatX), @@ -538,10 +537,10 @@ def test_solve_correctness(self): A = matrix() b = matrix() y = self.op_class(lower=True, b_ndim=2)(A, b) - cho_solve_lower_func = pytensor.function([A, b], y) + cho_solve_lower_func = function([A, b], y) y = self.op_class(lower=False, b_ndim=2)(A, b) - cho_solve_upper_func = pytensor.function([A, b], y) + cho_solve_upper_func = function([A, b], y) b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX) @@ -603,7 +602,7 @@ def test_lu_decomposition( A = tensor("A", shape=shape, dtype=dtype) out = lu(A, permute_l=permute_l, p_indices=p_indices) - f = pytensor.function([A], out) + f = function([A], out) rng = np.random.default_rng(utt.fetch_seed()) x = rng.normal(size=shape).astype(config.floatX) @@ -706,7 +705,7 @@ def test_lu_solve(self, b_shape: tuple[int], trans): x = self.factor_and_solve(A, b, trans=trans, sum=False) - f = pytensor.function([A, b], x) + f = function([A, b], x) x_pt = f(A_val.copy(), b_val.copy()) x_sp = scipy.linalg.lu_solve( scipy.linalg.lu_factor(A_val.copy()), b_val.copy(), trans=trans @@ -738,13 +737,29 @@ def test_lu_solve_gradient(self, b_shape: tuple[int], trans: bool): test_fn = functools.partial(self.factor_and_solve, sum=True, trans=trans) utt.verify_grad(test_fn, [A_val, b_val], 3, rng) + def test_lu_solve_batch_dims(self): + A = pt.tensor("A", shape=(3, 1, 5, 5)) + b = pt.tensor("b", shape=(1, 4, 5)) + lu_and_pivots = lu_factor(A) + x = lu_solve(lu_and_pivots, b, b_ndim=1) + assert x.type.shape in {(3, 4, None), (3, 4, 5)} + + rng = np.random.default_rng(748) + A_test = rng.random(A.type.shape).astype(A.type.dtype) + b_test = rng.random(b.type.shape).astype(b.type.dtype) + np.testing.assert_allclose( + x.eval({A: A_test, b: b_test}), + solve(A, b, b_ndim=1).eval({A: A_test, b: b_test}), + rtol=1e-9 if config.floatX == "float64" else 1e-5, + ) + def test_lu_factor(): rng = np.random.default_rng(utt.fetch_seed()) A = matrix() A_val = rng.normal(size=(5, 5)).astype(config.floatX) - f = pytensor.function([A], lu_factor(A)) + f = function([A], lu_factor(A)) LU, pt_p_idx = f(A_val) sp_LU, sp_p_idx = scipy.linalg.lu_factor(A_val) @@ -764,7 +779,7 @@ def test_cho_solve(): A = matrix() b = matrix() y = cho_solve((A, True), b) - cho_solve_lower_func = pytensor.function([A, b], y) + cho_solve_lower_func = function([A, b], y) b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX)