6
6
// option. This file may not be copied, modified, or distributed
7
7
// except according to those terms.
8
8
9
+ use rayon;
9
10
use libnum:: Zero ;
10
11
use itertools:: free:: enumerate;
11
12
@@ -413,6 +414,8 @@ fn mat_mul_impl<A>(alpha: A,
413
414
mat_mul_general ( alpha, lhs, rhs, beta, c)
414
415
}
415
416
417
+ const SPLIT : usize = 64 ;
418
+
416
419
/// C ← α A B + β C
417
420
fn mat_mul_general < A > ( alpha : A ,
418
421
lhs : & ArrayView < A , ( Ix , Ix ) > ,
@@ -421,7 +424,27 @@ fn mat_mul_general<A>(alpha: A,
421
424
c : & mut ArrayViewMut < A , ( Ix , Ix ) > )
422
425
where A : LinalgScalar ,
423
426
{
424
- let ( ( m, k) , ( _, n) ) = ( lhs. dim , rhs. dim ) ;
427
+ let ( ( m, k) , ( k2, n) ) = ( lhs. dim , rhs. dim ) ;
428
+
429
+ debug_assert_eq ! ( k, k2) ;
430
+ if m > SPLIT {
431
+ // [ A0 ] B = [ C0 ]
432
+ // [ A1 ] [ C1 ]
433
+ let mid = m / 2 ;
434
+ let ( a0, a1) = lhs. split_at ( Axis ( 0 ) , mid) ;
435
+ let ( mut c0, mut c1) = c. view_mut ( ) . split_at ( Axis ( 0 ) , mid) ;
436
+ rayon:: join ( move || mat_mul_general ( alpha, & a0, rhs, beta, & mut c0) ,
437
+ move || mat_mul_general ( alpha, & a1, rhs, beta, & mut c1) ) ;
438
+ return ;
439
+ } else if n > SPLIT {
440
+ // A [ B0 B1 ] = [ C0 C1 ]
441
+ let mid = n / 2 ;
442
+ let ( b0, b1) = rhs. split_at ( Axis ( 1 ) , mid) ;
443
+ let ( mut c0, mut c1) = c. view_mut ( ) . split_at ( Axis ( 1 ) , mid) ;
444
+ rayon:: join ( move || mat_mul_general ( alpha, lhs, & b0, beta, & mut c0) ,
445
+ move || mat_mul_general ( alpha, lhs, & b1, beta, & mut c1) ) ;
446
+ return ;
447
+ }
425
448
426
449
// common parameters for gemm
427
450
let ap = lhs. as_ptr ( ) ;
0 commit comments