@@ -934,19 +934,19 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
934
934
if not isinstance (node .op .core_op , Cholesky ):
935
935
return None
936
936
937
- inputs = node .inputs [ 0 ]
937
+ [ input ] = node .inputs
938
938
# Check for use of pt.diag first
939
939
if (
940
- inputs .owner
941
- and isinstance (inputs .owner .op , AllocDiag )
942
- and AllocDiag .is_offset_zero (inputs .owner )
940
+ input .owner
941
+ and isinstance (input .owner .op , AllocDiag )
942
+ and AllocDiag .is_offset_zero (input .owner )
943
943
):
944
- diag_input = inputs .owner .inputs [0 ]
944
+ diag_input = input .owner .inputs [0 ]
945
945
cholesky_val = pt .diag (diag_input ** 0.5 )
946
946
return [cholesky_val ]
947
947
948
948
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
949
- inputs_or_none = _find_diag_from_eye_mul (inputs )
949
+ inputs_or_none = _find_diag_from_eye_mul (input )
950
950
if inputs_or_none is None :
951
951
return None
952
952
@@ -956,7 +956,7 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
956
956
if len (non_eye_inputs ) != 1 :
957
957
return None
958
958
959
- non_eye_input = non_eye_inputs [ 0 ]
959
+ [ non_eye_input ] = non_eye_inputs
960
960
961
961
# Now, we can simply return the matrix consisting of sqrt values of the original diagonal elements
962
962
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
0 commit comments