Skip to content

Commit 748a32c

Browse files
committed
Added rewrite for slogdet; added docstrings for rewrites
1 parent 1aa9cb6 commit 748a32c

File tree

2 files changed

+82
-0
lines changed

2 files changed

+82
-0
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
inv,
3131
kron,
3232
pinv,
33+
slogdet,
3334
svd,
3435
)
3536
from pytensor.tensor.rewriting.basic import (
@@ -619,6 +620,23 @@ def rewrite_inv_inv(fgraph, node):
619620
@register_stabilize
620621
@node_rewriter([ExtractDiag])
621622
def rewrite_diag_kronecker(fgraph, node):
623+
"""
624+
This rewrite simplifies the diagonal of the kronecker product of 2 matrices by extracting the individual sub matrices and returning their outer product as a vector.
625+
626+
diag(kron(a,b)) -> outer(diag(a), diag(b))
627+
628+
Parameters
629+
----------
630+
fgraph: FunctionGraph
631+
Function graph being optimized
632+
node: Apply
633+
Node of the function graph to be optimized
634+
635+
Returns
636+
-------
637+
list of Variable, optional
638+
List of optimized variables, or None if no optimization was performed
639+
"""
622640
# Check for inner kron operation
623641
potential_kron = node.inputs[0].owner
624642
if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
@@ -630,3 +648,38 @@ def rewrite_diag_kronecker(fgraph, node):
630648
outer_prod_as_vector = outer(diag_a, diag_b).flatten()
631649

632650
return [outer_prod_as_vector]
651+
652+
653+
@register_canonicalize
654+
@register_stabilize
655+
@node_rewriter([slogdet])
656+
def rewrite_slogdet_kronecker(fgraph, node):
657+
"""
658+
This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
659+
660+
Parameters
661+
----------
662+
fgraph: FunctionGraph
663+
Function graph being optimized
664+
node: Apply
665+
Node of the function graph to be optimized
666+
667+
Returns
668+
-------
669+
list of Variable, optional
670+
List of optimized variables, or None if no optimization was performed
671+
"""
672+
# Check for inner kron operation
673+
potential_kron = node.inputs[0].owner
674+
if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
675+
return None
676+
677+
# Find the matrices
678+
a, b = potential_kron.inputs
679+
signs, logdets = zip(*[slogdet(a), slogdet(b)])
680+
sizes = [a.shape[-1], b.shape[-1]]
681+
prod_sizes = prod(sizes, no_zeros_in_input=True)
682+
signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)]
683+
logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)]
684+
685+
return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)]

tests/tensor/rewriting/test_linalg.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,3 +591,32 @@ def test_diag_kronecker_rewrite():
591591
atol=1e-3 if config.floatX == "float32" else 1e-8,
592592
rtol=1e-3 if config.floatX == "float32" else 1e-8,
593593
)
594+
595+
596+
def test_slogdet_kronecker_rewrite():
597+
a, b = pt.dmatrices("a", "b")
598+
kron_prod = pt.linalg.kron(a, b)
599+
sign_output, logdet_output = pt.linalg.slogdet(kron_prod)
600+
f_rewritten = function([kron_prod], [sign_output, logdet_output], mode="FAST_RUN")
601+
602+
# Rewrite Test
603+
nodes = f_rewritten.maker.fgraph.apply_nodes
604+
assert not any(isinstance(node.op, KroneckerProduct) for node in nodes)
605+
606+
# Value Test
607+
a_test, b_test = np.random.rand(2, 20, 20)
608+
kron_prod_test = np.kron(a_test, b_test)
609+
sign_output_test, logdet_output_test = np.linalg.slogdet(kron_prod_test)
610+
rewritten_sign_val, rewritten_logdet_val = f_rewritten(kron_prod_test)
611+
assert_allclose(
612+
sign_output_test,
613+
rewritten_sign_val,
614+
atol=1e-3 if config.floatX == "float32" else 1e-8,
615+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
616+
)
617+
assert_allclose(
618+
logdet_output_test,
619+
rewritten_logdet_val,
620+
atol=1e-3 if config.floatX == "float32" else 1e-8,
621+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
622+
)

0 commit comments

Comments
 (0)