Skip to content

Commit 3098cc7

Browse files
committed
Fix lu_solve with batch inputs
1 parent afe934b commit 3098cc7

File tree

2 files changed

+64
-26
lines changed

2 files changed

+64
-26
lines changed

pytensor/tensor/slinalg.py

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import warnings
33
from collections.abc import Sequence
4-
from functools import reduce
4+
from functools import partial, reduce
55
from typing import Literal, cast
66

77
import numpy as np
@@ -589,6 +589,7 @@ def lu(
589589

590590

591591
class PivotToPermutations(Op):
592+
gufunc_signature = "(x)->(x)"
592593
__props__ = ("inverse",)
593594

594595
def __init__(self, inverse=True):
@@ -723,40 +724,22 @@ def lu_factor(
723724
)
724725

725726

726-
def lu_solve(
727-
LU_and_pivots: tuple[TensorLike, TensorLike],
727+
def _lu_solve(
728+
LU: TensorLike,
729+
pivots: TensorLike,
728730
b: TensorLike,
729731
trans: bool = False,
730732
b_ndim: int | None = None,
731733
check_finite: bool = True,
732-
overwrite_b: bool = False,
733734
):
734-
"""
735-
Solve a system of linear equations given the LU decomposition of the matrix.
736-
737-
Parameters
738-
----------
739-
LU_and_pivots: tuple[TensorLike, TensorLike]
740-
LU decomposition of the matrix, as returned by `lu_factor`
741-
b: TensorLike
742-
Right-hand side of the equation
743-
trans: bool
744-
If True, solve A^T x = b, instead of Ax = b. Default is False
745-
b_ndim: int, optional
746-
The number of core dimensions in b. Used to distinguish between a batch of vectors (b_ndim=1) and a matrix
747-
of vectors (b_ndim=2). Default is None, which will infer the number of core dimensions from the input.
748-
check_finite: bool
749-
If True, check that the input matrices contain only finite numbers. Default is True.
750-
overwrite_b: bool
751-
Ignored by Pytensor. Pytensor will always compute inplace when possible.
752-
"""
753735
b_ndim = _default_b_ndim(b, b_ndim)
754-
LU, pivots = LU_and_pivots
755736

756737
LU, pivots, b = map(pt.as_tensor_variable, [LU, pivots, b])
757-
inv_permutation = pivot_to_permutation(pivots, inverse=True)
758738

739+
inv_permutation = pivot_to_permutation(pivots, inverse=True)
759740
x = b[inv_permutation] if not trans else b
741+
# TODO: Use PermuteRows on b
742+
# x = permute_rows(b, pivots) if not trans else b
760743

761744
x = solve_triangular(
762745
LU,
@@ -777,11 +760,51 @@ def lu_solve(
777760
b_ndim=b_ndim,
778761
check_finite=check_finite,
779762
)
780-
x = x[pt.argsort(inv_permutation)] if trans else x
781763

764+
# TODO: Use PermuteRows(inverse=True) on x
765+
# if trans:
766+
# x = permute_rows(x, pivots, inverse=True)
767+
x = x[pt.argsort(inv_permutation)] if trans else x
782768
return x
783769

784770

771+
def lu_solve(
772+
LU_and_pivots: tuple[TensorLike, TensorLike],
773+
b: TensorLike,
774+
trans: bool = False,
775+
b_ndim: int | None = None,
776+
check_finite: bool = True,
777+
overwrite_b: bool = False,
778+
):
779+
"""
780+
Solve a system of linear equations given the LU decomposition of the matrix.
781+
782+
Parameters
783+
----------
784+
LU_and_pivots: tuple[TensorLike, TensorLike]
785+
LU decomposition of the matrix, as returned by `lu_factor`
786+
b: TensorLike
787+
Right-hand side of the equation
788+
trans: bool
789+
If True, solve A^T x = b, instead of Ax = b. Default is False
790+
b_ndim: int, optional
791+
The number of core dimensions in b. Used to distinguish between a batch of vectors (b_ndim=1) and a matrix
792+
of vectors (b_ndim=2). Default is None, which will infer the number of core dimensions from the input.
793+
check_finite: bool
794+
If True, check that the input matrices contain only finite numbers. Default is True.
795+
overwrite_b: bool
796+
Ignored by Pytensor. Pytensor will always compute inplace when possible.
797+
"""
798+
if b_ndim == 1:
799+
signature = "(m,m),(m),(m)->(m)"
800+
else:
801+
signature = "(m,m),(m),(m,n)->(m,n)"
802+
partialled_func = partial(
803+
_lu_solve, trans=trans, b_ndim=b_ndim, check_finite=check_finite
804+
)
805+
return pt.vectorize(partialled_func, signature=signature)(*LU_and_pivots, b)
806+
807+
785808
class SolveTriangular(SolveBase):
786809
"""Solve a system of linear equations."""
787810

tests/tensor/test_slinalg.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,21 @@ def test_lu_solve_gradient(self, b_shape: tuple[int], trans: bool):
738738
test_fn = functools.partial(self.factor_and_solve, sum=True, trans=trans)
739739
utt.verify_grad(test_fn, [A_val, b_val], 3, rng)
740740

741+
def test_lu_solve_batch_dims(self):
742+
A = pt.tensor("A", shape=(3, 1, 5, 5))
743+
b = pt.tensor("b", shape=(1, 4, 5))
744+
lu_and_pivots = lu_factor(A)
745+
x = lu_solve(lu_and_pivots, b, b_ndim=1)
746+
assert x.type.shape in {(3, 4, None), (3, 4, 5)}
747+
748+
rng = np.random.default_rng(748)
749+
A_test = rng.random(A.type.shape).astype(A.type.dtype)
750+
b_test = rng.random(b.type.shape).astype(b.type.dtype)
751+
np.testing.assert_allclose(
752+
x.eval({A: A_test, b: b_test}),
753+
solve(A, b, b_ndim=1).eval({A: A_test, b: b_test}),
754+
)
755+
741756

742757
def test_lu_factor():
743758
rng = np.random.default_rng(utt.fetch_seed())

0 commit comments

Comments
 (0)