|
| 1 | +use crate::{error::*, layout::*, *}; |
| 2 | +use cauchy::*; |
| 3 | + |
| 4 | +pub trait SolveTridiagonalImpl: Scalar { |
| 5 | + fn solve_tridiagonal( |
| 6 | + lu: &LUFactorizedTridiagonal<Self>, |
| 7 | + bl: MatrixLayout, |
| 8 | + t: Transpose, |
| 9 | + b: &mut [Self], |
| 10 | + ) -> Result<()>; |
| 11 | +} |
| 12 | + |
| 13 | +macro_rules! impl_solve_tridiagonal { |
| 14 | + ($s:ty, $trs:path) => { |
| 15 | + impl SolveTridiagonalImpl for $s { |
| 16 | + fn solve_tridiagonal( |
| 17 | + lu: &LUFactorizedTridiagonal<Self>, |
| 18 | + b_layout: MatrixLayout, |
| 19 | + t: Transpose, |
| 20 | + b: &mut [Self], |
| 21 | + ) -> Result<()> { |
| 22 | + let (n, _) = lu.a.l.size(); |
| 23 | + let ipiv = &lu.ipiv; |
| 24 | + // Transpose if b is C-continuous |
| 25 | + let mut b_t = None; |
| 26 | + let b_layout = match b_layout { |
| 27 | + MatrixLayout::C { .. } => { |
| 28 | + let (layout, t) = transpose(b_layout, b); |
| 29 | + b_t = Some(t); |
| 30 | + layout |
| 31 | + } |
| 32 | + MatrixLayout::F { .. } => b_layout, |
| 33 | + }; |
| 34 | + let (ldb, nrhs) = b_layout.size(); |
| 35 | + let mut info = 0; |
| 36 | + unsafe { |
| 37 | + $trs( |
| 38 | + t.as_ptr(), |
| 39 | + &n, |
| 40 | + &nrhs, |
| 41 | + AsPtr::as_ptr(&lu.a.dl), |
| 42 | + AsPtr::as_ptr(&lu.a.d), |
| 43 | + AsPtr::as_ptr(&lu.a.du), |
| 44 | + AsPtr::as_ptr(&lu.du2), |
| 45 | + ipiv.as_ptr(), |
| 46 | + AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)), |
| 47 | + &ldb, |
| 48 | + &mut info, |
| 49 | + ); |
| 50 | + } |
| 51 | + info.as_lapack_result()?; |
| 52 | + if let Some(b_t) = b_t { |
| 53 | + transpose_over(b_layout, &b_t, b); |
| 54 | + } |
| 55 | + Ok(()) |
| 56 | + } |
| 57 | + } |
| 58 | + }; |
| 59 | +} |
| 60 | + |
| 61 | +impl_solve_tridiagonal!(c64, lapack_sys::zgttrs_); |
| 62 | +impl_solve_tridiagonal!(c32, lapack_sys::cgttrs_); |
| 63 | +impl_solve_tridiagonal!(f64, lapack_sys::dgttrs_); |
| 64 | +impl_solve_tridiagonal!(f32, lapack_sys::sgttrs_); |
0 commit comments