Skip to content

Commit a8550c2

Browse files
committed
SolveTridiagonalImpl
1 parent 17f9fb8 commit a8550c2

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

lax/src/tridiagonal/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
//! for tridiagonal matrix
33
44
mod matrix;
5+
mod solve;
56

67
pub use matrix::*;
8+
pub use solve::*;
79

810
use crate::{error::*, layout::*, *};
911
use cauchy::*;

lax/src/tridiagonal/solve.rs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
use crate::{error::*, layout::*, *};
2+
use cauchy::*;
3+
4+
pub trait SolveTridiagonalImpl: Scalar {
5+
fn solve_tridiagonal(
6+
lu: &LUFactorizedTridiagonal<Self>,
7+
bl: MatrixLayout,
8+
t: Transpose,
9+
b: &mut [Self],
10+
) -> Result<()>;
11+
}
12+
13+
macro_rules! impl_solve_tridiagonal {
14+
($s:ty, $trs:path) => {
15+
impl SolveTridiagonalImpl for $s {
16+
fn solve_tridiagonal(
17+
lu: &LUFactorizedTridiagonal<Self>,
18+
b_layout: MatrixLayout,
19+
t: Transpose,
20+
b: &mut [Self],
21+
) -> Result<()> {
22+
let (n, _) = lu.a.l.size();
23+
let ipiv = &lu.ipiv;
24+
// Transpose if b is C-continuous
25+
let mut b_t = None;
26+
let b_layout = match b_layout {
27+
MatrixLayout::C { .. } => {
28+
let (layout, t) = transpose(b_layout, b);
29+
b_t = Some(t);
30+
layout
31+
}
32+
MatrixLayout::F { .. } => b_layout,
33+
};
34+
let (ldb, nrhs) = b_layout.size();
35+
let mut info = 0;
36+
unsafe {
37+
$trs(
38+
t.as_ptr(),
39+
&n,
40+
&nrhs,
41+
AsPtr::as_ptr(&lu.a.dl),
42+
AsPtr::as_ptr(&lu.a.d),
43+
AsPtr::as_ptr(&lu.a.du),
44+
AsPtr::as_ptr(&lu.du2),
45+
ipiv.as_ptr(),
46+
AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)),
47+
&ldb,
48+
&mut info,
49+
);
50+
}
51+
info.as_lapack_result()?;
52+
if let Some(b_t) = b_t {
53+
transpose_over(b_layout, &b_t, b);
54+
}
55+
Ok(())
56+
}
57+
}
58+
};
59+
}
60+
61+
impl_solve_tridiagonal!(c64, lapack_sys::zgttrs_);
62+
impl_solve_tridiagonal!(c32, lapack_sys::cgttrs_);
63+
impl_solve_tridiagonal!(f64, lapack_sys::dgttrs_);
64+
impl_solve_tridiagonal!(f32, lapack_sys::sgttrs_);

0 commit comments

Comments
 (0)