Skip to content

Commit 032cbca

Browse files
tanish1729jessegrabowski
authored andcommitted
Added rewrite for diag of kronecker product
1 parent 5632777 commit 032cbca

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pytensor.tensor.blas import Dot22
2323
from pytensor.tensor.blockwise import Blockwise
2424
from pytensor.tensor.elemwise import DimShuffle, Elemwise
25-
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
25+
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, outer, prod
2626
from pytensor.tensor.nlinalg import (
2727
SVD,
2828
KroneckerProduct,
@@ -818,3 +818,20 @@ def rewrite_slogdet_blockdiag(fgraph, node):
818818
)
819819

820820
return [prod(sign_sub_matrices), sum(logdet_sub_matrices)]
821+
822+
823+
@register_canonicalize
824+
@register_stabilize
825+
@node_rewriter([ExtractDiag])
826+
def rewrite_diag_kronecker(fgraph, node):
827+
# Check for inner kron operation
828+
potential_kron = node.inputs[0].owner
829+
if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
830+
return None
831+
832+
# Find the matrices
833+
a, b = potential_kron.inputs
834+
diag_a, diag_b = diag(a), diag(b)
835+
outer_prod_as_vector = outer(diag_a, diag_b).flatten()
836+
837+
return [outer_prod_as_vector]

tests/tensor/rewriting/test_linalg.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,3 +751,26 @@ def test_slogdet_blockdiag_rewrite():
751751
atol=1e-3 if config.floatX == "float32" else 1e-8,
752752
rtol=1e-3 if config.floatX == "float32" else 1e-8,
753753
)
754+
755+
756+
def test_diag_kronecker_rewrite():
757+
a, b = pt.dmatrices("a", "b")
758+
kron_prod = pt.linalg.kron(a, b)
759+
diag_kron_prod = pt.diag(kron_prod)
760+
f_rewritten = function([a, b], diag_kron_prod, mode="FAST_RUN")
761+
762+
# Rewrite Test
763+
nodes = f_rewritten.maker.fgraph.apply_nodes
764+
assert not any(isinstance(node.op, KroneckerProduct) for node in nodes)
765+
766+
# Value Test
767+
a_test, b_test = np.random.rand(2, 20, 20)
768+
kron_prod_test = np.kron(a_test, b_test)
769+
diag_kron_prod_test = np.diag(kron_prod_test)
770+
rewritten_val = f_rewritten(a_test, b_test)
771+
assert_allclose(
772+
diag_kron_prod_test,
773+
rewritten_val,
774+
atol=1e-3 if config.floatX == "float32" else 1e-8,
775+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
776+
)

0 commit comments

Comments
 (0)