File tree 1 file changed +41
-0
lines changed 1 file changed +41
-0
lines changed Original file line number Diff line number Diff line change @@ -109,6 +109,47 @@ def psd_solve_with_chol(fgraph, node):
109
109
return [x ]
110
110
111
111
112
+ @register_canonicalize
113
+ @register_stabilize
114
+ @node_rewriter ([Cholesky ])
115
+ def chol_of_dot_chol (fgraph , node ):
116
+ """
117
+ This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices.
118
+ """
119
+ if not isinstance (node .op , Cholesky ):
120
+ return
121
+
122
+ A = node .inputs [0 ]
123
+ if not (A .owner and isinstance (A .owner .op , (Dot , Dot22 ))):
124
+ return
125
+
126
+ l , r = A .owner .inputs
127
+
128
+ # cholesky(dot(L,L.T)) case
129
+ if (
130
+ getattr (l .tag , "lower_triangular" , False )
131
+ and r .owner
132
+ and isinstance (r .owner .op , DimShuffle )
133
+ and r .owner .op .new_order == (1 , 0 )
134
+ and r .owner .inputs [0 ] == l
135
+ ):
136
+ if node .op .lower :
137
+ return [l ]
138
+ return [r ]
139
+
140
+ # cholesky(dot(U.T,U)) case
141
+ if (
142
+ getattr (r .tag , "upper_triangular" , False )
143
+ and l .owner
144
+ and isinstance (l .owner .op , DimShuffle )
145
+ and l .owner .op .new_order == (1 , 0 )
146
+ and l .owner .inputs [0 ] == r
147
+ ):
148
+ if node .op .lower :
149
+ return [l ]
150
+ return [r ]
151
+
152
+
112
153
@register_stabilize
113
154
@register_specialize
114
155
@node_rewriter ([Det ])
You can’t perform that action at this time.
0 commit comments