Skip to content

Commit a175675

Browse files
committed
Allow sum_axis and mean_axis for empty arrays
1 parent d4b9801 commit a175675

File tree

2 files changed

+22
-7
lines changed

2 files changed

+22
-7
lines changed

src/numeric/impl_numeric.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,13 @@ impl<A, S, D> ArrayBase<S, D>
6161
/// );
6262
/// ```
6363
///
64-
/// **Panics** if `axis` is out of bounds or if the length of the axis is
65-
/// zero.
64+
/// **Panics** if `axis` is out of bounds.
6665
pub fn sum_axis(&self, axis: Axis) -> Array<A, D::Smaller>
6766
where A: Clone + Zero + Add<Output=A>,
6867
D: RemoveAxis,
6968
{
7069
let n = self.len_of(axis);
71-
let mut res = self.subview(axis, 0).to_owned();
70+
let mut res = Array::zeros(self.raw_dim().remove_axis(axis));
7271
let stride = self.strides()[axis.index()];
7372
if self.ndim() == 2 && stride == 1 {
7473
// contiguous along the axis we are summing
@@ -77,7 +76,7 @@ impl<A, S, D> ArrayBase<S, D>
7776
*elt = self.subview(Axis(1 - ax), i).scalar_sum();
7877
}
7978
} else {
80-
for i in 1..n {
79+
for i in 0..n {
8180
let view = self.subview(axis, i);
8281
res = res + &view;
8382
}
@@ -88,7 +87,7 @@ impl<A, S, D> ArrayBase<S, D>
8887
/// Return mean along `axis`.
8988
///
9089
/// **Panics** if `axis` is out of bounds or if the length of the axis is
91-
/// zero.
90+
/// zero and division by zero panics for type `A`.
9291
///
9392
/// ```
9493
/// use ndarray::{aview1, arr2, Axis};
@@ -106,8 +105,8 @@ impl<A, S, D> ArrayBase<S, D>
106105
{
107106
let n = self.len_of(axis);
108107
let sum = self.sum_axis(axis);
109-
let mut cnt = A::one();
110-
for _ in 1..n {
108+
let mut cnt = A::zero();
109+
for _ in 0..n {
111110
cnt = cnt + A::one();
112111
}
113112
sum / &aview0(&cnt)

tests/array.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,22 @@ fn sum_mean()
697697
assert_eq!(a.scalar_sum(), 10.);
698698
}
699699

700+
#[test]
701+
fn sum_mean_empty() {
702+
assert_eq!(Array3::<f32>::ones((2, 0, 3)).scalar_sum(), 0.);
703+
assert_eq!(Array1::<f32>::ones(0).sum_axis(Axis(0)), arr0(0.));
704+
assert_eq!(
705+
Array3::<f32>::ones((2, 0, 3)).sum_axis(Axis(1)),
706+
Array::zeros((2, 3)),
707+
);
708+
let a = Array1::<f32>::ones(0).mean_axis(Axis(0));
709+
assert_eq!(a.shape(), &[]);
710+
assert!(a[()].is_nan());
711+
let a = Array3::<f32>::ones((2, 0, 3)).mean_axis(Axis(1));
712+
assert_eq!(a.shape(), &[2, 3]);
713+
a.mapv(|x| assert!(x.is_nan()));
714+
}
715+
700716
#[test]
701717
fn var_axis() {
702718
let a = array![

0 commit comments

Comments
 (0)