Skip to content

Commit d76c1b2

Browse files
tanish1729jessegrabowski
authored andcommitted
Added rewrite for slogdet; added docstrings for rewrites
1 parent 032cbca commit d76c1b2

File tree

2 files changed

+81
-0
lines changed

2 files changed

+81
-0
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,23 @@ def rewrite_slogdet_blockdiag(fgraph, node):
824824
@register_stabilize
825825
@node_rewriter([ExtractDiag])
826826
def rewrite_diag_kronecker(fgraph, node):
827+
"""
828+
This rewrite simplifies the diagonal of the kronecker product of 2 matrices by extracting the individual sub matrices and returning their outer product as a vector.
829+
830+
diag(kron(a,b)) -> outer(diag(a), diag(b))
831+
832+
Parameters
833+
----------
834+
fgraph: FunctionGraph
835+
Function graph being optimized
836+
node: Apply
837+
Node of the function graph to be optimized
838+
839+
Returns
840+
-------
841+
list of Variable, optional
842+
List of optimized variables, or None if no optimization was performed
843+
"""
827844
# Check for inner kron operation
828845
potential_kron = node.inputs[0].owner
829846
if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
@@ -835,3 +852,38 @@ def rewrite_diag_kronecker(fgraph, node):
835852
outer_prod_as_vector = outer(diag_a, diag_b).flatten()
836853

837854
return [outer_prod_as_vector]
855+
856+
857+
@register_canonicalize
858+
@register_stabilize
859+
@node_rewriter([slogdet])
860+
def rewrite_slogdet_kronecker(fgraph, node):
861+
"""
862+
This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
863+
864+
Parameters
865+
----------
866+
fgraph: FunctionGraph
867+
Function graph being optimized
868+
node: Apply
869+
Node of the function graph to be optimized
870+
871+
Returns
872+
-------
873+
list of Variable, optional
874+
List of optimized variables, or None if no optimization was performed
875+
"""
876+
# Check for inner kron operation
877+
potential_kron = node.inputs[0].owner
878+
if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
879+
return None
880+
881+
# Find the matrices
882+
a, b = potential_kron.inputs
883+
signs, logdets = zip(*[slogdet(a), slogdet(b)])
884+
sizes = [a.shape[-1], b.shape[-1]]
885+
prod_sizes = prod(sizes, no_zeros_in_input=True)
886+
signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)]
887+
logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)]
888+
889+
return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)]

tests/tensor/rewriting/test_linalg.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,3 +774,32 @@ def test_diag_kronecker_rewrite():
774774
atol=1e-3 if config.floatX == "float32" else 1e-8,
775775
rtol=1e-3 if config.floatX == "float32" else 1e-8,
776776
)
777+
778+
779+
def test_slogdet_kronecker_rewrite():
780+
a, b = pt.dmatrices("a", "b")
781+
kron_prod = pt.linalg.kron(a, b)
782+
sign_output, logdet_output = pt.linalg.slogdet(kron_prod)
783+
f_rewritten = function([kron_prod], [sign_output, logdet_output], mode="FAST_RUN")
784+
785+
# Rewrite Test
786+
nodes = f_rewritten.maker.fgraph.apply_nodes
787+
assert not any(isinstance(node.op, KroneckerProduct) for node in nodes)
788+
789+
# Value Test
790+
a_test, b_test = np.random.rand(2, 20, 20)
791+
kron_prod_test = np.kron(a_test, b_test)
792+
sign_output_test, logdet_output_test = np.linalg.slogdet(kron_prod_test)
793+
rewritten_sign_val, rewritten_logdet_val = f_rewritten(kron_prod_test)
794+
assert_allclose(
795+
sign_output_test,
796+
rewritten_sign_val,
797+
atol=1e-3 if config.floatX == "float32" else 1e-8,
798+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
799+
)
800+
assert_allclose(
801+
logdet_output_test,
802+
rewritten_logdet_val,
803+
atol=1e-3 if config.floatX == "float32" else 1e-8,
804+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
805+
)

0 commit comments

Comments
 (0)