Skip to content

Commit 4913818

Browse files
committed
Merge Tridiagonal_ into Lapack
1 parent d9c52a2 commit 4913818

File tree

2 files changed

+42
-148
lines changed

2 files changed

+42
-148
lines changed

lax/src/lib.rs

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,14 @@ extern crate openblas_src as _src;
8484
#[cfg(any(feature = "netlib-system", feature = "netlib-static"))]
8585
extern crate netlib_src as _src;
8686

87-
pub mod error;
88-
pub mod flags;
89-
pub mod layout;
90-
87+
pub mod alloc;
9188
pub mod cholesky;
9289
pub mod eig;
9390
pub mod eigh;
9491
pub mod eigh_generalized;
92+
pub mod error;
93+
pub mod flags;
94+
pub mod layout;
9595
pub mod least_squares;
9696
pub mod opnorm;
9797
pub mod qr;
@@ -101,16 +101,12 @@ pub mod solveh;
101101
pub mod svd;
102102
pub mod svddc;
103103
pub mod triangular;
104+
pub mod tridiagonal;
104105

105-
mod alloc;
106-
mod tridiagonal;
107-
108-
pub use self::cholesky::*;
109106
pub use self::flags::*;
110107
pub use self::least_squares::LeastSquaresOwned;
111-
pub use self::opnorm::*;
112108
pub use self::svd::{SvdOwned, SvdRef};
113-
pub use self::tridiagonal::*;
109+
pub use self::tridiagonal::{LUFactorizedTridiagonal, Tridiagonal};
114110

115111
use self::{alloc::*, error::*, layout::*};
116112
use cauchy::*;
@@ -120,7 +116,7 @@ pub type Pivot = Vec<i32>;
120116

121117
#[cfg_attr(doc, katexit::katexit)]
122118
/// Trait for primitive types which implements LAPACK subroutines
123-
pub trait Lapack: Tridiagonal_ {
119+
pub trait Lapack: Scalar {
124120
/// Compute right eigenvalue and eigenvectors for a general matrix
125121
fn eig(
126122
calc_v: bool,
@@ -306,6 +302,19 @@ pub trait Lapack: Tridiagonal_ {
306302
a: &[Self],
307303
b: &mut [Self],
308304
) -> Result<()>;
305+
306+
/// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using
307+
/// partial pivoting with row interchanges.
308+
fn lu_tridiagonal(a: Tridiagonal<Self>) -> Result<LUFactorizedTridiagonal<Self>>;
309+
310+
fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real>;
311+
312+
fn solve_tridiagonal(
313+
lu: &LUFactorizedTridiagonal<Self>,
314+
bl: MatrixLayout,
315+
t: Transpose,
316+
b: &mut [Self],
317+
) -> Result<()>;
309318
}
310319

311320
macro_rules! impl_lapack {
@@ -491,6 +500,28 @@ macro_rules! impl_lapack {
491500
use triangular::*;
492501
SolveTriangularImpl::solve_triangular(al, bl, uplo, d, a, b)
493502
}
503+
504+
fn lu_tridiagonal(a: Tridiagonal<Self>) -> Result<LUFactorizedTridiagonal<Self>> {
505+
use tridiagonal::*;
506+
let work = LuTridiagonalWork::<$s>::new(a.l);
507+
work.eval(a)
508+
}
509+
510+
fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real> {
511+
use tridiagonal::*;
512+
let mut work = RcondTridiagonalWork::<$s>::new(lu.a.l);
513+
work.calc(lu)
514+
}
515+
516+
fn solve_tridiagonal(
517+
lu: &LUFactorizedTridiagonal<Self>,
518+
bl: MatrixLayout,
519+
t: Transpose,
520+
b: &mut [Self],
521+
) -> Result<()> {
522+
use tridiagonal::*;
523+
SolveTridiagonalImpl::solve_tridiagonal(lu, bl, t, b)
524+
}
494525
}
495526
};
496527
}

lax/src/tridiagonal/mod.rs

Lines changed: 0 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -10,140 +10,3 @@ pub use lu::*;
1010
pub use matrix::*;
1111
pub use rcond::*;
1212
pub use solve::*;
13-
14-
use crate::{error::*, layout::*, *};
15-
use cauchy::*;
16-
use num_traits::Zero;
17-
18-
/// Wraps `*gttrf`, `*gtcon` and `*gttrs`
19-
pub trait Tridiagonal_: Scalar + Sized {
20-
/// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using
21-
/// partial pivoting with row interchanges.
22-
fn lu_tridiagonal(a: Tridiagonal<Self>) -> Result<LUFactorizedTridiagonal<Self>>;
23-
24-
fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real>;
25-
26-
fn solve_tridiagonal(
27-
lu: &LUFactorizedTridiagonal<Self>,
28-
bl: MatrixLayout,
29-
t: Transpose,
30-
b: &mut [Self],
31-
) -> Result<()>;
32-
}
33-
34-
macro_rules! impl_tridiagonal {
35-
(@real, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => {
36-
impl_tridiagonal!(@body, $scalar, $gttrf, $gtcon, $gttrs, iwork);
37-
};
38-
(@complex, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => {
39-
impl_tridiagonal!(@body, $scalar, $gttrf, $gtcon, $gttrs, );
40-
};
41-
(@body, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path, $($iwork:ident)*) => {
42-
impl Tridiagonal_ for $scalar {
43-
fn lu_tridiagonal(mut a: Tridiagonal<Self>) -> Result<LUFactorizedTridiagonal<Self>> {
44-
let (n, _) = a.l.size();
45-
let mut du2 = vec_uninit( (n - 2) as usize);
46-
let mut ipiv = vec_uninit( n as usize);
47-
// We have to calc one-norm before LU factorization
48-
let a_opnorm_one = a.opnorm_one();
49-
let mut info = 0;
50-
unsafe {
51-
$gttrf(
52-
&n,
53-
AsPtr::as_mut_ptr(&mut a.dl),
54-
AsPtr::as_mut_ptr(&mut a.d),
55-
AsPtr::as_mut_ptr(&mut a.du),
56-
AsPtr::as_mut_ptr(&mut du2),
57-
AsPtr::as_mut_ptr(&mut ipiv),
58-
&mut info,
59-
)
60-
};
61-
info.as_lapack_result()?;
62-
let du2 = unsafe { du2.assume_init() };
63-
let ipiv = unsafe { ipiv.assume_init() };
64-
Ok(LUFactorizedTridiagonal {
65-
a,
66-
du2,
67-
ipiv,
68-
a_opnorm_one,
69-
})
70-
}
71-
72-
fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real> {
73-
let (n, _) = lu.a.l.size();
74-
let ipiv = &lu.ipiv;
75-
let mut work: Vec<MaybeUninit<Self>> = vec_uninit(2 * n as usize);
76-
$(
77-
let mut $iwork: Vec<MaybeUninit<i32>> = vec_uninit(n as usize);
78-
)*
79-
let mut rcond = Self::Real::zero();
80-
let mut info = 0;
81-
unsafe {
82-
$gtcon(
83-
NormType::One.as_ptr(),
84-
&n,
85-
AsPtr::as_ptr(&lu.a.dl),
86-
AsPtr::as_ptr(&lu.a.d),
87-
AsPtr::as_ptr(&lu.a.du),
88-
AsPtr::as_ptr(&lu.du2),
89-
ipiv.as_ptr(),
90-
&lu.a_opnorm_one,
91-
&mut rcond,
92-
AsPtr::as_mut_ptr(&mut work),
93-
$(AsPtr::as_mut_ptr(&mut $iwork),)*
94-
&mut info,
95-
);
96-
}
97-
info.as_lapack_result()?;
98-
Ok(rcond)
99-
}
100-
101-
fn solve_tridiagonal(
102-
lu: &LUFactorizedTridiagonal<Self>,
103-
b_layout: MatrixLayout,
104-
t: Transpose,
105-
b: &mut [Self],
106-
) -> Result<()> {
107-
let (n, _) = lu.a.l.size();
108-
let ipiv = &lu.ipiv;
109-
// Transpose if b is C-continuous
110-
let mut b_t = None;
111-
let b_layout = match b_layout {
112-
MatrixLayout::C { .. } => {
113-
let (layout, t) = transpose(b_layout, b);
114-
b_t = Some(t);
115-
layout
116-
}
117-
MatrixLayout::F { .. } => b_layout,
118-
};
119-
let (ldb, nrhs) = b_layout.size();
120-
let mut info = 0;
121-
unsafe {
122-
$gttrs(
123-
t.as_ptr(),
124-
&n,
125-
&nrhs,
126-
AsPtr::as_ptr(&lu.a.dl),
127-
AsPtr::as_ptr(&lu.a.d),
128-
AsPtr::as_ptr(&lu.a.du),
129-
AsPtr::as_ptr(&lu.du2),
130-
ipiv.as_ptr(),
131-
AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)),
132-
&ldb,
133-
&mut info,
134-
);
135-
}
136-
info.as_lapack_result()?;
137-
if let Some(b_t) = b_t {
138-
transpose_over(b_layout, &b_t, b);
139-
}
140-
Ok(())
141-
}
142-
}
143-
};
144-
} // impl_tridiagonal!
145-
146-
impl_tridiagonal!(@real, f64, lapack_sys::dgttrf_, lapack_sys::dgtcon_, lapack_sys::dgttrs_);
147-
impl_tridiagonal!(@real, f32, lapack_sys::sgttrf_, lapack_sys::sgtcon_, lapack_sys::sgttrs_);
148-
impl_tridiagonal!(@complex, c64, lapack_sys::zgttrf_, lapack_sys::zgtcon_, lapack_sys::zgttrs_);
149-
impl_tridiagonal!(@complex, c32, lapack_sys::cgttrf_, lapack_sys::cgtcon_, lapack_sys::cgttrs_);

0 commit comments

Comments
 (0)