diff --git a/ndarray-linalg/src/least_squares.rs b/ndarray-linalg/src/least_squares.rs index 03583a25..25e700e4 100644 --- a/ndarray-linalg/src/least_squares.rs +++ b/ndarray-linalg/src/least_squares.rs @@ -149,12 +149,13 @@ where /// Solve least squares for immutable references and a single /// column vector as a right-hand side. -/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any -/// valid representation for `ArrayBase`. -impl LeastSquaresSvd for ArrayBase +/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any +/// valid representation for `ArrayBase` (over `E`). +impl LeastSquaresSvd for ArrayBase where E: Scalar + Lapack, - D: Data, + D1: Data, + D2: Data, { /// Solve a least squares problem of the form `Ax = rhs` /// by calling `A.least_squares(&rhs)`, where `rhs` is a @@ -163,7 +164,7 @@ where /// `A` and `rhs` must have the same layout, i.e. they must /// be both either row- or column-major format, otherwise a /// `IncompatibleShape` error is raised. - fn least_squares(&self, rhs: &ArrayBase) -> Result> { + fn least_squares(&self, rhs: &ArrayBase) -> Result> { let a = self.to_owned(); let b = rhs.to_owned(); a.least_squares_into(b) @@ -172,12 +173,13 @@ where /// Solve least squares for immutable references and matrix /// (=mulitipe vectors) as a right-hand side. -/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any -/// valid representation for `ArrayBase`. -impl LeastSquaresSvd for ArrayBase +/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any +/// valid representation for `ArrayBase` (over `E`). +impl LeastSquaresSvd for ArrayBase where E: Scalar + Lapack, - D: Data, + D1: Data, + D2: Data, { /// Solve a least squares problem of the form `Ax = rhs` /// by calling `A.least_squares(&rhs)`, where `rhs` is @@ -186,7 +188,7 @@ where /// `A` and `rhs` must have the same layout, i.e. they must /// be both either row- or column-major format, otherwise a /// `IncompatibleShape` error is raised. - fn least_squares(&self, rhs: &ArrayBase) -> Result> { + fn least_squares(&self, rhs: &ArrayBase) -> Result> { let a = self.to_owned(); let b = rhs.to_owned(); a.least_squares_into(b) @@ -199,10 +201,11 @@ where /// /// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any /// valid representation for `ArrayBase`. -impl LeastSquaresSvdInto for ArrayBase +impl LeastSquaresSvdInto for ArrayBase where E: Scalar + Lapack, - D: DataMut, + D1: DataMut, + D2: DataMut, { /// Solve a least squares problem of the form `Ax = rhs` /// by calling `A.least_squares(rhs)`, where `rhs` is a @@ -213,7 +216,7 @@ where /// `IncompatibleShape` error is raised. fn least_squares_into( mut self, - mut rhs: ArrayBase, + mut rhs: ArrayBase, ) -> Result> { self.least_squares_in_place(&mut rhs) } @@ -223,12 +226,13 @@ where /// as a right-hand side. The matrix and the RHS matrix /// are consumed. /// -/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any -/// valid representation for `ArrayBase`. -impl LeastSquaresSvdInto for ArrayBase +/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any +/// valid representation for `ArrayBase` (over `E`). +impl LeastSquaresSvdInto for ArrayBase where E: Scalar + Lapack, - D: DataMut, + D1: DataMut, + D2: DataMut, { /// Solve a least squares problem of the form `Ax = rhs` /// by calling `A.least_squares(rhs)`, where `rhs` is a @@ -239,7 +243,7 @@ where /// `IncompatibleShape` error is raised. fn least_squares_into( mut self, - mut rhs: ArrayBase, + mut rhs: ArrayBase, ) -> Result> { self.least_squares_in_place(&mut rhs) } @@ -249,12 +253,13 @@ where /// as a right-hand side. Both values are overwritten in the /// call. /// -/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any -/// valid representation for `ArrayBase`. -impl LeastSquaresSvdInPlace for ArrayBase +/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any +/// valid representation for `ArrayBase` (over `E`). +impl LeastSquaresSvdInPlace for ArrayBase where E: Scalar + Lapack, - D: DataMut, + D1: DataMut, + D2: DataMut, { /// Solve a least squares problem of the form `Ax = rhs` /// by calling `A.least_squares(rhs)`, where `rhs` is a @@ -265,7 +270,7 @@ where /// `IncompatibleShape` error is raised. fn least_squares_in_place( &mut self, - rhs: &mut ArrayBase, + rhs: &mut ArrayBase, ) -> Result> { if self.shape()[0] != rhs.shape()[0] { return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape).into()); @@ -331,12 +336,13 @@ fn compute_residual_scalar>( /// as a right-hand side. Both values are overwritten in the /// call. /// -/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any -/// valid representation for `ArrayBase`. -impl LeastSquaresSvdInPlace for ArrayBase +/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any +/// valid representation for `ArrayBase` (over `E`). +impl LeastSquaresSvdInPlace for ArrayBase where E: Scalar + Lapack + LeastSquaresSvdDivideConquer_, - D: DataMut, + D1: DataMut, + D2: DataMut, { /// Solve a least squares problem of the form `Ax = rhs` /// by calling `A.least_squares(rhs)`, where `rhs` is a @@ -347,7 +353,7 @@ where /// `IncompatibleShape` error is raised. fn least_squares_in_place( &mut self, - rhs: &mut ArrayBase, + rhs: &mut ArrayBase, ) -> Result> { if self.shape()[0] != rhs.shape()[0] { return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape).into()); @@ -425,7 +431,7 @@ mod tests { use ndarray::*; // - // Test that the different lest squares traits work as intended on the + // Test that the different least squares traits work as intended on the // different array types. // // | least_squares | ls_into | ls_in_place | @@ -437,9 +443,9 @@ mod tests { // ArrayViewMut | yes | no | yes | // - fn assert_result>( - a: &ArrayBase, - b: &ArrayBase, + fn assert_result, D2: Data>( + a: &ArrayBase, + b: &ArrayBase, res: &LeastSquaresResult, ) { assert_eq!(res.rank, 2); @@ -487,6 +493,15 @@ mod tests { assert_result(&av, &bv, &res); } + #[test] + fn on_cow_view() { + let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]); + let b: Array1 = array![1., 2., 3.]; + let bv = b.view(); + let res = a.least_squares(&bv).unwrap(); + assert_result(&a, &bv, &res); + } + #[test] fn into_on_owned() { let a: Array2 = array![[1., 2.], [4., 5.], [3., 4.]]; @@ -517,6 +532,16 @@ mod tests { assert_result(&a, &b, &res); } + #[test] + fn into_on_owned_cow() { + let a: Array2 = array![[1., 2.], [4., 5.], [3., 4.]]; + let b = CowArray::from(array![1., 2., 3.]); + let ac = a.clone(); + let b2 = b.clone(); + let res = ac.least_squares_into(b2).unwrap(); + assert_result(&a, &b, &res); + } + #[test] fn in_place_on_owned() { let a = array![[1., 2.], [4., 5.], [3., 4.]]; @@ -549,6 +574,16 @@ mod tests { assert_result(&a, &b, &res); } + #[test] + fn in_place_on_owned_cow() { + let a = array![[1., 2.], [4., 5.], [3., 4.]]; + let b = CowArray::from(array![1., 2., 3.]); + let mut a2 = a.clone(); + let mut b2 = b.clone(); + let res = a2.least_squares_in_place(&mut b2).unwrap(); + assert_result(&a, &b, &res); + } + // // Testing error cases //