Skip to content

Commit 40df71b

Browse files
committed
Added rewrite for slogdet; added docstrings for rewrites
1 parent 13b20e2 commit 40df71b

File tree

2 files changed

+82
-0
lines changed

2 files changed

+82
-0
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
inv,
3232
kron,
3333
pinv,
34+
slogdet,
3435
svd,
3536
)
3637
from pytensor.tensor.rewriting.basic import (
@@ -709,6 +710,23 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
709710
@register_stabilize
710711
@node_rewriter([ExtractDiag])
711712
def rewrite_diag_kronecker(fgraph, node):
713+
"""
714+
This rewrite simplifies the diagonal of the kronecker product of 2 matrices by extracting the individual sub matrices and returning their outer product as a vector.
715+
716+
diag(kron(a,b)) -> outer(diag(a), diag(b))
717+
718+
Parameters
719+
----------
720+
fgraph: FunctionGraph
721+
Function graph being optimized
722+
node: Apply
723+
Node of the function graph to be optimized
724+
725+
Returns
726+
-------
727+
list of Variable, optional
728+
List of optimized variables, or None if no optimization was performed
729+
"""
712730
# Check for inner kron operation
713731
potential_kron = node.inputs[0].owner
714732
if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
@@ -720,3 +738,38 @@ def rewrite_diag_kronecker(fgraph, node):
720738
outer_prod_as_vector = outer(diag_a, diag_b).flatten()
721739

722740
return [outer_prod_as_vector]
741+
742+
743+
@register_canonicalize
744+
@register_stabilize
745+
@node_rewriter([slogdet])
746+
def rewrite_slogdet_kronecker(fgraph, node):
747+
"""
748+
This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
749+
750+
Parameters
751+
----------
752+
fgraph: FunctionGraph
753+
Function graph being optimized
754+
node: Apply
755+
Node of the function graph to be optimized
756+
757+
Returns
758+
-------
759+
list of Variable, optional
760+
List of optimized variables, or None if no optimization was performed
761+
"""
762+
# Check for inner kron operation
763+
potential_kron = node.inputs[0].owner
764+
if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
765+
return None
766+
767+
# Find the matrices
768+
a, b = potential_kron.inputs
769+
signs, logdets = zip(*[slogdet(a), slogdet(b)])
770+
sizes = [a.shape[-1], b.shape[-1]]
771+
prod_sizes = prod(sizes, no_zeros_in_input=True)
772+
signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)]
773+
logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)]
774+
775+
return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)]

tests/tensor/rewriting/test_linalg.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,3 +685,32 @@ def test_diag_kronecker_rewrite():
685685
atol=1e-3 if config.floatX == "float32" else 1e-8,
686686
rtol=1e-3 if config.floatX == "float32" else 1e-8,
687687
)
688+
689+
690+
def test_slogdet_kronecker_rewrite():
691+
a, b = pt.dmatrices("a", "b")
692+
kron_prod = pt.linalg.kron(a, b)
693+
sign_output, logdet_output = pt.linalg.slogdet(kron_prod)
694+
f_rewritten = function([kron_prod], [sign_output, logdet_output], mode="FAST_RUN")
695+
696+
# Rewrite Test
697+
nodes = f_rewritten.maker.fgraph.apply_nodes
698+
assert not any(isinstance(node.op, KroneckerProduct) for node in nodes)
699+
700+
# Value Test
701+
a_test, b_test = np.random.rand(2, 20, 20)
702+
kron_prod_test = np.kron(a_test, b_test)
703+
sign_output_test, logdet_output_test = np.linalg.slogdet(kron_prod_test)
704+
rewritten_sign_val, rewritten_logdet_val = f_rewritten(kron_prod_test)
705+
assert_allclose(
706+
sign_output_test,
707+
rewritten_sign_val,
708+
atol=1e-3 if config.floatX == "float32" else 1e-8,
709+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
710+
)
711+
assert_allclose(
712+
logdet_output_test,
713+
rewritten_logdet_val,
714+
atol=1e-3 if config.floatX == "float32" else 1e-8,
715+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
716+
)

0 commit comments

Comments
 (0)