Skip to content

Commit 938846e

Browse files
committed
implement scalar_prod
1 parent 408f42b commit 938846e

File tree

3 files changed

+85
-1
lines changed

3 files changed

+85
-1
lines changed

src/numeric/impl_numeric.rs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
// option. This file may not be copied, modified, or distributed
77
// except according to those terms.
88

9-
use std::ops::{Add, Div};
9+
use std::ops::{Add, Div, Mul};
1010
use libnum::{self, One, Zero, Float};
1111
use itertools::free::enumerate;
1212

@@ -46,6 +46,32 @@ impl<A, S, D> ArrayBase<S, D>
4646
sum
4747
}
4848

49+
/// Return the product of all elements in the array.
50+
///
51+
/// ```
52+
/// use ndarray::arr2;
53+
///
54+
/// let a = arr2(&[[1., 2.],
55+
/// [3., 4.]]);
56+
/// assert_eq!(a.scalar_prod(), 24.);
57+
/// ```
58+
pub fn scalar_prod(&self) -> A
59+
where A: Clone + Mul<Output=A> + libnum::One,
60+
{
61+
if let Some(slc) = self.as_slice_memory_order() {
62+
return numeric_util::unrolled_prod(slc);
63+
}
64+
let mut sum = A::one();
65+
for row in self.inner_rows() {
66+
if let Some(slc) = row.as_slice() {
67+
sum = sum * numeric_util::unrolled_prod(slc);
68+
} else {
69+
sum = sum * row.iter().fold(A::one(), |acc, elt| acc * elt.clone());
70+
}
71+
}
72+
sum
73+
}
74+
4975
/// Return sum along `axis`.
5076
///
5177
/// ```

src/numeric_util.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use libnum;
1010
use std::cmp;
1111
use std::ops::{
1212
Add,
13+
Mul,
1314
};
1415

1516
use LinalgScalar;
@@ -51,6 +52,43 @@ pub fn unrolled_sum<A>(mut xs: &[A]) -> A
5152
sum
5253
}
5354

55+
/// Compute the product of the values in `xs`
56+
pub fn unrolled_prod<A>(mut xs: &[A]) -> A
57+
where A: Clone + Mul<Output=A> + libnum::One,
58+
{
59+
// eightfold unrolled so that floating point can be vectorized
60+
// (even with strict floating point accuracy semantics)
61+
let mut prod = A::one();
62+
let (mut p0, mut p1, mut p2, mut p3,
63+
mut p4, mut p5, mut p6, mut p7) =
64+
(A::one(), A::one(), A::one(), A::one(),
65+
A::one(), A::one(), A::one(), A::one());
66+
while xs.len() >= 8 {
67+
p0 = p0 * xs[0].clone();
68+
p1 = p1 * xs[1].clone();
69+
p2 = p2 * xs[2].clone();
70+
p3 = p3 * xs[3].clone();
71+
p4 = p4 * xs[4].clone();
72+
p5 = p5 * xs[5].clone();
73+
p6 = p6 * xs[6].clone();
74+
p7 = p7 * xs[7].clone();
75+
76+
xs = &xs[8..];
77+
}
78+
prod = prod.clone() * (p0 * p4);
79+
prod = prod.clone() * (p1 * p5);
80+
prod = prod.clone() * (p2 * p6);
81+
prod = prod.clone() * (p3 * p7);
82+
83+
// make it clear to the optimizer that this loop is short
84+
// and can not be autovectorized.
85+
for i in 0..xs.len() {
86+
if i >= 7 { break; }
87+
prod = prod.clone() * xs[i].clone()
88+
}
89+
prod
90+
}
91+
5492
/// Compute the dot product.
5593
///
5694
/// `xs` and `ys` must be the same length

tests/oper.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,26 @@ fn fold_and_sum() {
271271
}
272272
}
273273

274+
#[test]
275+
fn scalar_prod() {
276+
let a = Array::linspace(0.5, 2., 128).into_shape((8, 16)).unwrap();
277+
assert_approx_eq(a.fold(1., |acc, &x| acc * x), a.scalar_prod(), 1e-5);
278+
279+
// test different strides
280+
let max = 8 as Ixs;
281+
for i in 1..max {
282+
for j in 1..max {
283+
let a1 = a.slice(s![..;i, ..;j]);
284+
let mut prod = 1.;
285+
for elt in a1.iter() {
286+
prod *= *elt;
287+
}
288+
assert_approx_eq(a1.fold(1., |acc, &x| acc * x), prod, 1e-5);
289+
assert_approx_eq(prod, a1.scalar_prod(), 1e-5);
290+
}
291+
}
292+
}
293+
274294
fn range_mat(m: Ix, n: Ix) -> Array2<f32> {
275295
Array::linspace(0., (m * n) as f32 - 1., m * n).into_shape((m, n)).unwrap()
276296
}

0 commit comments

Comments
 (0)