Skip to content

Commit 182cb96

Browse files
committed
added rewrites for inv(diag(x)) and inv(orthonormal(x))
1 parent ad27dc7 commit 182cb96

File tree

2 files changed

+157
-0
lines changed

2 files changed

+157
-0
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,3 +569,95 @@ def svd_uv_merge(fgraph, node):
569569
or len(fgraph.clients[cl.outputs[2]]) > 0
570570
):
571571
return [cl.outputs[1]]
572+
573+
574+
@register_canonicalize
575+
@register_stabilize
576+
@node_rewriter([Blockwise])
577+
def rewrite_inv_for_diag_eye_mul(fgraph, node):
578+
"""
579+
This rewrite takes advantage of the fact that for a diagonal matrix, the inverse is a diagonal matrix with the new diagonal entries as reciprocals of the original diagonal elements.
580+
This function deals with diagonal matrix arising from the multiplicaton of eye with a scalar/vector/matrix
581+
582+
Parameters
583+
----------
584+
fgraph: FunctionGraph
585+
Function graph being optimized
586+
node: Apply
587+
Node of the function graph to be optimized
588+
589+
Returns
590+
-------
591+
list of Variable, optional
592+
List of optimized variables, or None if no optimization was performed
593+
"""
594+
# List of useful operations : Inv, Pinv
595+
valid_inverses = (MatrixInverse, MatrixPinv)
596+
core_op = node.op.core_op
597+
if not (isinstance(core_op, valid_inverses)):
598+
return None
599+
600+
# Dealing with diagonal matrix from eye_mul
601+
potential_mul_input = node.inputs[0]
602+
eye_non_eye_inputs = _find_diag_from_eye_mul(potential_mul_input)
603+
if eye_non_eye_inputs is not None:
604+
eye_input, non_eye_inputs = eye_non_eye_inputs
605+
else:
606+
return None
607+
608+
# Dealing with only one other input
609+
if len(non_eye_inputs) != 1:
610+
return None
611+
612+
useful_eye, useful_non_eye = eye_input[0], non_eye_inputs[0]
613+
614+
# For a matrix, we can first get the diagonal and then only use those
615+
if useful_non_eye.type.broadcastable[-2:] == (False, False):
616+
# For Matrix
617+
return [useful_eye * 1 / useful_non_eye.diagonal(axis1=-1, axis2=-2)]
618+
else:
619+
# For Scalar/Vector
620+
return [useful_eye * 1 / useful_non_eye]
621+
622+
623+
def rewrite_inv_for_diag_ptdiag(fgraph, node):
624+
pass
625+
626+
627+
@register_canonicalize
628+
@register_stabilize
629+
@node_rewriter([Blockwise])
630+
def rewrite_inv_for_orthonormal(fgraph, node):
631+
"""
632+
This rewrite takes advantage of the fact that for an orthonormal matrix, the inverse is simply the transpose.
633+
This function deals with orthonormal matrix arising from pt.linalg.svd decomposition (U, Vh) or arising from pt.linalg.qr
634+
635+
Parameters
636+
----------
637+
fgraph: FunctionGraph
638+
Function graph being optimized
639+
node: Apply
640+
Node of the function graph to be optimized
641+
642+
Returns
643+
-------
644+
list of Variable, optional
645+
List of optimized variables, or None if no optimization was performed
646+
"""
647+
# Dealing with orthonormal matrix from SVD
648+
# Check if input to Inverse is coming from SVD
649+
input_to_inv = node.inputs[0]
650+
# Check if this input is coming from SVD with compute_uv = True
651+
if not (
652+
input_to_inv.owner
653+
and isinstance(input_to_inv.owner.op, Blockwise)
654+
and isinstance(input_to_inv.owner.op.core_op, SVD)
655+
and input_to_inv.owner.op.core_op.compute_uv is True
656+
):
657+
return None
658+
659+
# To make sure input is orthonormal, we have to check that its not S (output order is U, S, Vh, so S is index 1)
660+
if input_to_inv == input_to_inv.owner.outputs[1]:
661+
return None
662+
663+
return [input_to_inv.T]

tests/tensor/rewriting/test_linalg.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,3 +554,68 @@ def test_svd_uv_merge():
554554
assert node.op.compute_uv
555555
svd_counter += 1
556556
assert svd_counter == 1
557+
558+
559+
@pytest.mark.parametrize(
560+
"shape",
561+
[(), (7,), (7, 7)],
562+
ids=["scalar", "vector", "matrix"],
563+
)
564+
def test_inv_diag_from_eye_mul(shape):
565+
# Initializing x based on scalar/vector/matrix
566+
x = pt.tensor("x", shape=shape)
567+
x_diag = pt.eye(7) * x
568+
# Calculating inverse using pt.linalg.inv
569+
x_inv = pt.linalg.inv(x_diag)
570+
571+
# REWRITE TEST
572+
f_rewritten = function([x], x_inv, mode="FAST_RUN")
573+
nodes = f_rewritten.maker.fgraph.apply_nodes
574+
575+
valid_inverses = (MatrixInverse, MatrixPinv)
576+
assert not any(isinstance(node.op, valid_inverses) for node in nodes)
577+
578+
# NUMERIC VALUE TEST
579+
if len(shape) == 0:
580+
x_test = np.array(np.random.rand()).astype(config.floatX)
581+
elif len(shape) == 1:
582+
x_test = np.random.rand(*shape).astype(config.floatX)
583+
else:
584+
x_test = np.random.rand(*shape).astype(config.floatX)
585+
x_test_matrix = np.eye(7) * x_test
586+
inverse_matrix = np.linalg.inv(x_test_matrix)
587+
rewritten_inverse = f_rewritten(x_test)
588+
589+
assert_allclose(
590+
inverse_matrix,
591+
rewritten_inverse,
592+
atol=1e-3 if config.floatX == "float32" else 1e-8,
593+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
594+
)
595+
596+
597+
def test_inv_orthonormal():
598+
x = pt.dmatrix("x")
599+
u, s, vh = pt.linalg.svd(x)
600+
# Calculating inverse using pt.linalg.inv
601+
u_inv = pt.linalg.inv(u)
602+
print(u_inv.dprint())
603+
# REWRITE TEST
604+
f_rewritten = function([x], u_inv, mode="FAST_RUN")
605+
nodes = f_rewritten.maker.fgraph.apply_nodes
606+
607+
valid_inverses = (MatrixInverse, MatrixPinv)
608+
assert not any(isinstance(node.op, valid_inverses) for node in nodes)
609+
610+
# NUMERIC VALUE TEST
611+
x_test = np.random.rand(7, 7).astype(config.floatX)
612+
u_test, _, _ = np.linalg.svd(x_test)
613+
inverse_matrix = np.linalg.inv(u_test)
614+
rewritten_inverse = f_rewritten(x_test)
615+
616+
assert_allclose(
617+
inverse_matrix,
618+
rewritten_inverse,
619+
atol=1e-3 if config.floatX == "float32" else 1e-8,
620+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
621+
)

0 commit comments

Comments
 (0)