diff --git a/Cargo.toml b/Cargo.toml
index 714e2ae69..d908ed3ef 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -35,6 +35,8 @@ version = "0.3.16"
optional = true
[dependencies]
+rayon = "0.3.1"
+
# Use via the `blas` crate feature!
blas-sys = { version = "0.6.2", optional = true, default-features = false }
openblas-provider = { version = "0.4.1", optional = true, default-features = false }
diff --git a/src/lib.rs b/src/lib.rs
index 8b69b2141..177a1a70e 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -73,6 +73,7 @@ extern crate rustc_serialize as serialize;
extern crate blas_sys;
extern crate matrixmultiply;
+extern crate rayon;
extern crate itertools;
extern crate num as libnum;
diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs
index 824fdae0a..e7df8c143 100644
--- a/src/linalg/impl_linalg.rs
+++ b/src/linalg/impl_linalg.rs
@@ -6,6 +6,7 @@
// option. This file may not be copied, modified, or distributed
// except according to those terms.
+use rayon;
use libnum::Zero;
use itertools::free::enumerate;
@@ -413,6 +414,48 @@ fn mat_mul_impl(alpha: A,
mat_mul_general(alpha, lhs, rhs, beta, c)
}
+const SPLIT: usize = 64;
+
+#[inline(never)]
+fn mat_mul_par_start(alpha: A,
+ lhs: &ArrayView,
+ rhs: &ArrayView,
+ beta: A,
+ c: &mut ArrayViewMut)
+ where A: LinalgScalar,
+{
+ mat_mul_par(alpha, lhs, rhs, beta, c);
+}
+
+fn mat_mul_par(alpha: A,
+ lhs: &ArrayView,
+ rhs: &ArrayView,
+ beta: A,
+ c: &mut ArrayViewMut)
+ where A: LinalgScalar,
+{
+ let ((m, k), (k2, n)) = (lhs.dim, rhs.dim);
+ debug_assert_eq!(k, k2);
+ if m > SPLIT {
+ // [ A0 ] B = [ C0 ]
+ // [ A1 ] [ C1 ]
+ let mid = m / 2;
+ let (a0, a1) = lhs.split_at(Axis(0), mid);
+ let (mut c0, mut c1) = c.view_mut().split_at(Axis(0), mid);
+ rayon::join(move || mat_mul_par(alpha, &a0, rhs, beta, &mut c0),
+ move || mat_mul_par(alpha, &a1, rhs, beta, &mut c1));
+ } else if n > SPLIT {
+ // A [ B0 B1 ] = [ C0 C1 ]
+ let mid = n / 2;
+ let (b0, b1) = rhs.split_at(Axis(1), mid);
+ let (mut c0, mut c1) = c.view_mut().split_at(Axis(1), mid);
+ rayon::join(move || mat_mul_par(alpha, lhs, &b0, beta, &mut c0),
+ move || mat_mul_par(alpha, lhs, &b1, beta, &mut c1));
+ } else {
+ mat_mul_general(alpha, lhs, rhs, beta, c);
+ }
+}
+
/// C ← α A B + β C
fn mat_mul_general(alpha: A,
lhs: &ArrayView,
@@ -421,7 +464,12 @@ fn mat_mul_general(alpha: A,
c: &mut ArrayViewMut)
where A: LinalgScalar,
{
- let ((m, k), (_, n)) = (lhs.dim, rhs.dim);
+ let ((m, k), (k2, n)) = (lhs.dim, rhs.dim);
+
+ debug_assert_eq!(k, k2);
+ if m > SPLIT || n > SPLIT {
+ return mat_mul_par_start(alpha, lhs, rhs, beta, c);
+ }
// common parameters for gemm
let ap = lhs.as_ptr();
diff --git a/src/linalg_traits.rs b/src/linalg_traits.rs
index d73612500..60865256e 100644
--- a/src/linalg_traits.rs
+++ b/src/linalg_traits.rs
@@ -24,7 +24,7 @@ use ScalarOperand;
/// semantics or destructors, and the rest are numerical traits.
pub trait LinalgScalar :
Any +
- Copy +
+ Copy + Send + Sync +
Zero + One +
Add