|
22 | 22 | from pytensor.tensor.blas import Dot22
|
23 | 23 | from pytensor.tensor.blockwise import Blockwise
|
24 | 24 | from pytensor.tensor.elemwise import DimShuffle, Elemwise
|
25 |
| -from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod |
| 25 | +from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, outer, prod |
26 | 26 | from pytensor.tensor.nlinalg import (
|
27 | 27 | SVD,
|
28 | 28 | KroneckerProduct,
|
@@ -818,3 +818,72 @@ def rewrite_slogdet_blockdiag(fgraph, node):
|
818 | 818 | )
|
819 | 819 |
|
820 | 820 | return [prod(sign_sub_matrices), sum(logdet_sub_matrices)]
|
| 821 | + |
| 822 | + |
| 823 | +@register_canonicalize |
| 824 | +@register_stabilize |
| 825 | +@node_rewriter([ExtractDiag]) |
| 826 | +def rewrite_diag_kronecker(fgraph, node): |
| 827 | + """ |
| 828 | + This rewrite simplifies the diagonal of the kronecker product of 2 matrices by extracting the individual sub matrices and returning their outer product as a vector. |
| 829 | +
|
| 830 | + diag(kron(a,b)) -> outer(diag(a), diag(b)) |
| 831 | +
|
| 832 | + Parameters |
| 833 | + ---------- |
| 834 | + fgraph: FunctionGraph |
| 835 | + Function graph being optimized |
| 836 | + node: Apply |
| 837 | + Node of the function graph to be optimized |
| 838 | +
|
| 839 | + Returns |
| 840 | + ------- |
| 841 | + list of Variable, optional |
| 842 | + List of optimized variables, or None if no optimization was performed |
| 843 | + """ |
| 844 | + # Check for inner kron operation |
| 845 | + potential_kron = node.inputs[0].owner |
| 846 | + if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)): |
| 847 | + return None |
| 848 | + |
| 849 | + # Find the matrices |
| 850 | + a, b = potential_kron.inputs |
| 851 | + diag_a, diag_b = diag(a), diag(b) |
| 852 | + outer_prod_as_vector = outer(diag_a, diag_b).flatten() |
| 853 | + |
| 854 | + return [outer_prod_as_vector] |
| 855 | + |
| 856 | + |
| 857 | +@register_canonicalize |
| 858 | +@register_stabilize |
| 859 | +@node_rewriter([slogdet]) |
| 860 | +def rewrite_slogdet_kronecker(fgraph, node): |
| 861 | + """ |
| 862 | + This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those |
| 863 | +
|
| 864 | + Parameters |
| 865 | + ---------- |
| 866 | + fgraph: FunctionGraph |
| 867 | + Function graph being optimized |
| 868 | + node: Apply |
| 869 | + Node of the function graph to be optimized |
| 870 | +
|
| 871 | + Returns |
| 872 | + ------- |
| 873 | + list of Variable, optional |
| 874 | + List of optimized variables, or None if no optimization was performed |
| 875 | + """ |
| 876 | + # Check for inner kron operation |
| 877 | + potential_kron = node.inputs[0].owner |
| 878 | + if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)): |
| 879 | + return None |
| 880 | + |
| 881 | + # Find the matrices |
| 882 | + a, b = potential_kron.inputs |
| 883 | + signs, logdets = zip(*[slogdet(a), slogdet(b)]) |
| 884 | + sizes = [a.shape[-1], b.shape[-1]] |
| 885 | + prod_sizes = prod(sizes, no_zeros_in_input=True) |
| 886 | + signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)] |
| 887 | + logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)] |
| 888 | + |
| 889 | + return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)] |
0 commit comments