From 638d09d480e30a1026b12c40c6d52675b998c785 Mon Sep 17 00:00:00 2001 From: bluss Date: Mon, 11 Apr 2016 23:23:17 +0200 Subject: [PATCH 1/3] Use Send + Sync in LinalgScalar --- src/linalg_traits.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/linalg_traits.rs b/src/linalg_traits.rs index d73612500..60865256e 100644 --- a/src/linalg_traits.rs +++ b/src/linalg_traits.rs @@ -24,7 +24,7 @@ use ScalarOperand; /// semantics or destructors, and the rest are numerical traits. pub trait LinalgScalar : Any + - Copy + + Copy + Send + Sync + Zero + One + Add + Sub + @@ -35,7 +35,7 @@ pub trait LinalgScalar : impl LinalgScalar for T where T: Any + - Copy + + Copy + Send + Sync + Zero + One + Add + Sub + From 2d41bddef60a3dd3ec6a81c6ba4b7ee324a2eb0c Mon Sep 17 00:00:00 2001 From: bluss Date: Mon, 11 Apr 2016 23:23:51 +0200 Subject: [PATCH 2/3] Experimental divide & conquer in matrix multiply using rayon --- Cargo.toml | 2 ++ src/lib.rs | 1 + src/linalg/impl_linalg.rs | 25 ++++++++++++++++++++++++- 3 files changed, 27 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 714e2ae69..d908ed3ef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,8 @@ version = "0.3.16" optional = true [dependencies] +rayon = "0.3.1" + # Use via the `blas` crate feature! blas-sys = { version = "0.6.2", optional = true, default-features = false } openblas-provider = { version = "0.4.1", optional = true, default-features = false } diff --git a/src/lib.rs b/src/lib.rs index 8b69b2141..177a1a70e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -73,6 +73,7 @@ extern crate rustc_serialize as serialize; extern crate blas_sys; extern crate matrixmultiply; +extern crate rayon; extern crate itertools; extern crate num as libnum; diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index 824fdae0a..bc2bdd685 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -6,6 +6,7 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. +use rayon; use libnum::Zero; use itertools::free::enumerate; @@ -413,6 +414,8 @@ fn mat_mul_impl(alpha: A, mat_mul_general(alpha, lhs, rhs, beta, c) } +const SPLIT: usize = 64; + /// C ← α A B + β C fn mat_mul_general(alpha: A, lhs: &ArrayView, @@ -421,7 +424,27 @@ fn mat_mul_general(alpha: A, c: &mut ArrayViewMut) where A: LinalgScalar, { - let ((m, k), (_, n)) = (lhs.dim, rhs.dim); + let ((m, k), (k2, n)) = (lhs.dim, rhs.dim); + + debug_assert_eq!(k, k2); + if m > SPLIT { + // [ A0 ] B = [ C0 ] + // [ A1 ] [ C1 ] + let mid = m / 2; + let (a0, a1) = lhs.split_at(Axis(0), mid); + let (mut c0, mut c1) = c.view_mut().split_at(Axis(0), mid); + rayon::join(move || mat_mul_general(alpha, &a0, rhs, beta, &mut c0), + move || mat_mul_general(alpha, &a1, rhs, beta, &mut c1)); + return; + } else if n > SPLIT { + // A [ B0 B1 ] = [ C0 C1 ] + let mid = n / 2; + let (b0, b1) = rhs.split_at(Axis(1), mid); + let (mut c0, mut c1) = c.view_mut().split_at(Axis(1), mid); + rayon::join(move || mat_mul_general(alpha, lhs, &b0, beta, &mut c0), + move || mat_mul_general(alpha, lhs, &b1, beta, &mut c1)); + return; + } // common parameters for gemm let ap = lhs.as_ptr(); From f3ac6095b0dd98dc593d24f05054bdb96b7b3f53 Mon Sep 17 00:00:00 2001 From: bluss Date: Mon, 18 Apr 2016 14:05:04 +0200 Subject: [PATCH 3/3] Move parallel computation to a separate function --- src/linalg/impl_linalg.rs | 51 +++++++++++++++++++++++++++++---------- 1 file changed, 38 insertions(+), 13 deletions(-) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index bc2bdd685..e7df8c143 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -416,16 +416,25 @@ fn mat_mul_impl(alpha: A, const SPLIT: usize = 64; -/// C ← α A B + β C -fn mat_mul_general(alpha: A, - lhs: &ArrayView, - rhs: &ArrayView, - beta: A, - c: &mut ArrayViewMut) +#[inline(never)] +fn mat_mul_par_start(alpha: A, + lhs: &ArrayView, + rhs: &ArrayView, + beta: A, + c: &mut ArrayViewMut) where A: LinalgScalar, { - let ((m, k), (k2, n)) = (lhs.dim, rhs.dim); + mat_mul_par(alpha, lhs, rhs, beta, c); +} +fn mat_mul_par(alpha: A, + lhs: &ArrayView, + rhs: &ArrayView, + beta: A, + c: &mut ArrayViewMut) + where A: LinalgScalar, +{ + let ((m, k), (k2, n)) = (lhs.dim, rhs.dim); debug_assert_eq!(k, k2); if m > SPLIT { // [ A0 ] B = [ C0 ] @@ -433,17 +442,33 @@ fn mat_mul_general(alpha: A, let mid = m / 2; let (a0, a1) = lhs.split_at(Axis(0), mid); let (mut c0, mut c1) = c.view_mut().split_at(Axis(0), mid); - rayon::join(move || mat_mul_general(alpha, &a0, rhs, beta, &mut c0), - move || mat_mul_general(alpha, &a1, rhs, beta, &mut c1)); - return; + rayon::join(move || mat_mul_par(alpha, &a0, rhs, beta, &mut c0), + move || mat_mul_par(alpha, &a1, rhs, beta, &mut c1)); } else if n > SPLIT { // A [ B0 B1 ] = [ C0 C1 ] let mid = n / 2; let (b0, b1) = rhs.split_at(Axis(1), mid); let (mut c0, mut c1) = c.view_mut().split_at(Axis(1), mid); - rayon::join(move || mat_mul_general(alpha, lhs, &b0, beta, &mut c0), - move || mat_mul_general(alpha, lhs, &b1, beta, &mut c1)); - return; + rayon::join(move || mat_mul_par(alpha, lhs, &b0, beta, &mut c0), + move || mat_mul_par(alpha, lhs, &b1, beta, &mut c1)); + } else { + mat_mul_general(alpha, lhs, rhs, beta, c); + } +} + +/// C ← α A B + β C +fn mat_mul_general(alpha: A, + lhs: &ArrayView, + rhs: &ArrayView, + beta: A, + c: &mut ArrayViewMut) + where A: LinalgScalar, +{ + let ((m, k), (k2, n)) = (lhs.dim, rhs.dim); + + debug_assert_eq!(k, k2); + if m > SPLIT || n > SPLIT { + return mat_mul_par_start(alpha, lhs, rhs, beta, c); } // common parameters for gemm