Skip to content

Commit 3e98b9f

Browse files
authored
Adding rewrites involving kronecker product (#975)
* Added rewrite for diag of kronecker product * Added rewrite for slogdet; added docstrings for rewrites * fixed typo
1 parent 5632777 commit 3e98b9f

File tree

2 files changed

+122
-1
lines changed

2 files changed

+122
-1
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pytensor.tensor.blas import Dot22
2323
from pytensor.tensor.blockwise import Blockwise
2424
from pytensor.tensor.elemwise import DimShuffle, Elemwise
25-
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
25+
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, outer, prod
2626
from pytensor.tensor.nlinalg import (
2727
SVD,
2828
KroneckerProduct,
@@ -818,3 +818,72 @@ def rewrite_slogdet_blockdiag(fgraph, node):
818818
)
819819

820820
return [prod(sign_sub_matrices), sum(logdet_sub_matrices)]
821+
822+
823+
@register_canonicalize
824+
@register_stabilize
825+
@node_rewriter([ExtractDiag])
826+
def rewrite_diag_kronecker(fgraph, node):
827+
"""
828+
This rewrite simplifies the diagonal of the kronecker product of 2 matrices by extracting the individual sub matrices and returning their outer product as a vector.
829+
830+
diag(kron(a,b)) -> outer(diag(a), diag(b))
831+
832+
Parameters
833+
----------
834+
fgraph: FunctionGraph
835+
Function graph being optimized
836+
node: Apply
837+
Node of the function graph to be optimized
838+
839+
Returns
840+
-------
841+
list of Variable, optional
842+
List of optimized variables, or None if no optimization was performed
843+
"""
844+
# Check for inner kron operation
845+
potential_kron = node.inputs[0].owner
846+
if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
847+
return None
848+
849+
# Find the matrices
850+
a, b = potential_kron.inputs
851+
diag_a, diag_b = diag(a), diag(b)
852+
outer_prod_as_vector = outer(diag_a, diag_b).flatten()
853+
854+
return [outer_prod_as_vector]
855+
856+
857+
@register_canonicalize
858+
@register_stabilize
859+
@node_rewriter([slogdet])
860+
def rewrite_slogdet_kronecker(fgraph, node):
861+
"""
862+
This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
863+
864+
Parameters
865+
----------
866+
fgraph: FunctionGraph
867+
Function graph being optimized
868+
node: Apply
869+
Node of the function graph to be optimized
870+
871+
Returns
872+
-------
873+
list of Variable, optional
874+
List of optimized variables, or None if no optimization was performed
875+
"""
876+
# Check for inner kron operation
877+
potential_kron = node.inputs[0].owner
878+
if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
879+
return None
880+
881+
# Find the matrices
882+
a, b = potential_kron.inputs
883+
signs, logdets = zip(*[slogdet(a), slogdet(b)])
884+
sizes = [a.shape[-1], b.shape[-1]]
885+
prod_sizes = prod(sizes, no_zeros_in_input=True)
886+
signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)]
887+
logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)]
888+
889+
return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)]

tests/tensor/rewriting/test_linalg.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,3 +751,55 @@ def test_slogdet_blockdiag_rewrite():
751751
atol=1e-3 if config.floatX == "float32" else 1e-8,
752752
rtol=1e-3 if config.floatX == "float32" else 1e-8,
753753
)
754+
755+
756+
def test_diag_kronecker_rewrite():
757+
a, b = pt.dmatrices("a", "b")
758+
kron_prod = pt.linalg.kron(a, b)
759+
diag_kron_prod = pt.diag(kron_prod)
760+
f_rewritten = function([a, b], diag_kron_prod, mode="FAST_RUN")
761+
762+
# Rewrite Test
763+
nodes = f_rewritten.maker.fgraph.apply_nodes
764+
assert not any(isinstance(node.op, KroneckerProduct) for node in nodes)
765+
766+
# Value Test
767+
a_test, b_test = np.random.rand(2, 20, 20)
768+
kron_prod_test = np.kron(a_test, b_test)
769+
diag_kron_prod_test = np.diag(kron_prod_test)
770+
rewritten_val = f_rewritten(a_test, b_test)
771+
assert_allclose(
772+
diag_kron_prod_test,
773+
rewritten_val,
774+
atol=1e-3 if config.floatX == "float32" else 1e-8,
775+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
776+
)
777+
778+
779+
def test_slogdet_kronecker_rewrite():
780+
a, b = pt.dmatrices("a", "b")
781+
kron_prod = pt.linalg.kron(a, b)
782+
sign_output, logdet_output = pt.linalg.slogdet(kron_prod)
783+
f_rewritten = function([kron_prod], [sign_output, logdet_output], mode="FAST_RUN")
784+
785+
# Rewrite Test
786+
nodes = f_rewritten.maker.fgraph.apply_nodes
787+
assert not any(isinstance(node.op, KroneckerProduct) for node in nodes)
788+
789+
# Value Test
790+
a_test, b_test = np.random.rand(2, 20, 20)
791+
kron_prod_test = np.kron(a_test, b_test)
792+
sign_output_test, logdet_output_test = np.linalg.slogdet(kron_prod_test)
793+
rewritten_sign_val, rewritten_logdet_val = f_rewritten(kron_prod_test)
794+
assert_allclose(
795+
sign_output_test,
796+
rewritten_sign_val,
797+
atol=1e-3 if config.floatX == "float32" else 1e-8,
798+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
799+
)
800+
assert_allclose(
801+
logdet_output_test,
802+
rewritten_logdet_val,
803+
atol=1e-3 if config.floatX == "float32" else 1e-8,
804+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
805+
)

0 commit comments

Comments
 (0)