diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 1de6dbb373..47ca08cf21 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -3,6 +3,7 @@ from typing import cast from pytensor import Variable +from pytensor import tensor as pt from pytensor.graph import Apply, FunctionGraph from pytensor.graph.rewriting.basic import ( copy_stack_trace, @@ -48,6 +49,7 @@ logger = logging.getLogger(__name__) +ALL_INVERSE_OPS = (MatrixInverse, MatrixPinv) def is_matrix_transpose(x: TensorVariable) -> bool: @@ -592,11 +594,10 @@ def rewrite_inv_inv(fgraph, node): list of Variable, optional List of optimized variables, or None if no optimization was performed """ - valid_inverses = (MatrixInverse, MatrixPinv) # Check if its a valid inverse operation (either inv/pinv) # In case the outer operation is an inverse, it directly goes to the next step of finding inner operation # If the outer operation is not a valid inverse, we do not apply this rewrite - if not isinstance(node.op.core_op, valid_inverses): + if not isinstance(node.op.core_op, ALL_INVERSE_OPS): return None potential_inner_inv = node.inputs[0].owner @@ -607,7 +608,96 @@ def rewrite_inv_inv(fgraph, node): if not ( potential_inner_inv and isinstance(potential_inner_inv.op, Blockwise) - and isinstance(potential_inner_inv.op.core_op, valid_inverses) + and isinstance(potential_inner_inv.op.core_op, ALL_INVERSE_OPS) ): return None return [potential_inner_inv.inputs[0]] + + +@register_canonicalize +@register_stabilize +@node_rewriter([Blockwise]) +def rewrite_inv_eye_to_eye(fgraph, node): + """ + This rewrite takes advantage of the fact that the inverse of an identity matrix is the matrix itself + The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside an inverse op. + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + core_op = node.op.core_op + if not (isinstance(core_op, ALL_INVERSE_OPS)): + return None + + # Check whether input to inverse is Eye and the 1's are on main diagonal + potential_eye = node.inputs[0] + if not ( + potential_eye.owner + and isinstance(potential_eye.owner.op, Eye) + and getattr(potential_eye.owner.inputs[-1], "data", -1).item() == 0 + ): + return None + return [potential_eye] + + +@register_canonicalize +@register_stabilize +@node_rewriter([Blockwise]) +def rewrite_inv_diag_to_diag_reciprocal(fgraph, node): + """ + This rewrite takes advantage of the fact that for a diagonal matrix, the inverse is a diagonal matrix with the new diagonal entries as reciprocals of the original diagonal elements. + This function deals with diagonal matrix arising from the multiplicaton of eye with a scalar/vector/matrix + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + core_op = node.op.core_op + if not (isinstance(core_op, ALL_INVERSE_OPS)): + return None + + inputs = node.inputs[0] + # Check for use of pt.diag first + if ( + inputs.owner + and isinstance(inputs.owner.op, AllocDiag) + and AllocDiag.is_offset_zero(inputs.owner) + ): + inv_input = inputs.owner.inputs[0] + inv_val = pt.diag(1 / inv_input) + return [inv_val] + + # Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix + inputs_or_none = _find_diag_from_eye_mul(inputs) + if inputs_or_none is None: + return None + + eye_input, non_eye_inputs = inputs_or_none + + # Dealing with only one other input + if len(non_eye_inputs) != 1: + return None + + non_eye_input = non_eye_inputs[0] + + # For a matrix, we have to first extract the diagonal (non-zero values) and then only use those + if non_eye_input.type.broadcastable[-2:] == (False, False): + non_eye_diag = non_eye_input.diagonal(axis1=-1, axis2=-2) + non_eye_input = pt.shape_padaxis(non_eye_diag, -2) + + return [eye_input / non_eye_input] diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 7353a82be0..0bee56eb30 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -41,6 +41,9 @@ from tests.test_rop import break_op +ATOL = RTOL = 1e-3 if config.floatX == "float32" else 1e-8 + + def test_rop_lop(): mx = matrix("mx") mv = matrix("mv") @@ -557,14 +560,105 @@ def test_svd_uv_merge(): assert svd_counter == 1 +def get_pt_function(x, op_name): + return getattr(pt.linalg, op_name)(x) + + @pytest.mark.parametrize("inv_op_1", ["inv", "pinv"]) @pytest.mark.parametrize("inv_op_2", ["inv", "pinv"]) def test_inv_inv_rewrite(inv_op_1, inv_op_2): - def get_pt_function(x, op_name): - return getattr(pt.linalg, op_name)(x) - x = pt.matrix("x") op1 = get_pt_function(x, inv_op_1) op2 = get_pt_function(op1, inv_op_2) rewritten_out = rewrite_graph(op2) assert rewritten_out == x + + +@pytest.mark.parametrize("inv_op", ["inv", "pinv"]) +def test_inv_eye_to_eye(inv_op): + x = pt.eye(10) + x_inv = get_pt_function(x, inv_op) + f_rewritten = function([], x_inv, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + + # Rewrite Test + valid_inverses = (MatrixInverse, MatrixPinv) + assert not any(isinstance(node.op, valid_inverses) for node in nodes) + + # Value Test + x_test = np.eye(10) + x_inv_val = np.linalg.inv(x_test) + rewritten_val = f_rewritten() + + assert_allclose( + x_inv_val, + rewritten_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + +@pytest.mark.parametrize( + "shape", + [(), (7,), (7, 7), (5, 7, 7)], + ids=["scalar", "vector", "matrix", "batched"], +) +@pytest.mark.parametrize("inv_op", ["inv", "pinv"]) +def test_inv_diag_from_eye_mul(shape, inv_op): + # Initializing x based on scalar/vector/matrix + x = pt.tensor("x", shape=shape) + x_diag = pt.eye(7) * x + # Calculating inverse using pt.linalg.inv + x_inv = get_pt_function(x_diag, inv_op) + + # REWRITE TEST + f_rewritten = function([x], x_inv, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + + valid_inverses = (MatrixInverse, MatrixPinv) + assert not any(isinstance(node.op, valid_inverses) for node in nodes) + + # NUMERIC VALUE TEST + if len(shape) == 0: + x_test = np.array(np.random.rand()).astype(config.floatX) + elif len(shape) == 1: + x_test = np.random.rand(*shape).astype(config.floatX) + else: + x_test = np.random.rand(*shape).astype(config.floatX) + x_test_matrix = np.eye(7) * x_test + inverse_matrix = np.linalg.inv(x_test_matrix) + rewritten_inverse = f_rewritten(x_test) + + assert_allclose( + inverse_matrix, + rewritten_inverse, + atol=ATOL, + rtol=RTOL, + ) + + +@pytest.mark.parametrize("inv_op", ["inv", "pinv"]) +def test_inv_diag_from_diag(inv_op): + x = pt.dvector("x") + x_diag = pt.diag(x) + x_inv = get_pt_function(x_diag, inv_op) + + # REWRITE TEST + f_rewritten = function([x], x_inv, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + + valid_inverses = (MatrixInverse, MatrixPinv) + assert not any(isinstance(node.op, valid_inverses) for node in nodes) + + # NUMERIC VALUE TEST + x_test = np.random.rand(10) + x_test_matrix = np.eye(10) * x_test + inverse_matrix = np.linalg.inv(x_test_matrix) + rewritten_inverse = f_rewritten(x_test) + + assert_allclose( + inverse_matrix, + rewritten_inverse, + atol=ATOL, + rtol=RTOL, + )