Skip to content

Commit 2d41bdd

Browse files
committed
Experimental divide & conquer in matrix multiply using rayon
1 parent 638d09d commit 2d41bdd

File tree

3 files changed

+27
-1
lines changed

3 files changed

+27
-1
lines changed

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ version = "0.3.16"
3535
optional = true
3636

3737
[dependencies]
38+
rayon = "0.3.1"
39+
3840
# Use via the `blas` crate feature!
3941
blas-sys = { version = "0.6.2", optional = true, default-features = false }
4042
openblas-provider = { version = "0.4.1", optional = true, default-features = false }

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ extern crate rustc_serialize as serialize;
7373
extern crate blas_sys;
7474

7575
extern crate matrixmultiply;
76+
extern crate rayon;
7677

7778
extern crate itertools;
7879
extern crate num as libnum;

src/linalg/impl_linalg.rs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
// option. This file may not be copied, modified, or distributed
77
// except according to those terms.
88

9+
use rayon;
910
use libnum::Zero;
1011
use itertools::free::enumerate;
1112

@@ -413,6 +414,8 @@ fn mat_mul_impl<A>(alpha: A,
413414
mat_mul_general(alpha, lhs, rhs, beta, c)
414415
}
415416

417+
const SPLIT: usize = 64;
418+
416419
/// C ← α A B + β C
417420
fn mat_mul_general<A>(alpha: A,
418421
lhs: &ArrayView<A, (Ix, Ix)>,
@@ -421,7 +424,27 @@ fn mat_mul_general<A>(alpha: A,
421424
c: &mut ArrayViewMut<A, (Ix, Ix)>)
422425
where A: LinalgScalar,
423426
{
424-
let ((m, k), (_, n)) = (lhs.dim, rhs.dim);
427+
let ((m, k), (k2, n)) = (lhs.dim, rhs.dim);
428+
429+
debug_assert_eq!(k, k2);
430+
if m > SPLIT {
431+
// [ A0 ] B = [ C0 ]
432+
// [ A1 ] [ C1 ]
433+
let mid = m / 2;
434+
let (a0, a1) = lhs.split_at(Axis(0), mid);
435+
let (mut c0, mut c1) = c.view_mut().split_at(Axis(0), mid);
436+
rayon::join(move || mat_mul_general(alpha, &a0, rhs, beta, &mut c0),
437+
move || mat_mul_general(alpha, &a1, rhs, beta, &mut c1));
438+
return;
439+
} else if n > SPLIT {
440+
// A [ B0 B1 ] = [ C0 C1 ]
441+
let mid = n / 2;
442+
let (b0, b1) = rhs.split_at(Axis(1), mid);
443+
let (mut c0, mut c1) = c.view_mut().split_at(Axis(1), mid);
444+
rayon::join(move || mat_mul_general(alpha, lhs, &b0, beta, &mut c0),
445+
move || mat_mul_general(alpha, lhs, &b1, beta, &mut c1));
446+
return;
447+
}
425448

426449
// common parameters for gemm
427450
let ap = lhs.as_ptr();

0 commit comments

Comments
 (0)