From 4e88e2910152a81a3921392c8f9ed7f680ea563c Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Sun, 21 Jul 2024 14:32:24 +0530 Subject: [PATCH 1/5] fixed merge conflicts --- pytensor/tensor/rewriting/linalg.py | 61 ++++++++++++++++++++++++++ tests/tensor/rewriting/test_linalg.py | 63 +++++++++++++++++++++++++++ 2 files changed, 124 insertions(+) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 3ab2960562..e094e2d527 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -887,3 +887,64 @@ def rewrite_slogdet_kronecker(fgraph, node): logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)] return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)] + + +@register_canonicalize +@register_stabilize +@node_rewriter([Blockwise]) +def rewrite_cholesky_eye_to_eye(fgraph, node): + """ + This rewrite takes advantage of the fact that the cholesky decomposition of an identity matrix is the matrix itself + + The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside Cholesky. + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + # Find whether cholesky op is being applied + if not isinstance(node.op.core_op, Cholesky): + return None + + # Check whether input to Cholesky is Eye and the 1's are on main diagonal + eye_check = node.inputs[0] + if not ( + eye_check.owner + and isinstance(eye_check.owner.op, Eye) + and getattr(eye_check.owner.inputs[-1], "data", -1).item() == 0 + ): + return None + return [eye_check] + + +@register_canonicalize +@register_stabilize +@node_rewriter([Blockwise]) +def rewrite_cholesky_diag_from_eye_mul(fgraph, node): + # Find whether cholesky op is being applied + if not isinstance(node.op.core_op, Cholesky): + return None + + # Check whether input is diagonal from multiplcation of identity matrix with a tensor + inputs = node.inputs[0] + inputs_or_none = _find_diag_from_eye_mul(inputs) + if inputs_or_none is None: + return None + + eye_input, non_eye_inputs = inputs_or_none + + # Dealing with only one other input + if len(non_eye_inputs) != 1: + return None + + eye_input, non_eye_input = eye_input[0], non_eye_inputs[0] + + return [eye_input * (non_eye_input**0.5)] diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 211facb484..4021ac792d 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -803,3 +803,66 @@ def test_slogdet_kronecker_rewrite(): atol=1e-3 if config.floatX == "float32" else 1e-8, rtol=1e-3 if config.floatX == "float32" else 1e-8, ) + + +def test_cholesky_eye_rewrite(): + x = pt.eye(10) + x_mat = pt.matrix("x") + L = pt.linalg.cholesky(x) + L_mat = pt.linalg.cholesky(x_mat) + f_rewritten = function([], L, mode="FAST_RUN") + f_rewritten_mat = function([x_mat], L_mat, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + nodes_mat = f_rewritten_mat.maker.fgraph.apply_nodes + + # Rewrite Test + assert not any(isinstance(node.op, Cholesky) for node in nodes) + assert any(isinstance(node.op, Cholesky) for node in nodes_mat) + + # Value Test + x_test = np.eye(10) + L = np.linalg.cholesky(x_test) + rewritten_val = f_rewritten() + + assert_allclose( + L, + rewritten_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + +@pytest.mark.parametrize( + "shape", + [(), (7,), (1, 7), (7, 1), (7, 7), (3, 7, 7)], + ids=["scalar", "vector", "row_vec", "col_vec", "matrix", "batched_input"], +) +def test_cholesky_diag_from_eye_mul(shape): + # Initializing x based on scalar/vector/matrix + x = pt.tensor("x", shape=shape) + y = pt.eye(7) * x + # Performing cholesky decomposition using pt.linalg.cholesky + z_cholesky = pt.linalg.cholesky(y) + + # REWRITE TEST + f_rewritten = function([x], z_cholesky, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, Cholesky) for node in nodes) + + # NUMERIC VALUE TEST + if len(shape) == 0: + x_test = np.array(np.random.rand()).astype(config.floatX) + elif len(shape) == 1: + x_test = np.random.rand(*shape).astype(config.floatX) + else: + x_test = np.random.rand(*shape).astype(config.floatX) + x_test_matrix = np.eye(7) * x_test + cholesky_val = np.linalg.cholesky(x_test_matrix) + rewritten_val = f_rewritten(x_test) + + assert_allclose( + cholesky_val, + rewritten_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) From d573d570afe640b2b1e8a3642684176db4f1517d Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Sun, 21 Jul 2024 15:41:20 +0530 Subject: [PATCH 2/5] fixed failing tests and added rewrite for pt.diag --- pytensor/tensor/rewriting/linalg.py | 27 +++++++++++++++++++++---- tests/tensor/rewriting/test_linalg.py | 29 +++++++++++++++++++++++++-- 2 files changed, 50 insertions(+), 6 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index e094e2d527..5c686bd520 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -2,6 +2,7 @@ from collections.abc import Callable from typing import cast +import pytensor.tensor as pt from pytensor import Variable from pytensor import tensor as pt from pytensor.graph import Apply, FunctionGraph @@ -928,13 +929,24 @@ def rewrite_cholesky_eye_to_eye(fgraph, node): @register_canonicalize @register_stabilize @node_rewriter([Blockwise]) -def rewrite_cholesky_diag_from_eye_mul(fgraph, node): +def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node): # Find whether cholesky op is being applied if not isinstance(node.op.core_op, Cholesky): return None - # Check whether input is diagonal from multiplcation of identity matrix with a tensor inputs = node.inputs[0] + # Check for use of pt.diag first + if ( + inputs.owner + and isinstance(inputs.owner.op, AllocDiag) + and AllocDiag.is_offset_zero(inputs.owner) + ): + cholesky_input = inputs.owner.inputs[0] + if cholesky_input.type.ndim == 1: + cholesky_val = pt.diag(cholesky_input**0.5) + return [cholesky_val] + + # Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix inputs_or_none = _find_diag_from_eye_mul(inputs) if inputs_or_none is None: return None @@ -945,6 +957,13 @@ def rewrite_cholesky_diag_from_eye_mul(fgraph, node): if len(non_eye_inputs) != 1: return None - eye_input, non_eye_input = eye_input[0], non_eye_inputs[0] + non_eye_input = non_eye_inputs[0] - return [eye_input * (non_eye_input**0.5)] + # Now, we can simply return the matrix consisting of sqrt values of the original diagonal elements + # For a matrix, we have to first extract the diagonal (non-zero values) and then only use those + if non_eye_input.type.broadcastable[-2:] == (False, False): + # For Matrix + return [eye_input * (non_eye_input.diagonal(axis1=-1, axis2=-2) ** 0.5)] + else: + # For Vector or Scalar + return [eye_input * (non_eye_input**0.5)] diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 4021ac792d..234e81566f 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -834,8 +834,8 @@ def test_cholesky_eye_rewrite(): @pytest.mark.parametrize( "shape", - [(), (7,), (1, 7), (7, 1), (7, 7), (3, 7, 7)], - ids=["scalar", "vector", "row_vec", "col_vec", "matrix", "batched_input"], + [(), (7,), (7, 7)], + ids=["scalar", "vector", "matrix"], ) def test_cholesky_diag_from_eye_mul(shape): # Initializing x based on scalar/vector/matrix @@ -866,3 +866,28 @@ def test_cholesky_diag_from_eye_mul(shape): atol=1e-3 if config.floatX == "float32" else 1e-8, rtol=1e-3 if config.floatX == "float32" else 1e-8, ) + + +def test_cholesky_diag_from_diag(): + x = pt.dvector("x") + x_diag = pt.diag(x) + x_cholesky = pt.linalg.cholesky(x_diag) + + # REWRITE TEST + f_rewritten = function([x], x_cholesky, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + + assert not any(isinstance(node.op, Cholesky) for node in nodes) + + # NUMERIC VALUE TEST + x_test = np.random.rand(10) + x_test_matrix = np.eye(10) * x_test + cholesky_val = np.linalg.cholesky(x_test_matrix) + rewritten_cholesky = f_rewritten(x_test) + + assert_allclose( + cholesky_val, + rewritten_cholesky, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) From c54f3c144b318825980c02857c0d96e9bd3d17e2 Mon Sep 17 00:00:00 2001 From: Tanish Date: Tue, 30 Jul 2024 15:56:43 +0530 Subject: [PATCH 3/5] minor changes; added test to not apply rewrite --- pytensor/tensor/rewriting/linalg.py | 29 ++++++++++++--------------- tests/tensor/rewriting/test_linalg.py | 17 +++++++++++----- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 5c686bd520..cfac580b52 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -2,7 +2,6 @@ from collections.abc import Callable from typing import cast -import pytensor.tensor as pt from pytensor import Variable from pytensor import tensor as pt from pytensor.graph import Apply, FunctionGraph @@ -893,7 +892,7 @@ def rewrite_slogdet_kronecker(fgraph, node): @register_canonicalize @register_stabilize @node_rewriter([Blockwise]) -def rewrite_cholesky_eye_to_eye(fgraph, node): +def rewrite_remove_useless_cholesky(fgraph, node): """ This rewrite takes advantage of the fact that the cholesky decomposition of an identity matrix is the matrix itself @@ -916,14 +915,15 @@ def rewrite_cholesky_eye_to_eye(fgraph, node): return None # Check whether input to Cholesky is Eye and the 1's are on main diagonal - eye_check = node.inputs[0] + potential_eye = node.inputs[0] if not ( - eye_check.owner - and isinstance(eye_check.owner.op, Eye) - and getattr(eye_check.owner.inputs[-1], "data", -1).item() == 0 + potential_eye.owner + and isinstance(potential_eye.owner.op, Eye) + and hasattr(potential_eye.owner.inputs[-1], "data") + and potential_eye.owner.inputs[-1].data.item() == 0 ): return None - return [eye_check] + return [potential_eye] @register_canonicalize @@ -941,10 +941,9 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node): and isinstance(inputs.owner.op, AllocDiag) and AllocDiag.is_offset_zero(inputs.owner) ): - cholesky_input = inputs.owner.inputs[0] - if cholesky_input.type.ndim == 1: - cholesky_val = pt.diag(cholesky_input**0.5) - return [cholesky_val] + diag_input = inputs.owner.inputs[0] + cholesky_val = pt.diag(diag_input**0.5) + return [cholesky_val] # Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix inputs_or_none = _find_diag_from_eye_mul(inputs) @@ -962,8 +961,6 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node): # Now, we can simply return the matrix consisting of sqrt values of the original diagonal elements # For a matrix, we have to first extract the diagonal (non-zero values) and then only use those if non_eye_input.type.broadcastable[-2:] == (False, False): - # For Matrix - return [eye_input * (non_eye_input.diagonal(axis1=-1, axis2=-2) ** 0.5)] - else: - # For Vector or Scalar - return [eye_input * (non_eye_input**0.5)] + non_eye_input = non_eye_input.diagonal(axis1=-1, axis2=-2) + + return [eye_input * (non_eye_input**0.5)] diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 234e81566f..98210a0dc4 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -807,17 +807,12 @@ def test_slogdet_kronecker_rewrite(): def test_cholesky_eye_rewrite(): x = pt.eye(10) - x_mat = pt.matrix("x") L = pt.linalg.cholesky(x) - L_mat = pt.linalg.cholesky(x_mat) f_rewritten = function([], L, mode="FAST_RUN") - f_rewritten_mat = function([x_mat], L_mat, mode="FAST_RUN") nodes = f_rewritten.maker.fgraph.apply_nodes - nodes_mat = f_rewritten_mat.maker.fgraph.apply_nodes # Rewrite Test assert not any(isinstance(node.op, Cholesky) for node in nodes) - assert any(isinstance(node.op, Cholesky) for node in nodes_mat) # Value Test x_test = np.eye(10) @@ -891,3 +886,15 @@ def test_cholesky_diag_from_diag(): atol=1e-3 if config.floatX == "float32" else 1e-8, rtol=1e-3 if config.floatX == "float32" else 1e-8, ) + + +def test_dont_apply_cholesky(): + x = pt.tensor("x", shape=(7, 7)) + y = pt.eye(7, k=-1) * x + # Here, y is not a diagonal matrix because of k = -1 + z_cholesky = pt.linalg.cholesky(y) + + # REWRITE TEST (should not be applied) + f_rewritten = function([x], z_cholesky, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + assert any(isinstance(node.op, Cholesky) for node in nodes) From 5fc76c261311f8915bbb332740caa0dccfb60ddf Mon Sep 17 00:00:00 2001 From: Tanish Date: Mon, 5 Aug 2024 17:48:01 +0530 Subject: [PATCH 4/5] added test for batched case and more cases of not applying rewrite --- pytensor/tensor/rewriting/linalg.py | 2 ++ tests/tensor/rewriting/test_linalg.py | 16 ++++++++++++---- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index cfac580b52..49a73012ae 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -962,5 +962,7 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node): # For a matrix, we have to first extract the diagonal (non-zero values) and then only use those if non_eye_input.type.broadcastable[-2:] == (False, False): non_eye_input = non_eye_input.diagonal(axis1=-1, axis2=-2) + if eye_input.type.ndim > 2: + non_eye_input = pt.shape_padaxis(non_eye_input, -2) return [eye_input * (non_eye_input**0.5)] diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 98210a0dc4..9dd2a247a8 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -829,8 +829,8 @@ def test_cholesky_eye_rewrite(): @pytest.mark.parametrize( "shape", - [(), (7,), (7, 7)], - ids=["scalar", "vector", "matrix"], + [(), (7,), (7, 7), (5, 7, 7)], + ids=["scalar", "vector", "matrix", "batched"], ) def test_cholesky_diag_from_eye_mul(shape): # Initializing x based on scalar/vector/matrix @@ -888,13 +888,21 @@ def test_cholesky_diag_from_diag(): ) -def test_dont_apply_cholesky(): +def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied(): + # Case 1 : y is not a diagonal matrix because of k = -1 x = pt.tensor("x", shape=(7, 7)) y = pt.eye(7, k=-1) * x - # Here, y is not a diagonal matrix because of k = -1 z_cholesky = pt.linalg.cholesky(y) # REWRITE TEST (should not be applied) f_rewritten = function([x], z_cholesky, mode="FAST_RUN") nodes = f_rewritten.maker.fgraph.apply_nodes assert any(isinstance(node.op, Cholesky) for node in nodes) + + # Case 2 : eye is degenerate + x = pt.scalar("x") + y = pt.eye(1) * x + z_cholesky = pt.linalg.cholesky(y) + f_rewritten = function([x], z_cholesky, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + assert any(isinstance(node.op, Cholesky) for node in nodes) From cf873626445e37e520da965e9a4fb96d5c4ffee5 Mon Sep 17 00:00:00 2001 From: Tanish Date: Mon, 5 Aug 2024 23:25:01 +0530 Subject: [PATCH 5/5] minor changes --- pytensor/tensor/rewriting/linalg.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 49a73012ae..798d590d7f 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -934,19 +934,19 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node): if not isinstance(node.op.core_op, Cholesky): return None - inputs = node.inputs[0] + [input] = node.inputs # Check for use of pt.diag first if ( - inputs.owner - and isinstance(inputs.owner.op, AllocDiag) - and AllocDiag.is_offset_zero(inputs.owner) + input.owner + and isinstance(input.owner.op, AllocDiag) + and AllocDiag.is_offset_zero(input.owner) ): - diag_input = inputs.owner.inputs[0] + diag_input = input.owner.inputs[0] cholesky_val = pt.diag(diag_input**0.5) return [cholesky_val] # Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix - inputs_or_none = _find_diag_from_eye_mul(inputs) + inputs_or_none = _find_diag_from_eye_mul(input) if inputs_or_none is None: return None @@ -956,7 +956,7 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node): if len(non_eye_inputs) != 1: return None - non_eye_input = non_eye_inputs[0] + [non_eye_input] = non_eye_inputs # Now, we can simply return the matrix consisting of sqrt values of the original diagonal elements # For a matrix, we have to first extract the diagonal (non-zero values) and then only use those