Skip to content

Commit d9c52a2

Browse files
committed
LuTridiagonalWork
1 parent 6a35654 commit d9c52a2

File tree

3 files changed

+103
-39
lines changed

3 files changed

+103
-39
lines changed

lax/src/tridiagonal/lu.rs

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
use crate::*;
2+
use cauchy::*;
3+
use num_traits::Zero;
4+
5+
/// Represents the LU factorization of a tridiagonal matrix `A` as `A = P*L*U`.
6+
#[derive(Clone, PartialEq)]
7+
pub struct LUFactorizedTridiagonal<A: Scalar> {
8+
/// A tridiagonal matrix which consists of
9+
/// - l : layout of raw matrix
10+
/// - dl: (n-1) multipliers that define the matrix L.
11+
/// - d : (n) diagonal elements of the upper triangular matrix U.
12+
/// - du: (n-1) elements of the first super-diagonal of U.
13+
pub a: Tridiagonal<A>,
14+
/// (n-2) elements of the second super-diagonal of U.
15+
pub du2: Vec<A>,
16+
/// The pivot indices that define the permutation matrix `P`.
17+
pub ipiv: Pivot,
18+
19+
pub a_opnorm_one: A::Real,
20+
}
21+
22+
impl<A: Scalar> Tridiagonal<A> {
23+
fn opnorm_one(&self) -> A::Real {
24+
let mut col_sum: Vec<A::Real> = self.d.iter().map(|val| val.abs()).collect();
25+
for i in 0..col_sum.len() {
26+
if i < self.dl.len() {
27+
col_sum[i] += self.dl[i].abs();
28+
}
29+
if i > 0 {
30+
col_sum[i] += self.du[i - 1].abs();
31+
}
32+
}
33+
let mut max = A::Real::zero();
34+
for &val in &col_sum {
35+
if max < val {
36+
max = val;
37+
}
38+
}
39+
max
40+
}
41+
}
42+
43+
pub struct LuTridiagonalWork<T: Scalar> {
44+
pub layout: MatrixLayout,
45+
pub du2: Vec<MaybeUninit<T>>,
46+
pub ipiv: Vec<MaybeUninit<i32>>,
47+
}
48+
49+
pub trait LuTridiagonalWorkImpl {
50+
type Elem: Scalar;
51+
fn new(layout: MatrixLayout) -> Self;
52+
fn eval(self, a: Tridiagonal<Self::Elem>) -> Result<LUFactorizedTridiagonal<Self::Elem>>;
53+
}
54+
55+
macro_rules! impl_lu_tridiagonal_work {
56+
($s:ty, $trf:path) => {
57+
impl LuTridiagonalWorkImpl for LuTridiagonalWork<$s> {
58+
type Elem = $s;
59+
60+
fn new(layout: MatrixLayout) -> Self {
61+
let (n, _) = layout.size();
62+
let du2 = vec_uninit((n - 2) as usize);
63+
let ipiv = vec_uninit(n as usize);
64+
LuTridiagonalWork { layout, du2, ipiv }
65+
}
66+
67+
fn eval(
68+
mut self,
69+
mut a: Tridiagonal<Self::Elem>,
70+
) -> Result<LUFactorizedTridiagonal<Self::Elem>> {
71+
let (n, _) = self.layout.size();
72+
// We have to calc one-norm before LU factorization
73+
let a_opnorm_one = a.opnorm_one();
74+
let mut info = 0;
75+
unsafe {
76+
$trf(
77+
&n,
78+
AsPtr::as_mut_ptr(&mut a.dl),
79+
AsPtr::as_mut_ptr(&mut a.d),
80+
AsPtr::as_mut_ptr(&mut a.du),
81+
AsPtr::as_mut_ptr(&mut self.du2),
82+
AsPtr::as_mut_ptr(&mut self.ipiv),
83+
&mut info,
84+
)
85+
};
86+
info.as_lapack_result()?;
87+
Ok(LUFactorizedTridiagonal {
88+
a,
89+
du2: unsafe { self.du2.assume_init() },
90+
ipiv: unsafe { self.ipiv.assume_init() },
91+
a_opnorm_one,
92+
})
93+
}
94+
}
95+
};
96+
}
97+
98+
impl_lu_tridiagonal_work!(c64, lapack_sys::zgttrf_);
99+
impl_lu_tridiagonal_work!(c32, lapack_sys::cgttrf_);
100+
impl_lu_tridiagonal_work!(f64, lapack_sys::dgttrf_);
101+
impl_lu_tridiagonal_work!(f32, lapack_sys::sgttrf_);

lax/src/tridiagonal/matrix.rs

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
use crate::layout::*;
22
use cauchy::*;
3-
use num_traits::Zero;
43
use std::ops::{Index, IndexMut};
54

65
/// Represents a tridiagonal matrix as 3 one-dimensional vectors.
@@ -24,27 +23,6 @@ pub struct Tridiagonal<A: Scalar> {
2423
pub du: Vec<A>,
2524
}
2625

27-
impl<A: Scalar> Tridiagonal<A> {
28-
pub fn opnorm_one(&self) -> A::Real {
29-
let mut col_sum: Vec<A::Real> = self.d.iter().map(|val| val.abs()).collect();
30-
for i in 0..col_sum.len() {
31-
if i < self.dl.len() {
32-
col_sum[i] += self.dl[i].abs();
33-
}
34-
if i > 0 {
35-
col_sum[i] += self.du[i - 1].abs();
36-
}
37-
}
38-
let mut max = A::Real::zero();
39-
for &val in &col_sum {
40-
if max < val {
41-
max = val;
42-
}
43-
}
44-
max
45-
}
46-
}
47-
4826
impl<A: Scalar> Index<(i32, i32)> for Tridiagonal<A> {
4927
type Output = A;
5028
#[inline]

lax/src/tridiagonal/mod.rs

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
//! Implement linear solver using LU decomposition
22
//! for tridiagonal matrix
33
4+
mod lu;
45
mod matrix;
56
mod rcond;
67
mod solve;
78

9+
pub use lu::*;
810
pub use matrix::*;
911
pub use rcond::*;
1012
pub use solve::*;
@@ -13,23 +15,6 @@ use crate::{error::*, layout::*, *};
1315
use cauchy::*;
1416
use num_traits::Zero;
1517

16-
/// Represents the LU factorization of a tridiagonal matrix `A` as `A = P*L*U`.
17-
#[derive(Clone, PartialEq)]
18-
pub struct LUFactorizedTridiagonal<A: Scalar> {
19-
/// A tridiagonal matrix which consists of
20-
/// - l : layout of raw matrix
21-
/// - dl: (n-1) multipliers that define the matrix L.
22-
/// - d : (n) diagonal elements of the upper triangular matrix U.
23-
/// - du: (n-1) elements of the first super-diagonal of U.
24-
pub a: Tridiagonal<A>,
25-
/// (n-2) elements of the second super-diagonal of U.
26-
pub du2: Vec<A>,
27-
/// The pivot indices that define the permutation matrix `P`.
28-
pub ipiv: Pivot,
29-
30-
a_opnorm_one: A::Real,
31-
}
32-
3318
/// Wraps `*gttrf`, `*gtcon` and `*gttrs`
3419
pub trait Tridiagonal_: Scalar + Sized {
3520
/// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using

0 commit comments

Comments
 (0)