diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 63a1ba835b..48cff4238c 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -477,6 +477,9 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): "fusion", "inplace", "scan_save_mem_prealloc", + # There are specific variants for the LU decompositions supported by JAX + "reuse_lu_decomposition_multiple_solves", + "scan_split_non_sequence_lu_decomposition_solve", ], ), ) diff --git a/pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py b/pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py index 241c776010..717d0469ff 100644 --- a/pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py +++ b/pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py @@ -6,6 +6,7 @@ from numpy import ndarray from scipy import linalg +from pytensor.link.numba.dispatch import numba_funcify from pytensor.link.numba.dispatch.basic import numba_njit from pytensor.link.numba.dispatch.linalg._LAPACK import ( _LAPACK, @@ -20,6 +21,10 @@ _solve_check, _trans_char_to_int, ) +from pytensor.tensor._linalg.solve.tridiagonal import ( + LUFactorTridiagonal, + SolveLUFactorTridiagonal, +) @numba_njit @@ -297,3 +302,48 @@ def impl( return X return impl + + +@numba_funcify.register(LUFactorTridiagonal) +def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs): + overwrite_dl = op.overwrite_dl + overwrite_d = op.overwrite_d + overwrite_du = op.overwrite_du + + @numba_njit(cache=False) + def lu_factor_tridiagonal(dl, d, du): + if not overwrite_dl: + dl = dl.copy() + if not overwrite_d: + d = d.copy() + if not overwrite_du: + du = du.copy() + + dl, d, du, du2, ipiv, _ = _gttrf(dl, d, du) + return dl, d, du, du2, ipiv + + return lu_factor_tridiagonal + + +@numba_funcify.register(SolveLUFactorTridiagonal) +def numba_funcify_SolveLUFactorTridiagonal( + op: SolveLUFactorTridiagonal, node, **kwargs +): + overwrite_b = op.overwrite_b + transposed = op.transposed + + @numba_njit(cache=False) + def solve_lu_factor_tridiagonal(dl, d, du, du2, ipiv, b): + x, _ = _gttrs( + dl, + d, + du, + du2, + ipiv, + b, + overwrite_b=overwrite_b, + trans=transposed, + ) + return x + + return solve_lu_factor_tridiagonal diff --git a/pytensor/tensor/_linalg/solve/rewriting.py b/pytensor/tensor/_linalg/solve/rewriting.py index 9ea8db37fc..8f3cda3e0f 100644 --- a/pytensor/tensor/_linalg/solve/rewriting.py +++ b/pytensor/tensor/_linalg/solve/rewriting.py @@ -1,10 +1,15 @@ from collections.abc import Container from copy import copy +from pytensor.compile import optdb from pytensor.graph import Constant, graph_inputs from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter from pytensor.scan.op import Scan from pytensor.scan.rewriting import scan_seqopt1 +from pytensor.tensor._linalg.solve.tridiagonal import ( + tridiagonal_lu_factor, + tridiagonal_lu_solve, +) from pytensor.tensor.basic import atleast_Nd from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle @@ -17,18 +22,32 @@ def decompose_A(A, assume_a, check_finite): if assume_a == "gen": return lu_factor(A, check_finite=check_finite) + elif assume_a == "tridiagonal": + # We didn't implement check_finite for tridiagonal LU factorization + return tridiagonal_lu_factor(A) else: raise NotImplementedError def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op: Solve): - if core_solve_op.assume_a == "gen": + b_ndim = core_solve_op.b_ndim + check_finite = core_solve_op.check_finite + assume_a = core_solve_op.assume_a + if assume_a == "gen": return lu_solve( A_decomp, b, + b_ndim=b_ndim, trans=transposed, - b_ndim=core_solve_op.b_ndim, - check_finite=core_solve_op.check_finite, + check_finite=check_finite, + ) + elif assume_a == "tridiagonal": + # We didn't implement check_finite for tridiagonal LU solve + return tridiagonal_lu_solve( + A_decomp, + b, + b_ndim=b_ndim, + transposed=transposed, ) else: raise NotImplementedError @@ -189,13 +208,15 @@ def _scan_split_non_sequence_lu_decomposition_solve( @register_specialize @node_rewriter([Blockwise]) def reuse_lu_decomposition_multiple_solves(fgraph, node): - return _split_lu_solve_steps(fgraph, node, eager=False, allowed_assume_a={"gen"}) + return _split_lu_solve_steps( + fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal"} + ) @node_rewriter([Scan]) def scan_split_non_sequence_lu_decomposition_solve(fgraph, node): return _scan_split_non_sequence_lu_decomposition_solve( - fgraph, node, allowed_assume_a={"gen"} + fgraph, node, allowed_assume_a={"gen", "tridiagonal"} ) @@ -207,3 +228,32 @@ def scan_split_non_sequence_lu_decomposition_solve(fgraph, node): "scan_pushout", position=2, ) + + +@node_rewriter([Blockwise]) +def reuse_lu_decomposition_multiple_solves_jax(fgraph, node): + return _split_lu_solve_steps(fgraph, node, eager=False, allowed_assume_a={"gen"}) + + +optdb["specialize"].register( + reuse_lu_decomposition_multiple_solves_jax.__name__, + in2out(reuse_lu_decomposition_multiple_solves_jax, ignore_newtrees=True), + "jax", + use_db_name_as_tag=False, +) + + +@node_rewriter([Scan]) +def scan_split_non_sequence_lu_decomposition_solve_jax(fgraph, node): + return _scan_split_non_sequence_lu_decomposition_solve( + fgraph, node, allowed_assume_a={"gen"} + ) + + +scan_seqopt1.register( + scan_split_non_sequence_lu_decomposition_solve_jax.__name__, + in2out(scan_split_non_sequence_lu_decomposition_solve_jax, ignore_newtrees=True), + "jax", + use_db_name_as_tag=False, + position=2, +) diff --git a/pytensor/tensor/_linalg/solve/tridiagonal.py b/pytensor/tensor/_linalg/solve/tridiagonal.py new file mode 100644 index 0000000000..7b8198d2a9 --- /dev/null +++ b/pytensor/tensor/_linalg/solve/tridiagonal.py @@ -0,0 +1,169 @@ +import numpy as np +from scipy.linalg import get_lapack_funcs + +from pytensor.graph import Apply, Op +from pytensor.tensor.basic import as_tensor, diagonal +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.type import tensor, vector + + +class LUFactorTridiagonal(Op): + """Compute LU factorization of a tridiagonal matrix (lapack gttrf)""" + + __props__ = ( + "overwrite_dl", + "overwrite_d", + "overwrite_du", + ) + gufunc_signature = "(dl),(d),(dl)->(dl),(d),(dl),(du2),(d)" + + def __init__(self, overwrite_dl=False, overwrite_d=False, overwrite_du=False): + self.destroy_map = dm = {} + if overwrite_dl: + dm[0] = [0] + if overwrite_d: + dm[1] = [1] + if overwrite_du: + dm[2] = [2] + self.overwrite_dl = overwrite_dl + self.overwrite_d = overwrite_d + self.overwrite_du = overwrite_du + super().__init__() + + def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": + return type(self)( + overwrite_dl=0 in allowed_inplace_inputs, + overwrite_d=1 in allowed_inplace_inputs, + overwrite_du=2 in allowed_inplace_inputs, + ) + + def make_node(self, dl, d, du): + dl, d, du = map(as_tensor, (dl, d, du)) + + if not all(inp.type.ndim == 1 for inp in (dl, d, du)): + raise ValueError("Diagonals must be vectors") + + ndl, nd, ndu = (inp.type.shape[-1] for inp in (dl, d, du)) + n = ( + ndl + 1 + if ndl is not None + else (nd if nd is not None else (ndu + 1 if ndu is not None else None)) + ) + dummy_arrays = [np.zeros((), dtype=inp.type.dtype) for inp in (dl, d, du)] + out_dtype = get_lapack_funcs("gttrf", dummy_arrays).dtype + outputs = [ + vector(shape=(None if n is None else (n - 1),), dtype=out_dtype), + vector(shape=(n,), dtype=out_dtype), + vector(shape=(None if n is None else n - 1,), dtype=out_dtype), + vector(shape=(None if n is None else n - 2,), dtype=out_dtype), + vector(shape=(n,), dtype=np.int32), + ] + return Apply(self, [dl, d, du], outputs) + + def perform(self, node, inputs, output_storage): + gttrf = get_lapack_funcs("gttrf", dtype=node.outputs[0].type.dtype) + dl, d, du, du2, ipiv, _ = gttrf( + *inputs, + overwrite_dl=self.overwrite_dl, + overwrite_d=self.overwrite_d, + overwrite_du=self.overwrite_du, + ) + output_storage[0][0] = dl + output_storage[1][0] = d + output_storage[2][0] = du + output_storage[3][0] = du2 + output_storage[4][0] = ipiv + + +class SolveLUFactorTridiagonal(Op): + """Solve a system of linear equations with a tridiagonal coefficient matrix (lapack gttrs).""" + + __props__ = ("b_ndim", "overwrite_b", "transposed") + + def __init__(self, b_ndim: int, transposed: bool, overwrite_b=False): + if b_ndim not in (1, 2): + raise ValueError("b_ndim must be 1 or 2") + if b_ndim == 1: + self.gufunc_signature = "(dl),(d),(dl),(du2),(d),(d)->(d)" + else: + self.gufunc_signature = "(dl),(d),(dl),(du2),(d),(d,rhs)->(d,rhs)" + if overwrite_b: + self.destroy_map = {0: [5]} + self.b_ndim = b_ndim + self.transposed = transposed + self.overwrite_b = overwrite_b + super().__init__() + + def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": + if 5 in allowed_inplace_inputs: + props = self._props_dict() + props["overwrite_b"] = True + return type(self)(**props) + + return self + + def make_node(self, dl, d, du, du2, ipiv, b): + dl, d, du, du2, ipiv, b = map(as_tensor, (dl, d, du, du2, ipiv, b)) + + if b.type.ndim != self.b_ndim: + raise ValueError("Wrang number of dimensions for input b.") + + if not all(inp.type.ndim == 1 for inp in (dl, d, du, du2, ipiv)): + raise ValueError("Inputs must be vectors") + + ndl, nd, ndu, ndu2, nipiv = ( + inp.type.shape[-1] for inp in (dl, d, du, du2, ipiv) + ) + nb = b.type.shape[0] + n = ( + ndl + 1 + if ndl is not None + else ( + nd + if nd is not None + else ( + ndu + 1 + if ndu is not None + else ( + ndu2 + 2 + if ndu2 is not None + else (nipiv if nipiv is not None else nb) + ) + ) + ) + ) + dummy_arrays = [ + np.zeros((), dtype=inp.type.dtype) for inp in (dl, d, du, du2, ipiv) + ] + # Seems to always be float64? + out_dtype = get_lapack_funcs("gttrs", dummy_arrays).dtype + if self.b_ndim == 1: + output_shape = (n,) + else: + output_shape = (n, b.type.shape[-1]) + + outputs = [tensor(shape=output_shape, dtype=out_dtype)] + return Apply(self, [dl, d, du, du2, ipiv, b], outputs) + + def perform(self, node, inputs, output_storage): + gttrs = get_lapack_funcs("gttrs", dtype=node.outputs[0].type.dtype) + x, _ = gttrs( + *inputs, + overwrite_b=self.overwrite_b, + trans="N" if not self.transposed else "T", + ) + output_storage[0][0] = x + + +def tridiagonal_lu_factor(a): + # Return the decomposition of A implied by a solve tridiagonal + dl, d, du = (diagonal(a, offset=o, axis1=-2, axis2=-1) for o in (-1, 0, 1)) + dl, d, du, du2, ipiv = Blockwise(LUFactorTridiagonal())(dl, d, du) + return dl, d, du, du2, ipiv + + +def tridiagonal_lu_solve(a_diagonals, b, *, b_ndim: int, transposed: bool = False): + dl, d, du, du2, ipiv = a_diagonals + return Blockwise(SolveLUFactorTridiagonal(b_ndim=b_ndim, transposed=transposed))( + dl, d, du, du2, ipiv, b + ) diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index ad5c1fc16d..7a298f2bc5 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -19,6 +19,7 @@ from pytensor.scalar import constant as scalar_constant from pytensor.tensor.basic import ( Alloc, + ExtractDiag, Join, ScalarFromTensor, TensorFromScalar, @@ -26,6 +27,7 @@ cast, concatenate, expand_dims, + full, get_scalar_constant_value, get_underlying_scalar_constant_value, register_infer_shape, @@ -1793,3 +1795,82 @@ def ravel_multidimensional_int_idx(fgraph, node): "numba", use_db_name_as_tag=False, # Not included if only "specialize" is requested ) + + +@register_canonicalize +@register_stabilize +@register_specialize +@node_rewriter([ExtractDiag]) +def extract_diag_of_diagonal_set_subtensor(fgraph, node): + def is_contant_arange(var) -> bool: + if not (isinstance(var, TensorConstant) and var.type.ndim == 1): + return False + + data = var.data + start, stop = data[0], data[-1] + 1 + return data.size == (stop - start) and (data == np.arange(start, stop)).all() + + [diag_x] = node.inputs + if not ( + diag_x.owner is not None + and isinstance(diag_x.owner.op, AdvancedIncSubtensor) + and diag_x.owner.op.set_instead_of_inc + ): + return None + + x, y, *idxs = diag_x.owner.inputs + + if not ( + x.type.ndim >= 2 + and None not in x.type.shape[-2:] + and x.type.shape[-2] == x.type.shape[-1] + ): + # For now we only support rewrite with static square shape for x + return None + + op = node.op + if op.axis2 > len(idxs): + return None + + # Check all non-axis indices are full slices + axis = {op.axis1, op.axis2} + if not all(is_full_slice(idx) for i, idx in enumerate(idxs) if i not in axis): + return None + + # Check axis indices are arange we would expect from setting on the diagonal + axis1_idx = idxs[op.axis1] + axis2_idx = idxs[op.axis2] + if not (is_contant_arange(axis1_idx) and is_contant_arange(axis2_idx)): + return None + + dim_length = x.type.shape[-1] + offset = op.offset + start_stop1 = (axis1_idx.data[0], axis1_idx.data[-1] + 1) + start_stop2 = (axis2_idx.data[0], axis2_idx.data[-1] + 1) + orig_start1, orig_start2 = start_stop1[0], start_stop2[0] + + if offset < 0: + # The logic for checking if we are selecting or not a diagonal for negative offset is the same + # as the one with positive offset but swapped axis + start_stop1, start_stop2 = start_stop2, start_stop1 + offset = -offset + + start1, stop1 = start_stop1 + start2, stop2 = start_stop2 + if ( + start1 == 0 + and start2 == offset + and stop1 == dim_length - offset + and stop2 == dim_length + ): + # We are extracting the just written diagonal + if y.type.ndim == 0 or y.type.shape[-1] == 1: + # We may need to broadcast y + y = full((*x.shape[:-2], dim_length - offset), y, dtype=x.type.dtype) + return [y] + elif (orig_start2 - orig_start1) != op.offset: + # Some other diagonal was written, ignore it + return [op(x)] + else: + # A portion, but no the whole diagonal was written, don't do anything + return None diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 278d1e8da6..99ae67af9b 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -3021,12 +3021,7 @@ def make_node(self, x, y, *inputs): return Apply( self, (x, y, *new_inputs), - [ - tensor( - dtype=x.type.dtype, - shape=tuple(1 if s == 1 else None for s in x.type.shape), - ) - ], + [x.type()], ) def perform(self, node, inputs, out_): diff --git a/tests/tensor/linalg/test_rewriting.py b/tests/tensor/linalg/test_rewriting.py index 32683029f0..898a07aae7 100644 --- a/tests/tensor/linalg/test_rewriting.py +++ b/tests/tensor/linalg/test_rewriting.py @@ -9,6 +9,10 @@ reuse_lu_decomposition_multiple_solves, scan_split_non_sequence_lu_decomposition_solve, ) +from pytensor.tensor._linalg.solve.tridiagonal import ( + LUFactorTridiagonal, + SolveLUFactorTridiagonal, +) from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.linalg import solve from pytensor.tensor.slinalg import LUFactor, Solve, SolveTriangular @@ -28,9 +32,10 @@ def count_vanilla_solve_nodes(nodes) -> int: def count_lu_decom_nodes(nodes) -> int: return sum( ( - isinstance(node.op, LUFactor) + isinstance(node.op, LUFactor | LUFactorTridiagonal) or ( - isinstance(node.op, Blockwise) and isinstance(node.op.core_op, LUFactor) + isinstance(node.op, Blockwise) + and isinstance(node.op.core_op, LUFactor | LUFactorTridiagonal) ) ) for node in nodes @@ -40,27 +45,38 @@ def count_lu_decom_nodes(nodes) -> int: def count_lu_solve_nodes(nodes) -> int: count = sum( ( - isinstance(node.op, SolveTriangular) + # LUFactor uses 2 SolveTriangular nodes, so we count each as 0.5 + 0.5 + * ( + isinstance(node.op, SolveTriangular) + or ( + isinstance(node.op, Blockwise) + and isinstance(node.op.core_op, SolveTriangular) + ) + ) or ( - isinstance(node.op, Blockwise) - and isinstance(node.op.core_op, SolveTriangular) + isinstance(node.op, SolveLUFactorTridiagonal) + or ( + isinstance(node.op, Blockwise) + and isinstance(node.op.core_op, SolveLUFactorTridiagonal) + ) ) ) for node in nodes ) - # Each LU solve uses two Triangular solves - return count // 2 + return int(count) @pytest.mark.parametrize("transposed", (False, True)) -def test_lu_decomposition_reused_forward_and_gradient(transposed): +@pytest.mark.parametrize("assume_a", ("gen", "tridiagonal")) +def test_lu_decomposition_reused_forward_and_gradient(assume_a, transposed): rewrite_name = reuse_lu_decomposition_multiple_solves.__name__ mode = get_default_mode() - A = tensor("A", shape=(2, 2)) - b = tensor("b", shape=(2, 3)) + A = tensor("A", shape=(3, 3)) + b = tensor("b", shape=(3, 4)) - x = solve(A, b, assume_a="gen", transposed=transposed) + x = solve(A, b, assume_a=assume_a, transposed=transposed) grad_x_wrt_A = grad(x.sum(), A) fn_no_opt = function([A, b], [x, grad_x_wrt_A], mode=mode.excluding(rewrite_name)) no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes @@ -86,14 +102,15 @@ def test_lu_decomposition_reused_forward_and_gradient(transposed): @pytest.mark.parametrize("transposed", (False, True)) -def test_lu_decomposition_reused_blockwise(transposed): +@pytest.mark.parametrize("assume_a", ("gen", "tridiagonal")) +def test_lu_decomposition_reused_blockwise(assume_a, transposed): rewrite_name = reuse_lu_decomposition_multiple_solves.__name__ mode = get_default_mode() - A = tensor("A", shape=(2, 2)) - b = tensor("b", shape=(2, 2, 3)) + A = tensor("A", shape=(3, 3)) + b = tensor("b", shape=(2, 3, 4)) - x = solve(A, b, transposed=transposed) + x = solve(A, b, assume_a=assume_a, transposed=transposed) fn_no_opt = function([A, b], [x], mode=mode.excluding(rewrite_name)) no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes assert count_vanilla_solve_nodes(no_opt_nodes) == 1 @@ -116,15 +133,16 @@ def test_lu_decomposition_reused_blockwise(transposed): @pytest.mark.parametrize("transposed", (False, True)) -def test_lu_decomposition_reused_scan(transposed): +@pytest.mark.parametrize("assume_a", ("gen", "tridiagonal")) +def test_lu_decomposition_reused_scan(assume_a, transposed): rewrite_name = scan_split_non_sequence_lu_decomposition_solve.__name__ mode = get_default_mode() - A = tensor("A", shape=(2, 2)) - x0 = tensor("b", shape=(2, 3)) + A = tensor("A", shape=(3, 3)) + x0 = tensor("b", shape=(3, 4)) xs, _ = scan( - lambda xtm1, A: solve(A, xtm1, assume_a="general", transposed=transposed), + lambda xtm1, A: solve(A, xtm1, assume_a=assume_a, transposed=transposed), outputs_info=[x0], non_sequences=[A], n_steps=10, diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index d20082ed36..3d59c76604 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -1,3 +1,5 @@ +import random + import numpy as np import pytest @@ -1956,3 +1958,37 @@ def test_unknown_step(self): f(test_x, -2), test_x[0:3:-2, -1:-6:2, ::], ) + + +def test_extract_diag_of_diagonal_set_subtensor(): + A = pt.full((2, 6, 6), np.nan) + rows = pt.arange(A.shape[-2]) + cols = pt.arange(A.shape[-1]) + write_offsets = [-2, -1, 0, 1, 2] + # Randomize order of write operations, to make sure rewrite is not sensitive to it + random.shuffle(write_offsets) + for offset in write_offsets: + value = offset + 0.1 * offset + if offset == 0: + A = A[..., rows, cols].set(value) + elif offset > 0: + A = A[..., rows[:-offset], cols[offset:]].set(value) + else: + offset = -offset + A = A[..., rows[offset:], cols[:-offset]].set(value) + # Add a partial diagonal along offset 3 + A = A[..., rows[1:-3], cols[4:]].set(np.pi) + + read_offsets = [-2, -1, 0, 1, 2, 3] + outs = [A.diagonal(offset=offset, axis1=-2, axis2=-1) for offset in read_offsets] + rewritten_outs = rewrite_graph(outs, include=("ShapeOpt", "canonicalize")) + + # Every output should just be an Alloc with value + expected_outs = [] + for offset in read_offsets[:-1]: + value = np.asarray(offset + 0.1 * offset, dtype=A.type.dtype) + expected_outs.append(pt.full((np.int64(2), np.int8(6 - abs(offset))), value)) + # The partial diagonal shouldn't be rewritten + expected_outs.append(outs[-1]) + + assert equal_computations(rewritten_outs, expected_outs)