Skip to content

Commit f2fff31

Browse files
committed
added rewrites for inv(diag(x)) and inv(orthonormal(x))
1 parent df769f6 commit f2fff31

File tree

2 files changed

+157
-0
lines changed

2 files changed

+157
-0
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,3 +539,95 @@ def svd_uv_merge(fgraph, node):
539539
or len(fgraph.clients[cl.outputs[2]]) > 0
540540
):
541541
return [cl.outputs[1]]
542+
543+
544+
@register_canonicalize
545+
@register_stabilize
546+
@node_rewriter([Blockwise])
547+
def rewrite_inv_for_diag_eye_mul(fgraph, node):
548+
"""
549+
This rewrite takes advantage of the fact that for a diagonal matrix, the inverse is a diagonal matrix with the new diagonal entries as reciprocals of the original diagonal elements.
550+
This function deals with diagonal matrix arising from the multiplicaton of eye with a scalar/vector/matrix
551+
552+
Parameters
553+
----------
554+
fgraph: FunctionGraph
555+
Function graph being optimized
556+
node: Apply
557+
Node of the function graph to be optimized
558+
559+
Returns
560+
-------
561+
list of Variable, optional
562+
List of optimized variables, or None if no optimization was performed
563+
"""
564+
# List of useful operations : Inv, Pinv
565+
valid_inverses = (MatrixInverse, MatrixPinv)
566+
core_op = node.op.core_op
567+
if not (isinstance(core_op, valid_inverses)):
568+
return None
569+
570+
# Dealing with diagonal matrix from eye_mul
571+
potential_mul_input = node.inputs[0]
572+
eye_non_eye_inputs = _find_diag_from_eye_mul(potential_mul_input)
573+
if eye_non_eye_inputs is not None:
574+
eye_input, non_eye_inputs = eye_non_eye_inputs
575+
else:
576+
return None
577+
578+
# Dealing with only one other input
579+
if len(non_eye_inputs) != 1:
580+
return None
581+
582+
useful_eye, useful_non_eye = eye_input[0], non_eye_inputs[0]
583+
584+
# For a matrix, we can first get the diagonal and then only use those
585+
if useful_non_eye.type.broadcastable[-2:] == (False, False):
586+
# For Matrix
587+
return [useful_eye * 1 / useful_non_eye.diagonal(axis1=-1, axis2=-2)]
588+
else:
589+
# For Scalar/Vector
590+
return [useful_eye * 1 / useful_non_eye]
591+
592+
593+
def rewrite_inv_for_diag_ptdiag(fgraph, node):
594+
pass
595+
596+
597+
@register_canonicalize
598+
@register_stabilize
599+
@node_rewriter([Blockwise])
600+
def rewrite_inv_for_orthonormal(fgraph, node):
601+
"""
602+
This rewrite takes advantage of the fact that for an orthonormal matrix, the inverse is simply the transpose.
603+
This function deals with orthonormal matrix arising from pt.linalg.svd decomposition (U, Vh) or arising from pt.linalg.qr
604+
605+
Parameters
606+
----------
607+
fgraph: FunctionGraph
608+
Function graph being optimized
609+
node: Apply
610+
Node of the function graph to be optimized
611+
612+
Returns
613+
-------
614+
list of Variable, optional
615+
List of optimized variables, or None if no optimization was performed
616+
"""
617+
# Dealing with orthonormal matrix from SVD
618+
# Check if input to Inverse is coming from SVD
619+
input_to_inv = node.inputs[0]
620+
# Check if this input is coming from SVD with compute_uv = True
621+
if not (
622+
input_to_inv.owner
623+
and isinstance(input_to_inv.owner.op, Blockwise)
624+
and isinstance(input_to_inv.owner.op.core_op, SVD)
625+
and input_to_inv.owner.op.core_op.compute_uv is True
626+
):
627+
return None
628+
629+
# To make sure input is orthonormal, we have to check that its not S (output order is U, S, Vh, so S is index 1)
630+
if input_to_inv == input_to_inv.owner.outputs[1]:
631+
return None
632+
633+
return [input_to_inv.T]

tests/tensor/rewriting/test_linalg.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,3 +545,68 @@ def test_svd_uv_merge():
545545
assert node.op.compute_uv
546546
svd_counter += 1
547547
assert svd_counter == 1
548+
549+
550+
@pytest.mark.parametrize(
551+
"shape",
552+
[(), (7,), (7, 7)],
553+
ids=["scalar", "vector", "matrix"],
554+
)
555+
def test_inv_diag_from_eye_mul(shape):
556+
# Initializing x based on scalar/vector/matrix
557+
x = pt.tensor("x", shape=shape)
558+
x_diag = pt.eye(7) * x
559+
# Calculating inverse using pt.linalg.inv
560+
x_inv = pt.linalg.inv(x_diag)
561+
562+
# REWRITE TEST
563+
f_rewritten = function([x], x_inv, mode="FAST_RUN")
564+
nodes = f_rewritten.maker.fgraph.apply_nodes
565+
566+
valid_inverses = (MatrixInverse, MatrixPinv)
567+
assert not any(isinstance(node.op, valid_inverses) for node in nodes)
568+
569+
# NUMERIC VALUE TEST
570+
if len(shape) == 0:
571+
x_test = np.array(np.random.rand()).astype(config.floatX)
572+
elif len(shape) == 1:
573+
x_test = np.random.rand(*shape).astype(config.floatX)
574+
else:
575+
x_test = np.random.rand(*shape).astype(config.floatX)
576+
x_test_matrix = np.eye(7) * x_test
577+
inverse_matrix = np.linalg.inv(x_test_matrix)
578+
rewritten_inverse = f_rewritten(x_test)
579+
580+
assert_allclose(
581+
inverse_matrix,
582+
rewritten_inverse,
583+
atol=1e-3 if config.floatX == "float32" else 1e-8,
584+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
585+
)
586+
587+
588+
def test_inv_orthonormal():
589+
x = pt.dmatrix("x")
590+
u, s, vh = pt.linalg.svd(x)
591+
# Calculating inverse using pt.linalg.inv
592+
u_inv = pt.linalg.inv(u)
593+
print(u_inv.dprint())
594+
# REWRITE TEST
595+
f_rewritten = function([x], u_inv, mode="FAST_RUN")
596+
nodes = f_rewritten.maker.fgraph.apply_nodes
597+
598+
valid_inverses = (MatrixInverse, MatrixPinv)
599+
assert not any(isinstance(node.op, valid_inverses) for node in nodes)
600+
601+
# NUMERIC VALUE TEST
602+
x_test = np.random.rand(7, 7).astype(config.floatX)
603+
u_test, _, _ = np.linalg.svd(x_test)
604+
inverse_matrix = np.linalg.inv(u_test)
605+
rewritten_inverse = f_rewritten(x_test)
606+
607+
assert_allclose(
608+
inverse_matrix,
609+
rewritten_inverse,
610+
atol=1e-3 if config.floatX == "float32" else 1e-8,
611+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
612+
)

0 commit comments

Comments
 (0)