diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index 906021341..4e7127c96 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -65,14 +65,13 @@ impl ArrayBase /// ); /// ``` /// - /// **Panics** if `axis` is out of bounds or if the length of the axis is - /// zero. + /// **Panics** if `axis` is out of bounds. pub fn sum_axis(&self, axis: Axis) -> Array where A: Clone + Zero + Add, D: RemoveAxis, { let n = self.len_of(axis); - let mut res = self.subview(axis, 0).to_owned(); + let mut res = Array::zeros(self.raw_dim().remove_axis(axis)); let stride = self.strides()[axis.index()]; if self.ndim() == 2 && stride == 1 { // contiguous along the axis we are summing @@ -81,7 +80,7 @@ impl ArrayBase *elt = self.subview(Axis(1 - ax), i).scalar_sum(); } } else { - for i in 1..n { + for i in 0..n { let view = self.subview(axis, i); res = res + &view; } @@ -92,7 +91,7 @@ impl ArrayBase /// Return mean along `axis`. /// /// **Panics** if `axis` is out of bounds or if the length of the axis is - /// zero. + /// zero and division by zero panics for type `A`. /// /// ``` /// use ndarray::{aview1, arr2, Axis}; @@ -110,8 +109,8 @@ impl ArrayBase { let n = self.len_of(axis); let sum = self.sum_axis(axis); - let mut cnt = A::one(); - for _ in 1..n { + let mut cnt = A::zero(); + for _ in 0..n { cnt = cnt + A::one(); } sum / &aview0(&cnt) diff --git a/tests/array.rs b/tests/array.rs index e694a6cf7..69db22c8d 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -697,6 +697,22 @@ fn sum_mean() assert_eq!(a.scalar_sum(), 10.); } +#[test] +fn sum_mean_empty() { + assert_eq!(Array3::::ones((2, 0, 3)).scalar_sum(), 0.); + assert_eq!(Array1::::ones(0).sum_axis(Axis(0)), arr0(0.)); + assert_eq!( + Array3::::ones((2, 0, 3)).sum_axis(Axis(1)), + Array::zeros((2, 3)), + ); + let a = Array1::::ones(0).mean_axis(Axis(0)); + assert_eq!(a.shape(), &[]); + assert!(a[()].is_nan()); + let a = Array3::::ones((2, 0, 3)).mean_axis(Axis(1)); + assert_eq!(a.shape(), &[2, 3]); + a.mapv(|x| assert!(x.is_nan())); +} + #[test] fn var_axis() { let a = array![