File tree 2 files changed +63
-0
lines changed
pytensor/tensor/rewriting 2 files changed +63
-0
lines changed Original file line number Diff line number Diff line change @@ -536,3 +536,39 @@ def svd_uv_merge(fgraph, node):
536
536
or len (fgraph .clients [cl .outputs [2 ]]) > 0
537
537
):
538
538
return [cl .outputs [1 ]]
539
+
540
+
541
+ @register_canonicalize
542
+ @register_stabilize
543
+ @node_rewriter ([Blockwise ])
544
+ def rewrite_cholesky_eye_to_eye (fgraph , node ):
545
+ """
546
+ This rewrite takes advantage of the fact that the cholesky decomposition of an identity matrix is the matrix itself
547
+
548
+ The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside Cholesky.
549
+
550
+ Parameters
551
+ ----------
552
+ fgraph: FunctionGraph
553
+ Function graph being optimized
554
+ node: Apply
555
+ Node of the function graph to be optimized
556
+
557
+ Returns
558
+ -------
559
+ list of Variable, optional
560
+ List of optimized variables, or None if no optimization was performed
561
+ """
562
+ # Find whether cholesky op is being applied
563
+ if not isinstance (node .op .core_op , Cholesky ):
564
+ return None
565
+
566
+ # Check whether input to Cholesky is Eye and the 1's are on main diagonal
567
+ eye_check = node .inputs [0 ]
568
+ if not (
569
+ eye_check .owner
570
+ and isinstance (eye_check .owner .op , Eye )
571
+ and getattr (eye_check .owner .inputs [- 1 ], "data" , - 1 ).item () == 0
572
+ ):
573
+ return None
574
+ return [eye_check ]
Original file line number Diff line number Diff line change @@ -545,3 +545,30 @@ def test_svd_uv_merge():
545
545
assert node .op .compute_uv
546
546
svd_counter += 1
547
547
assert svd_counter == 1
548
+
549
+
550
+ def test_cholesky_eye_rewrite ():
551
+ x = pt .eye (10 )
552
+ x_mat = pt .matrix ("x" )
553
+ L = pt .linalg .cholesky (x )
554
+ L_mat = pt .linalg .cholesky (x_mat )
555
+ f_rewritten = function ([], L , mode = "FAST_RUN" )
556
+ f_rewritten_mat = function ([x_mat ], L_mat , mode = "FAST_RUN" )
557
+ nodes = f_rewritten .maker .fgraph .apply_nodes
558
+ nodes_mat = f_rewritten_mat .maker .fgraph .apply_nodes
559
+
560
+ # Rewrite Test
561
+ assert not any (isinstance (node .op , Cholesky ) for node in nodes )
562
+ assert any (isinstance (node .op , Cholesky ) for node in nodes_mat )
563
+
564
+ # Value Test
565
+ x_test = np .eye (10 )
566
+ L = np .linalg .cholesky (x_test )
567
+ rewritten_val = f_rewritten ()
568
+
569
+ assert_allclose (
570
+ L ,
571
+ rewritten_val ,
572
+ atol = 1e-3 if config .floatX == "float32" else 1e-8 ,
573
+ rtol = 1e-3 if config .floatX == "float32" else 1e-8 ,
574
+ )
You can’t perform that action at this time.
0 commit comments