@@ -61,14 +61,13 @@ impl<A, S, D> ArrayBase<S, D>
61
61
/// );
62
62
/// ```
63
63
///
64
- /// **Panics** if `axis` is out of bounds or if the length of the axis is
65
- /// zero.
64
+ /// **Panics** if `axis` is out of bounds.
66
65
pub fn sum_axis ( & self , axis : Axis ) -> Array < A , D :: Smaller >
67
66
where A : Clone + Zero + Add < Output =A > ,
68
67
D : RemoveAxis ,
69
68
{
70
69
let n = self . len_of ( axis) ;
71
- let mut res = self . subview ( axis , 0 ) . to_owned ( ) ;
70
+ let mut res = Array :: zeros ( self . raw_dim ( ) . remove_axis ( axis ) ) ;
72
71
let stride = self . strides ( ) [ axis. index ( ) ] ;
73
72
if self . ndim ( ) == 2 && stride == 1 {
74
73
// contiguous along the axis we are summing
@@ -77,7 +76,7 @@ impl<A, S, D> ArrayBase<S, D>
77
76
* elt = self . subview ( Axis ( 1 - ax) , i) . scalar_sum ( ) ;
78
77
}
79
78
} else {
80
- for i in 1 ..n {
79
+ for i in 0 ..n {
81
80
let view = self . subview ( axis, i) ;
82
81
res = res + & view;
83
82
}
@@ -88,7 +87,7 @@ impl<A, S, D> ArrayBase<S, D>
88
87
/// Return mean along `axis`.
89
88
///
90
89
/// **Panics** if `axis` is out of bounds or if the length of the axis is
91
- /// zero.
90
+ /// zero and division by zero panics for type `A` .
92
91
///
93
92
/// ```
94
93
/// use ndarray::{aview1, arr2, Axis};
@@ -106,8 +105,8 @@ impl<A, S, D> ArrayBase<S, D>
106
105
{
107
106
let n = self . len_of ( axis) ;
108
107
let sum = self . sum_axis ( axis) ;
109
- let mut cnt = A :: one ( ) ;
110
- for _ in 1 ..n {
108
+ let mut cnt = A :: zero ( ) ;
109
+ for _ in 0 ..n {
111
110
cnt = cnt + A :: one ( ) ;
112
111
}
113
112
sum / & aview0 ( & cnt)
0 commit comments