From 9c8fbaec416c927f627d6dc447ad8f09dddb4655 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Mon, 19 May 2025 13:49:00 +0200 Subject: [PATCH 1/2] Include signature in Blockwise tester RNG seed --- tests/tensor/test_blockwise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index a140e07846..a660a0c094 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -328,7 +328,7 @@ class BlockwiseOpTester: @classmethod def setup_class(cls): - seed = sum(map(ord, str(cls.core_op))) + seed = sum(map(ord, str(cls.core_op) + cls.signature)) cls.rng = np.random.default_rng(seed) cls.params_sig, cls.outputs_sig = _parse_gufunc_signature(cls.signature) if cls.batcheable_axes is None: From 72780769b0d2bfab2db15c0d1887352ec407ca95 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 29 Apr 2025 13:48:55 +0200 Subject: [PATCH 2/2] Reuse LU decomposition in Solve --- pytensor/compile/mode.py | 2 + pytensor/scan/rewriting.py | 10 +- pytensor/tensor/__init__.py | 1 + pytensor/tensor/_linalg/__init__.py | 2 + pytensor/tensor/_linalg/solve/__init__.py | 2 + pytensor/tensor/_linalg/solve/rewriting.py | 198 +++++++++++++++++++++ pytensor/tensor/rewriting/linalg.py | 7 + tests/tensor/linalg/__init__.py | 0 tests/tensor/linalg/test_rewriting.py | 163 +++++++++++++++++ tests/tensor/test_blockwise.py | 5 +- 10 files changed, 383 insertions(+), 7 deletions(-) create mode 100644 pytensor/tensor/_linalg/__init__.py create mode 100644 pytensor/tensor/_linalg/solve/__init__.py create mode 100644 pytensor/tensor/_linalg/solve/rewriting.py create mode 100644 tests/tensor/linalg/__init__.py create mode 100644 tests/tensor/linalg/test_rewriting.py diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index f80dfaaf5c..63a1ba835b 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -490,6 +490,8 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): "fusion", "inplace", "scan_save_mem_prealloc", + "reuse_lu_decomposition_multiple_solves", + "scan_split_non_sequence_lu_decomposition_solve", ], ), ) diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index b8e6b009d8..c49fbadce4 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -2561,7 +2561,6 @@ def scan_push_out_dot1(fgraph, node): position=1, ) - scan_seqopt1.register( "scan_push_out_non_seq", in2out(scan_push_out_non_seq, ignore_newtrees=True), @@ -2569,10 +2568,9 @@ def scan_push_out_dot1(fgraph, node): "fast_run", "scan", "scan_pushout", - position=2, + position=3, ) - scan_seqopt1.register( "scan_push_out_seq", in2out(scan_push_out_seq, ignore_newtrees=True), @@ -2580,7 +2578,7 @@ def scan_push_out_dot1(fgraph, node): "fast_run", "scan", "scan_pushout", - position=3, + position=4, ) @@ -2592,7 +2590,7 @@ def scan_push_out_dot1(fgraph, node): "more_mem", "scan", "scan_pushout", - position=4, + position=5, ) @@ -2605,7 +2603,7 @@ def scan_push_out_dot1(fgraph, node): "more_mem", "scan", "scan_pushout", - position=5, + position=6, ) scan_eqopt2.register( diff --git a/pytensor/tensor/__init__.py b/pytensor/tensor/__init__.py index c6b421d003..ce590f8228 100644 --- a/pytensor/tensor/__init__.py +++ b/pytensor/tensor/__init__.py @@ -114,6 +114,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int: # isort: off +import pytensor.tensor._linalg from pytensor.tensor import linalg from pytensor.tensor import special from pytensor.tensor import signal diff --git a/pytensor/tensor/_linalg/__init__.py b/pytensor/tensor/_linalg/__init__.py new file mode 100644 index 0000000000..767374b10b --- /dev/null +++ b/pytensor/tensor/_linalg/__init__.py @@ -0,0 +1,2 @@ +# Register rewrites +import pytensor.tensor._linalg.solve diff --git a/pytensor/tensor/_linalg/solve/__init__.py b/pytensor/tensor/_linalg/solve/__init__.py new file mode 100644 index 0000000000..1d85f4a66b --- /dev/null +++ b/pytensor/tensor/_linalg/solve/__init__.py @@ -0,0 +1,2 @@ +# Register rewrites in the database +import pytensor.tensor._linalg.solve.rewriting diff --git a/pytensor/tensor/_linalg/solve/rewriting.py b/pytensor/tensor/_linalg/solve/rewriting.py new file mode 100644 index 0000000000..ff1c74cdec --- /dev/null +++ b/pytensor/tensor/_linalg/solve/rewriting.py @@ -0,0 +1,198 @@ +from collections.abc import Container +from copy import copy + +from pytensor.graph import Constant, graph_inputs +from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter +from pytensor.scan.op import Scan +from pytensor.scan.rewriting import scan_seqopt1 +from pytensor.tensor.basic import atleast_Nd +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.elemwise import DimShuffle +from pytensor.tensor.rewriting.basic import register_specialize +from pytensor.tensor.rewriting.linalg import is_matrix_transpose +from pytensor.tensor.slinalg import Solve, lu_factor, lu_solve +from pytensor.tensor.variable import TensorVariable + + +def decompose_A(A, assume_a): + if assume_a == "gen": + return lu_factor(A, check_finite=False) + else: + raise NotImplementedError + + +def solve_lu_decomposed_system(A_decomp, b, b_ndim, assume_a, transposed=False): + if assume_a == "gen": + return lu_solve(A_decomp, b, b_ndim=b_ndim, trans=transposed) + else: + raise NotImplementedError + + +def _split_lu_solve_steps( + fgraph, node, *, eager: bool, allowed_assume_a: Container[str] +): + if not isinstance(node.op.core_op, Solve): + return None + + def get_root_A(a: TensorVariable) -> tuple[TensorVariable, bool]: + # Find the root variable of the first input to Solve + # If `a` is a left expand_dims or matrix transpose (DimShuffle variants), + # the root variable is the pre-DimShuffled input. + # Otherwise, `a` is considered the root variable. + # We also return whether the root `a` is transposed. + transposed = False + if a.owner is not None and isinstance(a.owner.op, DimShuffle): + if a.owner.op.is_left_expand_dims: + [a] = a.owner.inputs + elif is_matrix_transpose(a): + [a] = a.owner.inputs + transposed = True + return a, transposed + + def find_solve_clients(var, assume_a): + clients = [] + for cl, idx in fgraph.clients[var]: + if ( + idx == 0 + and isinstance(cl.op, Blockwise) + and isinstance(cl.op.core_op, Solve) + and (cl.op.core_op.assume_a == assume_a) + ): + clients.append(cl) + elif isinstance(cl.op, DimShuffle) and cl.op.is_left_expand_dims: + # If it's a left expand_dims, recurse on the output + clients.extend(find_solve_clients(cl.outputs[0], assume_a)) + return clients + + assume_a = node.op.core_op.assume_a + + if assume_a not in allowed_assume_a: + return None + + A, _ = get_root_A(node.inputs[0]) + + # Find Solve using A (or left expand_dims of A) + # TODO: We could handle arbitrary shuffle of the batch dimensions, just need to propagate + # that to the A_decomp outputs + A_solve_clients_and_transpose = [ + (client, False) for client in find_solve_clients(A, assume_a) + ] + + # Find Solves using A.T + for cl, _ in fgraph.clients[A]: + if isinstance(cl.op, DimShuffle) and is_matrix_transpose(cl.out): + A_T = cl.out + A_solve_clients_and_transpose.extend( + (client, True) for client in find_solve_clients(A_T, assume_a) + ) + + if not eager and len(A_solve_clients_and_transpose) == 1: + # If theres' a single use don't do it... unless it's being broadcast in a Blockwise (or we're eager) + # That's a "reuse" inside the inner vectorized loop + batch_ndim = node.op.batch_ndim(node) + (client, _) = A_solve_clients_and_transpose[0] + original_A, b = client.inputs + if not any( + a_bcast and not b_bcast + for a_bcast, b_bcast in zip( + original_A.type.broadcastable[:batch_ndim], + b.type.broadcastable[:batch_ndim], + strict=True, + ) + ): + return None + + A_decomp = decompose_A(A, assume_a=assume_a) + + replacements = {} + for client, transposed in A_solve_clients_and_transpose: + _, b = client.inputs + b_ndim = client.op.core_op.b_ndim + new_x = solve_lu_decomposed_system( + A_decomp, b, b_ndim=b_ndim, assume_a=assume_a, transposed=transposed + ) + [old_x] = client.outputs + new_x = atleast_Nd(new_x, n=old_x.type.ndim).astype(old_x.type.dtype) + copy_stack_trace(old_x, new_x) + replacements[old_x] = new_x + + return replacements + + +def _scan_split_non_sequence_lu_decomposition_solve( + fgraph, node, *, allowed_assume_a: Container[str] +): + """If the A of a Solve within a Scan is a function of non-sequences, split the LU decomposition step. + + The LU decomposition step can then be pushed out of the inner loop by the `scan_pushout_non_sequences` rewrite. + """ + scan_op: Scan = node.op + non_sequences = set(scan_op.inner_non_seqs(scan_op.inner_inputs)) + new_scan_fgraph = scan_op.fgraph + + changed = False + while True: + for inner_node in new_scan_fgraph.toposort(): + if ( + isinstance(inner_node.op, Blockwise) + and isinstance(inner_node.op.core_op, Solve) + and inner_node.op.core_op.assume_a in allowed_assume_a + ): + A, b = inner_node.inputs + if all( + (isinstance(root_inp, Constant) or (root_inp in non_sequences)) + for root_inp in graph_inputs([A]) + ): + if new_scan_fgraph is scan_op.fgraph: + # Clone the first time to avoid mutating the original fgraph + new_scan_fgraph, equiv = new_scan_fgraph.clone_get_equiv() + non_sequences = {equiv[non_seq] for non_seq in non_sequences} + inner_node = equiv[inner_node] # type: ignore + + replace_dict = _split_lu_solve_steps( + new_scan_fgraph, + inner_node, + eager=True, + allowed_assume_a=allowed_assume_a, + ) + assert ( + isinstance(replace_dict, dict) and len(replace_dict) > 0 + ), "Rewrite failed" + new_scan_fgraph.replace_all(replace_dict.items()) + changed = True + break # Break to start over with a fresh toposort + else: # no_break + break # Nothing else changed + + if not changed: + return + + # Return a new scan to indicate that a rewrite was done + new_scan_op = copy(scan_op) + new_scan_op.fgraph = new_scan_fgraph + new_outs = new_scan_op.make_node(*node.inputs).outputs + copy_stack_trace(node.outputs, new_outs) + return new_outs + + +@register_specialize +@node_rewriter([Blockwise]) +def reuse_lu_decomposition_multiple_solves(fgraph, node): + return _split_lu_solve_steps(fgraph, node, eager=False, allowed_assume_a={"gen"}) + + +@node_rewriter([Scan]) +def scan_split_non_sequence_lu_decomposition_solve(fgraph, node): + return _scan_split_non_sequence_lu_decomposition_solve( + fgraph, node, allowed_assume_a={"gen"} + ) + + +scan_seqopt1.register( + "scan_split_non_sequence_lu_decomposition_solve", + in2out(scan_split_non_sequence_lu_decomposition_solve, ignore_newtrees=True), + "fast_run", + "scan", + "scan_pushout", + position=2, +) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index cd202fe3ed..af42bee236 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -75,6 +75,13 @@ def is_matrix_transpose(x: TensorVariable) -> bool: if ndims < 2: return False transpose_order = (*range(ndims - 2), ndims - 1, ndims - 2) + + # Allow expand_dims on the left of the transpose + if (diff := len(transpose_order) - len(node.op.new_order)) > 0: + transpose_order = ( + *(["x"] * diff), + *transpose_order, + ) return node.op.new_order == transpose_order return False diff --git a/tests/tensor/linalg/__init__.py b/tests/tensor/linalg/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/tensor/linalg/test_rewriting.py b/tests/tensor/linalg/test_rewriting.py new file mode 100644 index 0000000000..6f04fac5fb --- /dev/null +++ b/tests/tensor/linalg/test_rewriting.py @@ -0,0 +1,163 @@ +import numpy as np +import pytest + +from pytensor import config, function, scan +from pytensor.compile.mode import get_default_mode +from pytensor.gradient import grad +from pytensor.scan.op import Scan +from pytensor.tensor._linalg.solve.rewriting import ( + reuse_lu_decomposition_multiple_solves, + scan_split_non_sequence_lu_decomposition_solve, +) +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.linalg import solve +from pytensor.tensor.slinalg import LUFactor, Solve, SolveTriangular +from pytensor.tensor.type import tensor + + +def count_vanilla_solve_nodes(nodes) -> int: + return sum( + ( + isinstance(node.op, Solve) + or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Solve)) + ) + for node in nodes + ) + + +def count_lu_decom_nodes(nodes) -> int: + return sum( + ( + isinstance(node.op, LUFactor) + or ( + isinstance(node.op, Blockwise) and isinstance(node.op.core_op, LUFactor) + ) + ) + for node in nodes + ) + + +def count_lu_solve_nodes(nodes) -> int: + count = sum( + ( + isinstance(node.op, SolveTriangular) + or ( + isinstance(node.op, Blockwise) + and isinstance(node.op.core_op, SolveTriangular) + ) + ) + for node in nodes + ) + # Each LU solve uses two Triangular solves + return count // 2 + + +@pytest.mark.parametrize("transposed", (False, True)) +def test_lu_decomposition_reused_forward_and_gradient(transposed): + rewrite_name = reuse_lu_decomposition_multiple_solves.__name__ + mode = get_default_mode() + + A = tensor("A", shape=(2, 2)) + b = tensor("b", shape=(2, 3)) + + x = solve(A, b, assume_a="gen", transposed=transposed) + grad_x_wrt_A = grad(x.sum(), A) + fn_no_opt = function([A, b], [x, grad_x_wrt_A], mode=mode.excluding(rewrite_name)) + no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes + assert count_vanilla_solve_nodes(no_opt_nodes) == 2 + assert count_lu_decom_nodes(no_opt_nodes) == 0 + assert count_lu_solve_nodes(no_opt_nodes) == 0 + + fn_opt = function([A, b], [x, grad_x_wrt_A], mode=mode.including(rewrite_name)) + opt_nodes = fn_opt.maker.fgraph.apply_nodes + assert count_vanilla_solve_nodes(opt_nodes) == 0 + assert count_lu_decom_nodes(opt_nodes) == 1 + assert count_lu_solve_nodes(opt_nodes) == 2 + + # Make sure results are correct + rng = np.random.default_rng(31) + A_test = rng.random(A.type.shape, dtype=A.type.dtype) + b_test = rng.random(b.type.shape, dtype=b.type.dtype) + resx0, resg0 = fn_no_opt(A_test, b_test) + resx1, resg1 = fn_opt(A_test, b_test) + rtol = 1e-7 if config.floatX == "float64" else 1e-6 + np.testing.assert_allclose(resx0, resx1, rtol=rtol) + np.testing.assert_allclose(resg0, resg1, rtol=rtol) + + +@pytest.mark.parametrize("transposed", (False, True)) +def test_lu_decomposition_reused_blockwise(transposed): + rewrite_name = reuse_lu_decomposition_multiple_solves.__name__ + mode = get_default_mode() + + A = tensor("A", shape=(2, 2)) + b = tensor("b", shape=(2, 2, 3)) + + x = solve(A, b, transposed=transposed) + fn_no_opt = function([A, b], [x], mode=mode.excluding(rewrite_name)) + no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes + assert count_vanilla_solve_nodes(no_opt_nodes) == 1 + assert count_lu_decom_nodes(no_opt_nodes) == 0 + assert count_lu_solve_nodes(no_opt_nodes) == 0 + + fn_opt = function([A, b], [x], mode=mode.including(rewrite_name)) + opt_nodes = fn_opt.maker.fgraph.apply_nodes + assert count_vanilla_solve_nodes(opt_nodes) == 0 + assert count_lu_decom_nodes(opt_nodes) == 1 + assert count_lu_solve_nodes(opt_nodes) == 1 + + # Make sure results are correct + rng = np.random.default_rng(31) + A_test = rng.random(A.type.shape, dtype=A.type.dtype) + b_test = rng.random(b.type.shape, dtype=b.type.dtype) + resx0 = fn_no_opt(A_test, b_test) + resx1 = fn_opt(A_test, b_test) + np.testing.assert_allclose(resx0, resx1) + + +@pytest.mark.parametrize("transposed", (False, True)) +def test_lu_decomposition_reused_scan(transposed): + rewrite_name = scan_split_non_sequence_lu_decomposition_solve.__name__ + mode = get_default_mode() + + A = tensor("A", shape=(2, 2)) + x0 = tensor("b", shape=(2, 3)) + + xs, _ = scan( + lambda xtm1, A: solve(A, xtm1, assume_a="general", transposed=transposed), + outputs_info=[x0], + non_sequences=[A], + n_steps=10, + ) + + fn_no_opt = function( + [A, x0], + [xs], + mode=mode.excluding(rewrite_name), + ) + [no_opt_scan_node] = [ + node for node in fn_no_opt.maker.fgraph.apply_nodes if isinstance(node.op, Scan) + ] + no_opt_nodes = no_opt_scan_node.op.fgraph.apply_nodes + assert count_vanilla_solve_nodes(no_opt_nodes) == 1 + assert count_lu_decom_nodes(no_opt_nodes) == 0 + assert count_lu_solve_nodes(no_opt_nodes) == 0 + + fn_opt = function([A, x0], [xs], mode=mode.including("scan", rewrite_name)) + [opt_scan_node] = [ + node for node in fn_opt.maker.fgraph.apply_nodes if isinstance(node.op, Scan) + ] + opt_nodes = opt_scan_node.op.fgraph.apply_nodes + assert count_vanilla_solve_nodes(opt_nodes) == 0 + # The LU decomp is outside of the scan! + assert count_lu_decom_nodes(opt_nodes) == 0 + assert count_lu_solve_nodes(opt_nodes) == 1 + + # Make sure results are correct + rng = np.random.default_rng(170) + A_test = rng.random(A.type.shape, dtype=A.type.dtype) + x0_test = rng.random(x0.type.shape, dtype=x0.type.dtype) + resx0 = fn_no_opt(A_test, x0_test) + resx1 = fn_opt(A_test, x0_test) + rtol = 1e-7 if config.floatX == "float64" else 1e-6 + np.testing.assert_allclose(resx0, resx1, rtol=rtol) diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index a660a0c094..cbaf27da29 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -579,7 +579,10 @@ def test_solve(self, solve_fn, batched_A, batched_b): else: x = solve_fn(A, b, b_ndim=1) - mode = get_default_mode().excluding("batched_vector_b_solve_to_matrix_b_solve") + mode = get_default_mode().excluding( + "batched_vector_b_solve_to_matrix_b_solve", + "reuse_lu_decomposition_multiple_solves", + ) fn = function([In(A, mutable=True), In(b, mutable=True)], x, mode=mode) op = fn.maker.fgraph.outputs[0].owner.op