Skip to content

Commit 9ac68db

Browse files
committed
added function for rewriting cholesky(eye) -> eye
1 parent e2f9cb8 commit 9ac68db

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,3 +536,39 @@ def svd_uv_merge(fgraph, node):
536536
or len(fgraph.clients[cl.outputs[2]]) > 0
537537
):
538538
return [cl.outputs[1]]
539+
540+
541+
@register_canonicalize
542+
@register_stabilize
543+
@node_rewriter([Blockwise])
544+
def rewrite_cholesky_eye_to_eye(fgraph, node):
545+
"""
546+
This rewrite takes advantage of the fact that the cholesky decomposition of an identity matrix is the matrix itself
547+
548+
The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside Cholesky.
549+
550+
Parameters
551+
----------
552+
fgraph: FunctionGraph
553+
Function graph being optimized
554+
node: Apply
555+
Node of the function graph to be optimized
556+
557+
Returns
558+
-------
559+
list of Variable, optional
560+
List of optimized variables, or None if no optimization was performed
561+
"""
562+
# Find whether cholesky op is being applied
563+
if not isinstance(node.op.core_op, Cholesky):
564+
return None
565+
566+
# Check whether input to Cholesky is Eye and the 1's are on main diagonal
567+
eye_check = node.inputs[0]
568+
if not (
569+
eye_check.owner
570+
and isinstance(eye_check.owner.op, Eye)
571+
and getattr(eye_check.owner.inputs[-1], "data", -1).item() == 0
572+
):
573+
return None
574+
return [eye_check]

tests/tensor/rewriting/test_linalg.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,3 +545,30 @@ def test_svd_uv_merge():
545545
assert node.op.compute_uv
546546
svd_counter += 1
547547
assert svd_counter == 1
548+
549+
550+
def test_cholesky_eye_rewrite():
551+
x = pt.eye(10)
552+
x_mat = pt.matrix("x")
553+
L = pt.linalg.cholesky(x)
554+
L_mat = pt.linalg.cholesky(x_mat)
555+
f_rewritten = function([], L, mode="FAST_RUN")
556+
f_rewritten_mat = function([x_mat], L_mat, mode="FAST_RUN")
557+
nodes = f_rewritten.maker.fgraph.apply_nodes
558+
nodes_mat = f_rewritten_mat.maker.fgraph.apply_nodes
559+
560+
# Rewrite Test
561+
assert not any(isinstance(node.op, Cholesky) for node in nodes)
562+
assert any(isinstance(node.op, Cholesky) for node in nodes_mat)
563+
564+
# Value Test
565+
x_test = np.eye(10)
566+
L = np.linalg.cholesky(x_test)
567+
rewritten_val = f_rewritten()
568+
569+
assert_allclose(
570+
L,
571+
rewritten_val,
572+
atol=1e-3 if config.floatX == "float32" else 1e-8,
573+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
574+
)

0 commit comments

Comments
 (0)