Skip to content

Decompose Tridiagonal Solve into core steps #1382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,9 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
"fusion",
"inplace",
"scan_save_mem_prealloc",
# There are specific variants for the LU decompositions supported by JAX
"reuse_lu_decomposition_multiple_solves",
"scan_split_non_sequence_lu_decomposition_solve",
],
),
)
Expand Down
50 changes: 50 additions & 0 deletions pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from numpy import ndarray
from scipy import linalg

from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import numba_njit
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
Expand All @@ -20,6 +21,10 @@
_solve_check,
_trans_char_to_int,
)
from pytensor.tensor._linalg.solve.tridiagonal import (
LUFactorTridiagonal,
SolveLUFactorTridiagonal,
)


@numba_njit
Expand Down Expand Up @@ -297,3 +302,48 @@ def impl(
return X

return impl


@numba_funcify.register(LUFactorTridiagonal)
def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
overwrite_dl = op.overwrite_dl
overwrite_d = op.overwrite_d
overwrite_du = op.overwrite_du

@numba_njit(cache=False)
def lu_factor_tridiagonal(dl, d, du):
if not overwrite_dl:
dl = dl.copy()
if not overwrite_d:
d = d.copy()
if not overwrite_du:
du = du.copy()

dl, d, du, du2, ipiv, _ = _gttrf(dl, d, du)
return dl, d, du, du2, ipiv

return lu_factor_tridiagonal


@numba_funcify.register(SolveLUFactorTridiagonal)
def numba_funcify_SolveLUFactorTridiagonal(
op: SolveLUFactorTridiagonal, node, **kwargs
):
overwrite_b = op.overwrite_b
transposed = op.transposed

@numba_njit(cache=False)
def solve_lu_factor_tridiagonal(dl, d, du, du2, ipiv, b):
x, _ = _gttrs(
dl,
d,
du,
du2,
ipiv,
b,
overwrite_b=overwrite_b,
trans=transposed,
)
return x

return solve_lu_factor_tridiagonal
60 changes: 55 additions & 5 deletions pytensor/tensor/_linalg/solve/rewriting.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from collections.abc import Container
from copy import copy

from pytensor.compile import optdb
from pytensor.graph import Constant, graph_inputs
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
from pytensor.scan.op import Scan
from pytensor.scan.rewriting import scan_seqopt1
from pytensor.tensor._linalg.solve.tridiagonal import (
tridiagonal_lu_factor,
tridiagonal_lu_solve,
)
from pytensor.tensor.basic import atleast_Nd
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
Expand All @@ -17,18 +22,32 @@
def decompose_A(A, assume_a, check_finite):
if assume_a == "gen":
return lu_factor(A, check_finite=check_finite)
elif assume_a == "tridiagonal":
# We didn't implement check_finite for tridiagonal LU factorization
return tridiagonal_lu_factor(A)
else:
raise NotImplementedError


def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op: Solve):
if core_solve_op.assume_a == "gen":
b_ndim = core_solve_op.b_ndim
check_finite = core_solve_op.check_finite
assume_a = core_solve_op.assume_a
if assume_a == "gen":
return lu_solve(
A_decomp,
b,
b_ndim=b_ndim,
trans=transposed,
b_ndim=core_solve_op.b_ndim,
check_finite=core_solve_op.check_finite,
check_finite=check_finite,
)
elif assume_a == "tridiagonal":
# We didn't implement check_finite for tridiagonal LU solve
return tridiagonal_lu_solve(
A_decomp,
b,
b_ndim=b_ndim,
transposed=transposed,
)
else:
raise NotImplementedError
Expand Down Expand Up @@ -189,13 +208,15 @@ def _scan_split_non_sequence_lu_decomposition_solve(
@register_specialize
@node_rewriter([Blockwise])
def reuse_lu_decomposition_multiple_solves(fgraph, node):
return _split_lu_solve_steps(fgraph, node, eager=False, allowed_assume_a={"gen"})
return _split_lu_solve_steps(
fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal"}
)


@node_rewriter([Scan])
def scan_split_non_sequence_lu_decomposition_solve(fgraph, node):
return _scan_split_non_sequence_lu_decomposition_solve(
fgraph, node, allowed_assume_a={"gen"}
fgraph, node, allowed_assume_a={"gen", "tridiagonal"}
)


Expand All @@ -207,3 +228,32 @@ def scan_split_non_sequence_lu_decomposition_solve(fgraph, node):
"scan_pushout",
position=2,
)


@node_rewriter([Blockwise])
def reuse_lu_decomposition_multiple_solves_jax(fgraph, node):
return _split_lu_solve_steps(fgraph, node, eager=False, allowed_assume_a={"gen"})


optdb["specialize"].register(
reuse_lu_decomposition_multiple_solves_jax.__name__,
in2out(reuse_lu_decomposition_multiple_solves_jax, ignore_newtrees=True),
"jax",
use_db_name_as_tag=False,
)


@node_rewriter([Scan])
def scan_split_non_sequence_lu_decomposition_solve_jax(fgraph, node):
return _scan_split_non_sequence_lu_decomposition_solve(
fgraph, node, allowed_assume_a={"gen"}
)


scan_seqopt1.register(
scan_split_non_sequence_lu_decomposition_solve_jax.__name__,
in2out(scan_split_non_sequence_lu_decomposition_solve_jax, ignore_newtrees=True),
"jax",
use_db_name_as_tag=False,
position=2,
)
169 changes: 169 additions & 0 deletions pytensor/tensor/_linalg/solve/tridiagonal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import numpy as np
from scipy.linalg import get_lapack_funcs

from pytensor.graph import Apply, Op
from pytensor.tensor.basic import as_tensor, diagonal
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.type import tensor, vector


class LUFactorTridiagonal(Op):
"""Compute LU factorization of a tridiagonal matrix (lapack gttrf)"""

__props__ = (
"overwrite_dl",
"overwrite_d",
"overwrite_du",
)
gufunc_signature = "(dl),(d),(dl)->(dl),(d),(dl),(du2),(d)"

def __init__(self, overwrite_dl=False, overwrite_d=False, overwrite_du=False):
self.destroy_map = dm = {}
if overwrite_dl:
dm[0] = [0]
if overwrite_d:
dm[1] = [1]
if overwrite_du:
dm[2] = [2]
self.overwrite_dl = overwrite_dl
self.overwrite_d = overwrite_d
self.overwrite_du = overwrite_du
super().__init__()

def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
return type(self)(
overwrite_dl=0 in allowed_inplace_inputs,
overwrite_d=1 in allowed_inplace_inputs,
overwrite_du=2 in allowed_inplace_inputs,
)

def make_node(self, dl, d, du):
dl, d, du = map(as_tensor, (dl, d, du))

if not all(inp.type.ndim == 1 for inp in (dl, d, du)):
raise ValueError("Diagonals must be vectors")

ndl, nd, ndu = (inp.type.shape[-1] for inp in (dl, d, du))
n = (
ndl + 1
if ndl is not None
else (nd if nd is not None else (ndu + 1 if ndu is not None else None))
)
dummy_arrays = [np.zeros((), dtype=inp.type.dtype) for inp in (dl, d, du)]
out_dtype = get_lapack_funcs("gttrf", dummy_arrays).dtype
outputs = [
vector(shape=(None if n is None else (n - 1),), dtype=out_dtype),
vector(shape=(n,), dtype=out_dtype),
vector(shape=(None if n is None else n - 1,), dtype=out_dtype),
vector(shape=(None if n is None else n - 2,), dtype=out_dtype),
vector(shape=(n,), dtype=np.int32),
]
return Apply(self, [dl, d, du], outputs)

def perform(self, node, inputs, output_storage):
gttrf = get_lapack_funcs("gttrf", dtype=node.outputs[0].type.dtype)
dl, d, du, du2, ipiv, _ = gttrf(
*inputs,
overwrite_dl=self.overwrite_dl,
overwrite_d=self.overwrite_d,
overwrite_du=self.overwrite_du,
)
output_storage[0][0] = dl
output_storage[1][0] = d
output_storage[2][0] = du
output_storage[3][0] = du2
output_storage[4][0] = ipiv


class SolveLUFactorTridiagonal(Op):
"""Solve a system of linear equations with a tridiagonal coefficient matrix (lapack gttrs)."""

__props__ = ("b_ndim", "overwrite_b", "transposed")

def __init__(self, b_ndim: int, transposed: bool, overwrite_b=False):
if b_ndim not in (1, 2):
raise ValueError("b_ndim must be 1 or 2")
if b_ndim == 1:
self.gufunc_signature = "(dl),(d),(dl),(du2),(d),(d)->(d)"
else:
self.gufunc_signature = "(dl),(d),(dl),(du2),(d),(d,rhs)->(d,rhs)"
if overwrite_b:
self.destroy_map = {0: [5]}
self.b_ndim = b_ndim
self.transposed = transposed
self.overwrite_b = overwrite_b
super().__init__()

def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
if 5 in allowed_inplace_inputs:
props = self._props_dict()
props["overwrite_b"] = True
return type(self)(**props)

return self

def make_node(self, dl, d, du, du2, ipiv, b):
dl, d, du, du2, ipiv, b = map(as_tensor, (dl, d, du, du2, ipiv, b))

if b.type.ndim != self.b_ndim:
raise ValueError("Wrang number of dimensions for input b.")

if not all(inp.type.ndim == 1 for inp in (dl, d, du, du2, ipiv)):
raise ValueError("Inputs must be vectors")

ndl, nd, ndu, ndu2, nipiv = (
inp.type.shape[-1] for inp in (dl, d, du, du2, ipiv)
)
nb = b.type.shape[0]
n = (
ndl + 1
if ndl is not None
else (
nd
if nd is not None
else (
ndu + 1
if ndu is not None
else (
ndu2 + 2
if ndu2 is not None
else (nipiv if nipiv is not None else nb)
)
)
)
)
dummy_arrays = [
np.zeros((), dtype=inp.type.dtype) for inp in (dl, d, du, du2, ipiv)
]
# Seems to always be float64?
out_dtype = get_lapack_funcs("gttrs", dummy_arrays).dtype
if self.b_ndim == 1:
output_shape = (n,)
else:
output_shape = (n, b.type.shape[-1])

outputs = [tensor(shape=output_shape, dtype=out_dtype)]
return Apply(self, [dl, d, du, du2, ipiv, b], outputs)

def perform(self, node, inputs, output_storage):
gttrs = get_lapack_funcs("gttrs", dtype=node.outputs[0].type.dtype)
x, _ = gttrs(
*inputs,
overwrite_b=self.overwrite_b,
trans="N" if not self.transposed else "T",
)
output_storage[0][0] = x


def tridiagonal_lu_factor(a):
# Return the decomposition of A implied by a solve tridiagonal
dl, d, du = (diagonal(a, offset=o, axis1=-2, axis2=-1) for o in (-1, 0, 1))
dl, d, du, du2, ipiv = Blockwise(LUFactorTridiagonal())(dl, d, du)
return dl, d, du, du2, ipiv


def tridiagonal_lu_solve(a_diagonals, b, *, b_ndim: int, transposed: bool = False):
dl, d, du, du2, ipiv = a_diagonals
return Blockwise(SolveLUFactorTridiagonal(b_ndim=b_ndim, transposed=transposed))(
dl, d, du, du2, ipiv, b
)
Loading
Loading