diff --git a/Cargo.toml b/Cargo.toml index 714e2ae69..d908ed3ef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,8 @@ version = "0.3.16" optional = true [dependencies] +rayon = "0.3.1" + # Use via the `blas` crate feature! blas-sys = { version = "0.6.2", optional = true, default-features = false } openblas-provider = { version = "0.4.1", optional = true, default-features = false } diff --git a/src/lib.rs b/src/lib.rs index 8b69b2141..177a1a70e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -73,6 +73,7 @@ extern crate rustc_serialize as serialize; extern crate blas_sys; extern crate matrixmultiply; +extern crate rayon; extern crate itertools; extern crate num as libnum; diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index 824fdae0a..e7df8c143 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -6,6 +6,7 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. +use rayon; use libnum::Zero; use itertools::free::enumerate; @@ -413,6 +414,48 @@ fn mat_mul_impl(alpha: A, mat_mul_general(alpha, lhs, rhs, beta, c) } +const SPLIT: usize = 64; + +#[inline(never)] +fn mat_mul_par_start(alpha: A, + lhs: &ArrayView, + rhs: &ArrayView, + beta: A, + c: &mut ArrayViewMut) + where A: LinalgScalar, +{ + mat_mul_par(alpha, lhs, rhs, beta, c); +} + +fn mat_mul_par(alpha: A, + lhs: &ArrayView, + rhs: &ArrayView, + beta: A, + c: &mut ArrayViewMut) + where A: LinalgScalar, +{ + let ((m, k), (k2, n)) = (lhs.dim, rhs.dim); + debug_assert_eq!(k, k2); + if m > SPLIT { + // [ A0 ] B = [ C0 ] + // [ A1 ] [ C1 ] + let mid = m / 2; + let (a0, a1) = lhs.split_at(Axis(0), mid); + let (mut c0, mut c1) = c.view_mut().split_at(Axis(0), mid); + rayon::join(move || mat_mul_par(alpha, &a0, rhs, beta, &mut c0), + move || mat_mul_par(alpha, &a1, rhs, beta, &mut c1)); + } else if n > SPLIT { + // A [ B0 B1 ] = [ C0 C1 ] + let mid = n / 2; + let (b0, b1) = rhs.split_at(Axis(1), mid); + let (mut c0, mut c1) = c.view_mut().split_at(Axis(1), mid); + rayon::join(move || mat_mul_par(alpha, lhs, &b0, beta, &mut c0), + move || mat_mul_par(alpha, lhs, &b1, beta, &mut c1)); + } else { + mat_mul_general(alpha, lhs, rhs, beta, c); + } +} + /// C ← α A B + β C fn mat_mul_general(alpha: A, lhs: &ArrayView, @@ -421,7 +464,12 @@ fn mat_mul_general(alpha: A, c: &mut ArrayViewMut) where A: LinalgScalar, { - let ((m, k), (_, n)) = (lhs.dim, rhs.dim); + let ((m, k), (k2, n)) = (lhs.dim, rhs.dim); + + debug_assert_eq!(k, k2); + if m > SPLIT || n > SPLIT { + return mat_mul_par_start(alpha, lhs, rhs, beta, c); + } // common parameters for gemm let ap = lhs.as_ptr(); diff --git a/src/linalg_traits.rs b/src/linalg_traits.rs index d73612500..60865256e 100644 --- a/src/linalg_traits.rs +++ b/src/linalg_traits.rs @@ -24,7 +24,7 @@ use ScalarOperand; /// semantics or destructors, and the rest are numerical traits. pub trait LinalgScalar : Any + - Copy + + Copy + Send + Sync + Zero + One + Add + Sub + @@ -35,7 +35,7 @@ pub trait LinalgScalar : impl LinalgScalar for T where T: Any + - Copy + + Copy + Send + Sync + Zero + One + Add + Sub +