Skip to content

Commit cf87362

Browse files
tanish1729jessegrabowski
authored andcommitted
minor changes
1 parent 5fc76c2 commit cf87362

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -934,19 +934,19 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
934934
if not isinstance(node.op.core_op, Cholesky):
935935
return None
936936

937-
inputs = node.inputs[0]
937+
[input] = node.inputs
938938
# Check for use of pt.diag first
939939
if (
940-
inputs.owner
941-
and isinstance(inputs.owner.op, AllocDiag)
942-
and AllocDiag.is_offset_zero(inputs.owner)
940+
input.owner
941+
and isinstance(input.owner.op, AllocDiag)
942+
and AllocDiag.is_offset_zero(input.owner)
943943
):
944-
diag_input = inputs.owner.inputs[0]
944+
diag_input = input.owner.inputs[0]
945945
cholesky_val = pt.diag(diag_input**0.5)
946946
return [cholesky_val]
947947

948948
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
949-
inputs_or_none = _find_diag_from_eye_mul(inputs)
949+
inputs_or_none = _find_diag_from_eye_mul(input)
950950
if inputs_or_none is None:
951951
return None
952952

@@ -956,7 +956,7 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
956956
if len(non_eye_inputs) != 1:
957957
return None
958958

959-
non_eye_input = non_eye_inputs[0]
959+
[non_eye_input] = non_eye_inputs
960960

961961
# Now, we can simply return the matrix consisting of sqrt values of the original diagonal elements
962962
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those

0 commit comments

Comments
 (0)