Skip to content

Commit d5f9cb5

Browse files
committed
implement scalar_prod
1 parent 408f42b commit d5f9cb5

File tree

3 files changed

+71
-28
lines changed

3 files changed

+71
-28
lines changed

src/numeric/impl_numeric.rs

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
// option. This file may not be copied, modified, or distributed
77
// except according to those terms.
88

9-
use std::ops::{Add, Div};
9+
use std::ops::{Add, Div, Mul};
1010
use libnum::{self, One, Zero, Float};
1111
use itertools::free::enumerate;
1212

@@ -33,19 +33,45 @@ impl<A, S, D> ArrayBase<S, D>
3333
where A: Clone + Add<Output=A> + libnum::Zero,
3434
{
3535
if let Some(slc) = self.as_slice_memory_order() {
36-
return numeric_util::unrolled_sum(slc);
36+
return numeric_util::unrolled_fold(slc, A::zero, A::add);
3737
}
3838
let mut sum = A::zero();
3939
for row in self.inner_rows() {
4040
if let Some(slc) = row.as_slice() {
41-
sum = sum + numeric_util::unrolled_sum(slc);
41+
sum = sum + numeric_util::unrolled_fold(slc, A::zero, A::add);
4242
} else {
4343
sum = sum + row.iter().fold(A::zero(), |acc, elt| acc + elt.clone());
4444
}
4545
}
4646
sum
4747
}
4848

49+
/// Return the product of all elements in the array.
50+
///
51+
/// ```
52+
/// use ndarray::arr2;
53+
///
54+
/// let a = arr2(&[[1., 2.],
55+
/// [3., 4.]]);
56+
/// assert_eq!(a.scalar_prod(), 24.);
57+
/// ```
58+
pub fn scalar_prod(&self) -> A
59+
where A: Clone + Mul<Output=A> + libnum::One,
60+
{
61+
if let Some(slc) = self.as_slice_memory_order() {
62+
return numeric_util::unrolled_fold(slc, A::one, A::mul);
63+
}
64+
let mut sum = A::one();
65+
for row in self.inner_rows() {
66+
if let Some(slc) = row.as_slice() {
67+
sum = sum * numeric_util::unrolled_fold(slc, A::one, A::mul);
68+
} else {
69+
sum = sum * row.iter().fold(A::one(), |acc, elt| acc * elt.clone());
70+
}
71+
}
72+
sum
73+
}
74+
4975
/// Return sum along `axis`.
5076
///
5177
/// ```

src/numeric_util.rs

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,50 +5,47 @@
55
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
66
// option. This file may not be copied, modified, or distributed
77
// except according to those terms.
8-
use libnum;
9-
108
use std::cmp;
11-
use std::ops::{
12-
Add,
13-
};
149

1510
use LinalgScalar;
1611

17-
/// Compute the sum of the values in `xs`
18-
pub fn unrolled_sum<A>(mut xs: &[A]) -> A
19-
where A: Clone + Add<Output=A> + libnum::Zero,
12+
/// Fold over the manually unrolled `xs` with `f`
13+
pub fn unrolled_fold<A, I, F>(mut xs: &[A], init: I, f: F) -> A
14+
where A: Clone,
15+
I: Fn() -> A,
16+
F: Fn(A, A) -> A,
2017
{
2118
// eightfold unrolled so that floating point can be vectorized
2219
// (even with strict floating point accuracy semantics)
23-
let mut sum = A::zero();
20+
let mut acc = init();
2421
let (mut p0, mut p1, mut p2, mut p3,
2522
mut p4, mut p5, mut p6, mut p7) =
26-
(A::zero(), A::zero(), A::zero(), A::zero(),
27-
A::zero(), A::zero(), A::zero(), A::zero());
23+
(init(), init(), init(), init(),
24+
init(), init(), init(), init());
2825
while xs.len() >= 8 {
29-
p0 = p0 + xs[0].clone();
30-
p1 = p1 + xs[1].clone();
31-
p2 = p2 + xs[2].clone();
32-
p3 = p3 + xs[3].clone();
33-
p4 = p4 + xs[4].clone();
34-
p5 = p5 + xs[5].clone();
35-
p6 = p6 + xs[6].clone();
36-
p7 = p7 + xs[7].clone();
26+
p0 = f(p0, xs[0].clone());
27+
p1 = f(p1, xs[1].clone());
28+
p2 = f(p2, xs[2].clone());
29+
p3 = f(p3, xs[3].clone());
30+
p4 = f(p4, xs[4].clone());
31+
p5 = f(p5, xs[5].clone());
32+
p6 = f(p6, xs[6].clone());
33+
p7 = f(p7, xs[7].clone());
3734

3835
xs = &xs[8..];
3936
}
40-
sum = sum.clone() + (p0 + p4);
41-
sum = sum.clone() + (p1 + p5);
42-
sum = sum.clone() + (p2 + p6);
43-
sum = sum.clone() + (p3 + p7);
37+
acc = f(acc.clone(), f(p0, p4));
38+
acc = f(acc.clone(), f(p1, p5));
39+
acc = f(acc.clone(), f(p2, p6));
40+
acc = f(acc.clone(), f(p3, p7));
4441

4542
// make it clear to the optimizer that this loop is short
4643
// and can not be autovectorized.
4744
for i in 0..xs.len() {
4845
if i >= 7 { break; }
49-
sum = sum.clone() + xs[i].clone()
46+
acc = f(acc.clone(), xs[i].clone())
5047
}
51-
sum
48+
acc
5249
}
5350

5451
/// Compute the dot product.

tests/oper.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,26 @@ fn fold_and_sum() {
271271
}
272272
}
273273

274+
#[test]
275+
fn scalar_prod() {
276+
let a = Array::linspace(0.5, 2., 128).into_shape((8, 16)).unwrap();
277+
assert_approx_eq(a.fold(1., |acc, &x| acc * x), a.scalar_prod(), 1e-5);
278+
279+
// test different strides
280+
let max = 8 as Ixs;
281+
for i in 1..max {
282+
for j in 1..max {
283+
let a1 = a.slice(s![..;i, ..;j]);
284+
let mut prod = 1.;
285+
for elt in a1.iter() {
286+
prod *= *elt;
287+
}
288+
assert_approx_eq(a1.fold(1., |acc, &x| acc * x), prod, 1e-5);
289+
assert_approx_eq(prod, a1.scalar_prod(), 1e-5);
290+
}
291+
}
292+
}
293+
274294
fn range_mat(m: Ix, n: Ix) -> Array2<f32> {
275295
Array::linspace(0., (m * n) as f32 - 1., m * n).into_shape((m, n)).unwrap()
276296
}

0 commit comments

Comments
 (0)