Skip to content

Commit 6e06f81

Browse files
committed
Fix numba symmetrical solve reciprocal of condition number
1 parent 8a81a53 commit 6e06f81

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,7 @@ def impl(
653653

654654
def _sysv(
655655
A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool
656-
) -> tuple[np.ndarray, np.ndarray, int]:
656+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]:
657657
"""
658658
Placeholder for solving a linear system with a symmetric matrix; used by linalg.solve.
659659
"""
@@ -664,7 +664,8 @@ def _sysv(
664664
def sysv_impl(
665665
A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool
666666
) -> Callable[
667-
[np.ndarray, np.ndarray, bool, bool, bool], tuple[np.ndarray, np.ndarray, int]
667+
[np.ndarray, np.ndarray, bool, bool, bool],
668+
tuple[np.ndarray, np.ndarray, np.ndarray, int],
668669
]:
669670
ensure_lapack()
670671
_check_scipy_linalg_matrix(A, "sysv")
@@ -740,8 +741,8 @@ def impl(
740741
)
741742

742743
if B_is_1d:
743-
return B_copy[..., 0], IPIV, int_ptr_to_val(INFO)
744-
return B_copy, IPIV, int_ptr_to_val(INFO)
744+
B_copy = B_copy[..., 0]
745+
return A_copy, B_copy, IPIV, int_ptr_to_val(INFO)
745746

746747
return impl
747748

@@ -770,7 +771,7 @@ def impl(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int
770771

771772
N = val_to_int_ptr(_N)
772773
LDA = val_to_int_ptr(_N)
773-
UPLO = val_to_int_ptr(ord("L"))
774+
UPLO = val_to_int_ptr(ord("U"))
774775
ANORM = np.array(anorm, dtype=dtype)
775776
RCOND = np.empty(1, dtype=dtype)
776777
WORK = np.empty(2 * _N, dtype=dtype)
@@ -843,10 +844,10 @@ def impl(
843844
) -> np.ndarray:
844845
_solve_check_input_shapes(A, B)
845846

846-
x, ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b)
847+
lu, x, ipiv, info = _sysv(A, B, lower, overwrite_a, overwrite_b)
847848
_solve_check(A.shape[-1], info)
848849

849-
rcond, info = _sycon(A, ipiv, _xlange(A, order="I"))
850+
rcond, info = _sycon(lu, ipiv, _xlange(A, order="I"))
850851
_solve_check(A.shape[-1], info, True, rcond)
851852

852853
return x

0 commit comments

Comments
 (0)