From e34747c8098b36f32dfaf714fde698e4e670de09 Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Sat, 15 Mar 2025 16:58:31 +0900 Subject: [PATCH 01/12] impl and tests added --- src/linalg/impl_linalg.rs | 128 +++++++++++++++++++++++++++++++++++++- 1 file changed, 127 insertions(+), 1 deletion(-) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index 7472d8292..d6e7799ca 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -353,7 +353,7 @@ where /// /// If their shapes disagree, `rhs` is broadcast to the shape of `self`. /// - /// **Panics** if broadcasting isn’t possible. + /// **Panics** if broadcasting isn't possible. #[track_caller] pub fn scaled_add(&mut self, alpha: A, rhs: &ArrayBase) where @@ -1067,3 +1067,129 @@ mod blas_tests } } } + +impl Dot> for ArrayBase +where + S: Data, + S2: Data, + A: LinalgScalar, +{ + type Output = Array; + + fn dot(&self, rhs: &ArrayBase) -> Self::Output { + match (self.ndim(), rhs.ndim()) { + (1, 1) => { + // Vector-vector dot product + if self.len() != rhs.len() { + panic!("Vector lengths must match for dot product"); + } + let a = self.view().into_dimensionality::().unwrap(); + let b = rhs.view().into_dimensionality::().unwrap(); + let result = a.dot(&b); + ArrayD::from_elem(vec![], result) + } + (2, 2) => { + // Matrix-matrix multiplication + let a = self.view().into_dimensionality::().unwrap(); + let b = rhs.view().into_dimensionality::().unwrap(); + let result = a.dot(&b); + result.into_dimensionality::().unwrap() + } + (2, 1) => { + // Matrix-vector multiplication + let a = self.view().into_dimensionality::().unwrap(); + let b = rhs.view().into_dimensionality::().unwrap(); + let result = a.dot(&b); + result.into_dimensionality::().unwrap() + } + (1, 2) => { + // Vector-matrix multiplication + let a = self.view().into_dimensionality::().unwrap(); + let b = rhs.view().into_dimensionality::().unwrap(); + let result = a.dot(&b); + result.into_dimensionality::().unwrap() + } + _ => panic!("Dot product for ArrayD is only supported for 1D and 2D arrays"), + } + } +} + +#[cfg(test)] +mod arrayd_dot_tests { + use super::*; + use crate::ArrayD; + + #[test] + fn test_arrayd_dot_2d() { + // Test case from the original issue + let mat1 = ArrayD::from_shape_vec(vec![3, 2], vec![3.0; 6]).unwrap(); + let mat2 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap(); + + let result = mat1.dot(&mat2); + + // Verify the result is correct + assert_eq!(result.ndim(), 2); + assert_eq!(result.shape(), &[3, 3]); + + // Compare with Array2 implementation + let mat1_2d = Array2::from_shape_vec((3, 2), vec![3.0; 6]).unwrap(); + let mat2_2d = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap(); + let expected = mat1_2d.dot(&mat2_2d); + + assert_eq!(result.into_dimensionality::().unwrap(), expected); + } + + #[test] + fn test_arrayd_dot_1d() { + // Test 1D array dot product + let vec1 = ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).unwrap(); + let vec2 = ArrayD::from_shape_vec(vec![3], vec![4.0, 5.0, 6.0]).unwrap(); + + let result = vec1.dot(&vec2); + + // Verify scalar result + assert_eq!(result.ndim(), 0); + assert_eq!(result.shape(), &[]); + assert_eq!(result[[]], 32.0); // 1*4 + 2*5 + 3*6 + } + + #[test] + #[should_panic(expected = "Dot product for ArrayD is only supported for 1D and 2D arrays")] + fn test_arrayd_dot_3d() { + // Test that 3D arrays are not supported + let arr1 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap(); + let arr2 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap(); + + let _result = arr1.dot(&arr2); // Should panic + } + + #[test] + #[should_panic(expected = "ndarray: inputs 2 × 3 and 4 × 5 are not compatible for matrix multiplication")] + fn test_arrayd_dot_incompatible_dims() { + // Test arrays with incompatible dimensions + let arr1 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap(); + let arr2 = ArrayD::from_shape_vec(vec![4, 5], vec![1.0; 20]).unwrap(); + + let _result = arr1.dot(&arr2); // Should panic + } + + #[test] + fn test_arrayd_dot_matrix_vector() { + // Test matrix-vector multiplication + let mat = ArrayD::from_shape_vec(vec![3, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(); + let vec = ArrayD::from_shape_vec(vec![2], vec![1.0, 2.0]).unwrap(); + + let result = mat.dot(&vec); + + // Verify result + assert_eq!(result.ndim(), 1); + assert_eq!(result.shape(), &[3]); + + // Compare with Array2 implementation + let mat_2d = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(); + let vec_1d = Array1::from_vec(vec![1.0, 2.0]); + let expected = mat_2d.dot(&vec_1d); + + assert_eq!(result.into_dimensionality::().unwrap(), expected); + } +} From 3adf2e30d00d48c2d7bf5b8312ba9fbdcce390bd Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Sun, 16 Mar 2025 09:41:28 +0900 Subject: [PATCH 02/12] add docstring and exmples --- src/linalg/impl_linalg.rs | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index d6e7799ca..de5bd4922 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -1068,6 +1068,39 @@ mod blas_tests } } +/// Dot product for dynamic-dimensional arrays (`ArrayD`). +/// +/// For one-dimensional arrays, computes the vector dot product, which is the sum +/// of the elementwise products (no conjugation of complex operands). +/// Both arrays must have the same length. +/// +/// For two-dimensional arrays, performs matrix multiplication. The array shapes +/// must be compatible in the following ways: +/// - If `self` is *M* × *N*, then `rhs` must be *N* × *K* for matrix-matrix multiplication +/// - If `self` is *M* × *N* and `rhs` is *N*, returns a vector of length *M* +/// - If `self` is *M* and `rhs` is *M* × *N*, returns a vector of length *N* +/// - If both arrays are one-dimensional of length *N*, returns a scalar +/// +/// **Panics** if: +/// - The arrays have dimensions other than 1 or 2 +/// - The array shapes are incompatible for the operation +/// - For vector dot product: the vectors have different lengths +/// +/// # Examples +/// +/// ``` +/// use ndarray::{ArrayD, Array2, Array1}; +/// +/// // Matrix multiplication +/// let a = ArrayD::from_shape_vec(vec![2, 3], vec![1., 2., 3., 4., 5., 6.]).unwrap(); +/// let b = ArrayD::from_shape_vec(vec![3, 2], vec![1., 2., 3., 4., 5., 6.]).unwrap(); +/// let c = a.dot(&b); +/// +/// // Vector dot product +/// let v1 = ArrayD::from_shape_vec(vec![3], vec![1., 2., 3.]).unwrap(); +/// let v2 = ArrayD::from_shape_vec(vec![3], vec![4., 5., 6.]).unwrap(); +/// let scalar = v1.dot(&v2); +/// ``` impl Dot> for ArrayBase where S: Data, From a5f78a9b5a38a158bf8c6c391f7572aa7d2611cd Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Sun, 16 Mar 2025 09:50:05 +0900 Subject: [PATCH 03/12] remove duplicated length check --- src/linalg/impl_linalg.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index de5bd4922..c48967b60 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -1112,10 +1112,6 @@ where fn dot(&self, rhs: &ArrayBase) -> Self::Output { match (self.ndim(), rhs.ndim()) { (1, 1) => { - // Vector-vector dot product - if self.len() != rhs.len() { - panic!("Vector lengths must match for dot product"); - } let a = self.view().into_dimensionality::().unwrap(); let b = rhs.view().into_dimensionality::().unwrap(); let result = a.dot(&b); From 1ca94096b2cdbe28c6f1131867782c0a05c358aa Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Sun, 16 Mar 2025 09:56:29 +0900 Subject: [PATCH 04/12] cargo fmt --- src/linalg/impl_linalg.rs | 200 ++++++++++++++++---------------------- 1 file changed, 86 insertions(+), 114 deletions(-) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index c48967b60..b13628880 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -41,7 +41,8 @@ const GEMM_BLAS_CUTOFF: usize = 7; type blas_index = c_int; // blas index type impl ArrayBase -where S: Data +where + S: Data, { /// Perform dot product or matrix multiplication of arrays `self` and `rhs`. /// @@ -62,7 +63,8 @@ where S: Data /// layout allows. #[track_caller] pub fn dot(&self, rhs: &Rhs) -> >::Output - where Self: Dot + where + Self: Dot, { Dot::dot(self, rhs) } @@ -111,17 +113,9 @@ where S: Data ($ty:ty, $func:ident) => {{ if blas_compat_1d::<$ty, _>(self) && blas_compat_1d::<$ty, _>(rhs) { unsafe { - let (lhs_ptr, n, incx) = - blas_1d_params(self.ptr.as_ptr(), self.len(), self.strides()[0]); - let (rhs_ptr, _, incy) = - blas_1d_params(rhs.ptr.as_ptr(), rhs.len(), rhs.strides()[0]); - let ret = blas_sys::$func( - n, - lhs_ptr as *const $ty, - incx, - rhs_ptr as *const $ty, - incy, - ); + let (lhs_ptr, n, incx) = blas_1d_params(self.ptr.as_ptr(), self.len(), self.strides()[0]); + let (rhs_ptr, _, incy) = blas_1d_params(rhs.ptr.as_ptr(), rhs.len(), rhs.strides()[0]); + let ret = blas_sys::$func(n, lhs_ptr as *const $ty, incx, rhs_ptr as *const $ty, incy); return cast_as::<$ty, A>(&ret); } } @@ -141,8 +135,7 @@ where S: Data /// which agrees with our pointer for non-negative strides, but /// is at the opposite end for negative strides. #[cfg(feature = "blas")] -unsafe fn blas_1d_params(ptr: *const A, len: usize, stride: isize) -> (*const A, blas_index, blas_index) -{ +unsafe fn blas_1d_params(ptr: *const A, len: usize, stride: isize) -> (*const A, blas_index, blas_index) { // [x x x x] // ^--ptr // stride = -1 @@ -159,8 +152,7 @@ unsafe fn blas_1d_params(ptr: *const A, len: usize, stride: isize) -> (*const /// /// For two-dimensional arrays, the dot method computes the matrix /// multiplication. -pub trait Dot -{ +pub trait Dot { /// The result of the operation. /// /// For two-dimensional arrays: a rectangular array. @@ -185,8 +177,7 @@ where /// *Note:* If enabled, uses blas `dot` for elements of `f32, f64` when memory /// layout allows. #[track_caller] - fn dot(&self, rhs: &ArrayBase) -> A - { + fn dot(&self, rhs: &ArrayBase) -> A { self.dot_impl(rhs) } } @@ -209,14 +200,14 @@ where /// /// **Panics** if shapes are incompatible. #[track_caller] - fn dot(&self, rhs: &ArrayBase) -> Array - { + fn dot(&self, rhs: &ArrayBase) -> Array { rhs.t().dot(self) } } impl ArrayBase -where S: Data +where + S: Data, { /// Perform matrix multiplication of rectangular arrays `self` and `rhs`. /// @@ -249,7 +240,8 @@ where S: Data /// ``` #[track_caller] pub fn dot(&self, rhs: &Rhs) -> >::Output - where Self: Dot + where + Self: Dot, { Dot::dot(self, rhs) } @@ -262,8 +254,7 @@ where A: LinalgScalar, { type Output = Array2; - fn dot(&self, b: &ArrayBase) -> Array2 - { + fn dot(&self, b: &ArrayBase) -> Array2 { let a = self.view(); let b = b.view(); let ((m, k), (k2, n)) = (a.dim(), b.dim()); @@ -289,24 +280,21 @@ where /// Assumes that `m` and `n` are ≤ `isize::MAX`. #[cold] #[inline(never)] -fn dot_shape_error(m: usize, k: usize, k2: usize, n: usize) -> ! -{ +fn dot_shape_error(m: usize, k: usize, k2: usize, n: usize) -> ! { match m.checked_mul(n) { Some(len) if len <= isize::MAX as usize => {} _ => panic!("ndarray: shape {} × {} overflows isize", m, n), } - panic!( - "ndarray: inputs {} × {} and {} × {} are not compatible for matrix multiplication", - m, k, k2, n - ); + panic!("ndarray: inputs {} × {} and {} × {} are not compatible for matrix multiplication", m, k, k2, n); } #[cold] #[inline(never)] -fn general_dot_shape_error(m: usize, k: usize, k2: usize, n: usize, c1: usize, c2: usize) -> ! -{ - panic!("ndarray: inputs {} × {}, {} × {}, and output {} × {} are not compatible for matrix multiplication", - m, k, k2, n, c1, c2); +fn general_dot_shape_error(m: usize, k: usize, k2: usize, n: usize, c1: usize, c2: usize) -> ! { + panic!( + "ndarray: inputs {} × {}, {} × {}, and output {} × {} are not compatible for matrix multiplication", + m, k, k2, n, c1, c2 + ); } /// Perform the matrix multiplication of the rectangular array `self` and @@ -326,8 +314,7 @@ where { type Output = Array; #[track_caller] - fn dot(&self, rhs: &ArrayBase) -> Array - { + fn dot(&self, rhs: &ArrayBase) -> Array { let ((m, a), n) = (self.dim(), rhs.dim()); if a != n { dot_shape_error(m, a, n, 1); @@ -373,7 +360,8 @@ use self::mat_mul_general as mat_mul_impl; #[cfg(feature = "blas")] fn mat_mul_impl(alpha: A, a: &ArrayView2<'_, A>, b: &ArrayView2<'_, A>, beta: A, c: &mut ArrayViewMut2<'_, A>) -where A: LinalgScalar +where + A: LinalgScalar, { let ((m, k), (k2, n)) = (a.dim(), b.dim()); debug_assert_eq!(k, k2); @@ -428,17 +416,17 @@ where A: LinalgScalar cblas_layout, a_trans, b_trans, - m as blas_index, // m, rows of Op(a) - n as blas_index, // n, cols of Op(b) - k as blas_index, // k, cols of Op(a) - gemm_scalar_cast!($ty, alpha), // alpha - a.ptr.as_ptr() as *const _, // a - lda, // lda - b.ptr.as_ptr() as *const _, // b - ldb, // ldb - gemm_scalar_cast!($ty, beta), // beta - c.ptr.as_ptr() as *mut _, // c - ldc, // ldc + m as blas_index, // m, rows of Op(a) + n as blas_index, // n, cols of Op(b) + k as blas_index, // k, cols of Op(a) + gemm_scalar_cast!($ty, alpha), // alpha + a.ptr.as_ptr() as *const _, // a + lda, // lda + b.ptr.as_ptr() as *const _, // b + ldb, // ldb + gemm_scalar_cast!($ty, beta), // beta + c.ptr.as_ptr() as *mut _, // c + ldc, // ldc ); } return; @@ -458,9 +446,9 @@ where A: LinalgScalar } /// C ← α A B + β C -fn mat_mul_general( - alpha: A, lhs: &ArrayView2<'_, A>, rhs: &ArrayView2<'_, A>, beta: A, c: &mut ArrayViewMut2<'_, A>, -) where A: LinalgScalar +fn mat_mul_general(alpha: A, lhs: &ArrayView2<'_, A>, rhs: &ArrayView2<'_, A>, beta: A, c: &mut ArrayViewMut2<'_, A>) +where + A: LinalgScalar, { let ((m, k), (_, n)) = (lhs.dim(), rhs.dim()); @@ -687,8 +675,8 @@ unsafe fn general_mat_vec_mul_impl( a_stride, // lda x_ptr as *const _, // x x_stride, - cast_as(&beta), // beta - y_ptr as *mut _, // y + cast_as(&beta), // beta + y_ptr as *mut _, // y y_stride, ); return; @@ -751,25 +739,26 @@ where #[inline(always)] /// Return `true` if `A` and `B` are the same type -fn same_type() -> bool -{ +fn same_type() -> bool { TypeId::of::() == TypeId::of::() } // Read pointer to type `A` as type `B`. // // **Panics** if `A` and `B` are not the same type -fn cast_as(a: &A) -> B -{ - assert!(same_type::(), "expect type {} and {} to match", - std::any::type_name::(), std::any::type_name::()); +fn cast_as(a: &A) -> B { + assert!( + same_type::(), + "expect type {} and {} to match", + std::any::type_name::(), + std::any::type_name::() + ); unsafe { ::std::ptr::read(a as *const _ as *const B) } } /// Return the complex in the form of an array [re, im] #[inline] -fn complex_array(z: Complex) -> [A; 2] -{ +fn complex_array(z: Complex) -> [A; 2] { [z.re, z.im] } @@ -796,17 +785,14 @@ where #[cfg(feature = "blas")] #[derive(Copy, Clone)] #[cfg_attr(test, derive(PartialEq, Eq, Debug))] -enum BlasOrder -{ +enum BlasOrder { C, F, } #[cfg(feature = "blas")] -impl BlasOrder -{ - fn transpose(self) -> Self - { +impl BlasOrder { + fn transpose(self) -> Self { match self { Self::C => Self::F, Self::F => Self::C, @@ -815,16 +801,14 @@ impl BlasOrder #[inline] /// Axis of leading stride (opposite of contiguous axis) - fn get_blas_lead_axis(self) -> usize - { + fn get_blas_lead_axis(self) -> usize { match self { Self::C => 0, Self::F => 1, } } - fn to_cblas_layout(self) -> CBLAS_LAYOUT - { + fn to_cblas_layout(self) -> CBLAS_LAYOUT { match self { Self::C => CBLAS_LAYOUT::CblasRowMajor, Self::F => CBLAS_LAYOUT::CblasColMajor, @@ -833,8 +817,7 @@ impl BlasOrder /// When using cblas_sgemm (etc) with C matrix using `for_layout`, /// how should this `self` matrix be transposed - fn to_cblas_transpose_for(self, for_layout: CBLAS_LAYOUT) -> CBLAS_TRANSPOSE - { + fn to_cblas_transpose_for(self, for_layout: CBLAS_LAYOUT) -> CBLAS_TRANSPOSE { let effective_order = match for_layout { CBLAS_LAYOUT::CblasRowMajor => self, CBLAS_LAYOUT::CblasColMajor => self.transpose(), @@ -848,8 +831,7 @@ impl BlasOrder } #[cfg(feature = "blas")] -fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: BlasOrder) -> bool -{ +fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: BlasOrder) -> bool { let (m, n) = dim.into_pattern(); let s0 = stride[0] as isize; let s1 = stride[1] as isize; @@ -887,7 +869,8 @@ fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: BlasOrder) -> bool /// Get BLAS compatible layout if any (C or F, preferring the former) #[cfg(feature = "blas")] fn get_blas_compatible_layout(a: &ArrayBase) -> Option -where S: Data +where + S: Data, { if is_blas_2d(&a.dim, &a.strides, BlasOrder::C) { Some(BlasOrder::C) @@ -904,7 +887,8 @@ where S: Data /// Return leading stride (lda, ldb, ldc) of array #[cfg(feature = "blas")] fn blas_stride(a: &ArrayBase, order: BlasOrder) -> blas_index -where S: Data +where + S: Data, { let axis = order.get_blas_lead_axis(); let other_axis = 1 - axis; @@ -953,37 +937,32 @@ where #[cfg(test)] #[cfg(feature = "blas")] -mod blas_tests -{ +mod blas_tests { use super::*; #[test] - fn blas_row_major_2d_normal_matrix() - { + fn blas_row_major_2d_normal_matrix() { let m: Array2 = Array2::zeros((3, 5)); assert!(blas_row_major_2d::(&m)); assert!(!blas_column_major_2d::(&m)); } #[test] - fn blas_row_major_2d_row_matrix() - { + fn blas_row_major_2d_row_matrix() { let m: Array2 = Array2::zeros((1, 5)); assert!(blas_row_major_2d::(&m)); assert!(blas_column_major_2d::(&m)); } #[test] - fn blas_row_major_2d_column_matrix() - { + fn blas_row_major_2d_column_matrix() { let m: Array2 = Array2::zeros((5, 1)); assert!(blas_row_major_2d::(&m)); assert!(blas_column_major_2d::(&m)); } #[test] - fn blas_row_major_2d_transposed_row_matrix() - { + fn blas_row_major_2d_transposed_row_matrix() { let m: Array2 = Array2::zeros((1, 5)); let m_t = m.t(); assert!(blas_row_major_2d::(&m_t)); @@ -991,8 +970,7 @@ mod blas_tests } #[test] - fn blas_row_major_2d_transposed_column_matrix() - { + fn blas_row_major_2d_transposed_column_matrix() { let m: Array2 = Array2::zeros((5, 1)); let m_t = m.t(); assert!(blas_row_major_2d::(&m_t)); @@ -1000,16 +978,14 @@ mod blas_tests } #[test] - fn blas_column_major_2d_normal_matrix() - { + fn blas_column_major_2d_normal_matrix() { let m: Array2 = Array2::zeros((3, 5).f()); assert!(!blas_row_major_2d::(&m)); assert!(blas_column_major_2d::(&m)); } #[test] - fn blas_row_major_2d_skip_rows_ok() - { + fn blas_row_major_2d_skip_rows_ok() { let m: Array2 = Array2::zeros((5, 5)); let mv = m.slice(s![..;2, ..]); assert!(blas_row_major_2d::(&mv)); @@ -1017,8 +993,7 @@ mod blas_tests } #[test] - fn blas_row_major_2d_skip_columns_fail() - { + fn blas_row_major_2d_skip_columns_fail() { let m: Array2 = Array2::zeros((5, 5)); let mv = m.slice(s![.., ..;2]); assert!(!blas_row_major_2d::(&mv)); @@ -1026,8 +1001,7 @@ mod blas_tests } #[test] - fn blas_col_major_2d_skip_columns_ok() - { + fn blas_col_major_2d_skip_columns_ok() { let m: Array2 = Array2::zeros((5, 5).f()); let mv = m.slice(s![.., ..;2]); assert!(blas_column_major_2d::(&mv)); @@ -1035,8 +1009,7 @@ mod blas_tests } #[test] - fn blas_col_major_2d_skip_rows_fail() - { + fn blas_col_major_2d_skip_rows_fail() { let m: Array2 = Array2::zeros((5, 5).f()); let mv = m.slice(s![..;2, ..]); assert!(!blas_column_major_2d::(&mv)); @@ -1044,8 +1017,7 @@ mod blas_tests } #[test] - fn blas_too_short_stride() - { + fn blas_too_short_stride() { // leading stride must be longer than the other dimension // Example, in a 5 x 5 matrix, the leading stride must be >= 5 for BLAS. @@ -1153,18 +1125,18 @@ mod arrayd_dot_tests { // Test case from the original issue let mat1 = ArrayD::from_shape_vec(vec![3, 2], vec![3.0; 6]).unwrap(); let mat2 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap(); - + let result = mat1.dot(&mat2); - + // Verify the result is correct assert_eq!(result.ndim(), 2); assert_eq!(result.shape(), &[3, 3]); - + // Compare with Array2 implementation let mat1_2d = Array2::from_shape_vec((3, 2), vec![3.0; 6]).unwrap(); let mat2_2d = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap(); let expected = mat1_2d.dot(&mat2_2d); - + assert_eq!(result.into_dimensionality::().unwrap(), expected); } @@ -1173,9 +1145,9 @@ mod arrayd_dot_tests { // Test 1D array dot product let vec1 = ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).unwrap(); let vec2 = ArrayD::from_shape_vec(vec![3], vec![4.0, 5.0, 6.0]).unwrap(); - + let result = vec1.dot(&vec2); - + // Verify scalar result assert_eq!(result.ndim(), 0); assert_eq!(result.shape(), &[]); @@ -1188,7 +1160,7 @@ mod arrayd_dot_tests { // Test that 3D arrays are not supported let arr1 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap(); let arr2 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap(); - + let _result = arr1.dot(&arr2); // Should panic } @@ -1198,7 +1170,7 @@ mod arrayd_dot_tests { // Test arrays with incompatible dimensions let arr1 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap(); let arr2 = ArrayD::from_shape_vec(vec![4, 5], vec![1.0; 20]).unwrap(); - + let _result = arr1.dot(&arr2); // Should panic } @@ -1207,18 +1179,18 @@ mod arrayd_dot_tests { // Test matrix-vector multiplication let mat = ArrayD::from_shape_vec(vec![3, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(); let vec = ArrayD::from_shape_vec(vec![2], vec![1.0, 2.0]).unwrap(); - + let result = mat.dot(&vec); - + // Verify result assert_eq!(result.ndim(), 1); assert_eq!(result.shape(), &[3]); - + // Compare with Array2 implementation let mat_2d = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(); let vec_1d = Array1::from_vec(vec![1.0, 2.0]); let expected = mat_2d.dot(&vec_1d); - + assert_eq!(result.into_dimensionality::().unwrap(), expected); } } From 18ca5d53fedc8f9224cbba457f75bea2216c65a4 Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Sun, 16 Mar 2025 10:56:14 +0900 Subject: [PATCH 05/12] Revert "cargo fmt" This reverts commit 1ca94096b2cdbe28c6f1131867782c0a05c358aa. --- src/linalg/impl_linalg.rs | 200 ++++++++++++++++++++++---------------- 1 file changed, 114 insertions(+), 86 deletions(-) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index b13628880..c48967b60 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -41,8 +41,7 @@ const GEMM_BLAS_CUTOFF: usize = 7; type blas_index = c_int; // blas index type impl ArrayBase -where - S: Data, +where S: Data { /// Perform dot product or matrix multiplication of arrays `self` and `rhs`. /// @@ -63,8 +62,7 @@ where /// layout allows. #[track_caller] pub fn dot(&self, rhs: &Rhs) -> >::Output - where - Self: Dot, + where Self: Dot { Dot::dot(self, rhs) } @@ -113,9 +111,17 @@ where ($ty:ty, $func:ident) => {{ if blas_compat_1d::<$ty, _>(self) && blas_compat_1d::<$ty, _>(rhs) { unsafe { - let (lhs_ptr, n, incx) = blas_1d_params(self.ptr.as_ptr(), self.len(), self.strides()[0]); - let (rhs_ptr, _, incy) = blas_1d_params(rhs.ptr.as_ptr(), rhs.len(), rhs.strides()[0]); - let ret = blas_sys::$func(n, lhs_ptr as *const $ty, incx, rhs_ptr as *const $ty, incy); + let (lhs_ptr, n, incx) = + blas_1d_params(self.ptr.as_ptr(), self.len(), self.strides()[0]); + let (rhs_ptr, _, incy) = + blas_1d_params(rhs.ptr.as_ptr(), rhs.len(), rhs.strides()[0]); + let ret = blas_sys::$func( + n, + lhs_ptr as *const $ty, + incx, + rhs_ptr as *const $ty, + incy, + ); return cast_as::<$ty, A>(&ret); } } @@ -135,7 +141,8 @@ where /// which agrees with our pointer for non-negative strides, but /// is at the opposite end for negative strides. #[cfg(feature = "blas")] -unsafe fn blas_1d_params(ptr: *const A, len: usize, stride: isize) -> (*const A, blas_index, blas_index) { +unsafe fn blas_1d_params(ptr: *const A, len: usize, stride: isize) -> (*const A, blas_index, blas_index) +{ // [x x x x] // ^--ptr // stride = -1 @@ -152,7 +159,8 @@ unsafe fn blas_1d_params(ptr: *const A, len: usize, stride: isize) -> (*const /// /// For two-dimensional arrays, the dot method computes the matrix /// multiplication. -pub trait Dot { +pub trait Dot +{ /// The result of the operation. /// /// For two-dimensional arrays: a rectangular array. @@ -177,7 +185,8 @@ where /// *Note:* If enabled, uses blas `dot` for elements of `f32, f64` when memory /// layout allows. #[track_caller] - fn dot(&self, rhs: &ArrayBase) -> A { + fn dot(&self, rhs: &ArrayBase) -> A + { self.dot_impl(rhs) } } @@ -200,14 +209,14 @@ where /// /// **Panics** if shapes are incompatible. #[track_caller] - fn dot(&self, rhs: &ArrayBase) -> Array { + fn dot(&self, rhs: &ArrayBase) -> Array + { rhs.t().dot(self) } } impl ArrayBase -where - S: Data, +where S: Data { /// Perform matrix multiplication of rectangular arrays `self` and `rhs`. /// @@ -240,8 +249,7 @@ where /// ``` #[track_caller] pub fn dot(&self, rhs: &Rhs) -> >::Output - where - Self: Dot, + where Self: Dot { Dot::dot(self, rhs) } @@ -254,7 +262,8 @@ where A: LinalgScalar, { type Output = Array2; - fn dot(&self, b: &ArrayBase) -> Array2 { + fn dot(&self, b: &ArrayBase) -> Array2 + { let a = self.view(); let b = b.view(); let ((m, k), (k2, n)) = (a.dim(), b.dim()); @@ -280,21 +289,24 @@ where /// Assumes that `m` and `n` are ≤ `isize::MAX`. #[cold] #[inline(never)] -fn dot_shape_error(m: usize, k: usize, k2: usize, n: usize) -> ! { +fn dot_shape_error(m: usize, k: usize, k2: usize, n: usize) -> ! +{ match m.checked_mul(n) { Some(len) if len <= isize::MAX as usize => {} _ => panic!("ndarray: shape {} × {} overflows isize", m, n), } - panic!("ndarray: inputs {} × {} and {} × {} are not compatible for matrix multiplication", m, k, k2, n); + panic!( + "ndarray: inputs {} × {} and {} × {} are not compatible for matrix multiplication", + m, k, k2, n + ); } #[cold] #[inline(never)] -fn general_dot_shape_error(m: usize, k: usize, k2: usize, n: usize, c1: usize, c2: usize) -> ! { - panic!( - "ndarray: inputs {} × {}, {} × {}, and output {} × {} are not compatible for matrix multiplication", - m, k, k2, n, c1, c2 - ); +fn general_dot_shape_error(m: usize, k: usize, k2: usize, n: usize, c1: usize, c2: usize) -> ! +{ + panic!("ndarray: inputs {} × {}, {} × {}, and output {} × {} are not compatible for matrix multiplication", + m, k, k2, n, c1, c2); } /// Perform the matrix multiplication of the rectangular array `self` and @@ -314,7 +326,8 @@ where { type Output = Array; #[track_caller] - fn dot(&self, rhs: &ArrayBase) -> Array { + fn dot(&self, rhs: &ArrayBase) -> Array + { let ((m, a), n) = (self.dim(), rhs.dim()); if a != n { dot_shape_error(m, a, n, 1); @@ -360,8 +373,7 @@ use self::mat_mul_general as mat_mul_impl; #[cfg(feature = "blas")] fn mat_mul_impl(alpha: A, a: &ArrayView2<'_, A>, b: &ArrayView2<'_, A>, beta: A, c: &mut ArrayViewMut2<'_, A>) -where - A: LinalgScalar, +where A: LinalgScalar { let ((m, k), (k2, n)) = (a.dim(), b.dim()); debug_assert_eq!(k, k2); @@ -416,17 +428,17 @@ where cblas_layout, a_trans, b_trans, - m as blas_index, // m, rows of Op(a) - n as blas_index, // n, cols of Op(b) - k as blas_index, // k, cols of Op(a) - gemm_scalar_cast!($ty, alpha), // alpha - a.ptr.as_ptr() as *const _, // a - lda, // lda - b.ptr.as_ptr() as *const _, // b - ldb, // ldb - gemm_scalar_cast!($ty, beta), // beta - c.ptr.as_ptr() as *mut _, // c - ldc, // ldc + m as blas_index, // m, rows of Op(a) + n as blas_index, // n, cols of Op(b) + k as blas_index, // k, cols of Op(a) + gemm_scalar_cast!($ty, alpha), // alpha + a.ptr.as_ptr() as *const _, // a + lda, // lda + b.ptr.as_ptr() as *const _, // b + ldb, // ldb + gemm_scalar_cast!($ty, beta), // beta + c.ptr.as_ptr() as *mut _, // c + ldc, // ldc ); } return; @@ -446,9 +458,9 @@ where } /// C ← α A B + β C -fn mat_mul_general(alpha: A, lhs: &ArrayView2<'_, A>, rhs: &ArrayView2<'_, A>, beta: A, c: &mut ArrayViewMut2<'_, A>) -where - A: LinalgScalar, +fn mat_mul_general( + alpha: A, lhs: &ArrayView2<'_, A>, rhs: &ArrayView2<'_, A>, beta: A, c: &mut ArrayViewMut2<'_, A>, +) where A: LinalgScalar { let ((m, k), (_, n)) = (lhs.dim(), rhs.dim()); @@ -675,8 +687,8 @@ unsafe fn general_mat_vec_mul_impl( a_stride, // lda x_ptr as *const _, // x x_stride, - cast_as(&beta), // beta - y_ptr as *mut _, // y + cast_as(&beta), // beta + y_ptr as *mut _, // y y_stride, ); return; @@ -739,26 +751,25 @@ where #[inline(always)] /// Return `true` if `A` and `B` are the same type -fn same_type() -> bool { +fn same_type() -> bool +{ TypeId::of::() == TypeId::of::() } // Read pointer to type `A` as type `B`. // // **Panics** if `A` and `B` are not the same type -fn cast_as(a: &A) -> B { - assert!( - same_type::(), - "expect type {} and {} to match", - std::any::type_name::(), - std::any::type_name::() - ); +fn cast_as(a: &A) -> B +{ + assert!(same_type::(), "expect type {} and {} to match", + std::any::type_name::(), std::any::type_name::()); unsafe { ::std::ptr::read(a as *const _ as *const B) } } /// Return the complex in the form of an array [re, im] #[inline] -fn complex_array(z: Complex) -> [A; 2] { +fn complex_array(z: Complex) -> [A; 2] +{ [z.re, z.im] } @@ -785,14 +796,17 @@ where #[cfg(feature = "blas")] #[derive(Copy, Clone)] #[cfg_attr(test, derive(PartialEq, Eq, Debug))] -enum BlasOrder { +enum BlasOrder +{ C, F, } #[cfg(feature = "blas")] -impl BlasOrder { - fn transpose(self) -> Self { +impl BlasOrder +{ + fn transpose(self) -> Self + { match self { Self::C => Self::F, Self::F => Self::C, @@ -801,14 +815,16 @@ impl BlasOrder { #[inline] /// Axis of leading stride (opposite of contiguous axis) - fn get_blas_lead_axis(self) -> usize { + fn get_blas_lead_axis(self) -> usize + { match self { Self::C => 0, Self::F => 1, } } - fn to_cblas_layout(self) -> CBLAS_LAYOUT { + fn to_cblas_layout(self) -> CBLAS_LAYOUT + { match self { Self::C => CBLAS_LAYOUT::CblasRowMajor, Self::F => CBLAS_LAYOUT::CblasColMajor, @@ -817,7 +833,8 @@ impl BlasOrder { /// When using cblas_sgemm (etc) with C matrix using `for_layout`, /// how should this `self` matrix be transposed - fn to_cblas_transpose_for(self, for_layout: CBLAS_LAYOUT) -> CBLAS_TRANSPOSE { + fn to_cblas_transpose_for(self, for_layout: CBLAS_LAYOUT) -> CBLAS_TRANSPOSE + { let effective_order = match for_layout { CBLAS_LAYOUT::CblasRowMajor => self, CBLAS_LAYOUT::CblasColMajor => self.transpose(), @@ -831,7 +848,8 @@ impl BlasOrder { } #[cfg(feature = "blas")] -fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: BlasOrder) -> bool { +fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: BlasOrder) -> bool +{ let (m, n) = dim.into_pattern(); let s0 = stride[0] as isize; let s1 = stride[1] as isize; @@ -869,8 +887,7 @@ fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: BlasOrder) -> bool { /// Get BLAS compatible layout if any (C or F, preferring the former) #[cfg(feature = "blas")] fn get_blas_compatible_layout(a: &ArrayBase) -> Option -where - S: Data, +where S: Data { if is_blas_2d(&a.dim, &a.strides, BlasOrder::C) { Some(BlasOrder::C) @@ -887,8 +904,7 @@ where /// Return leading stride (lda, ldb, ldc) of array #[cfg(feature = "blas")] fn blas_stride(a: &ArrayBase, order: BlasOrder) -> blas_index -where - S: Data, +where S: Data { let axis = order.get_blas_lead_axis(); let other_axis = 1 - axis; @@ -937,32 +953,37 @@ where #[cfg(test)] #[cfg(feature = "blas")] -mod blas_tests { +mod blas_tests +{ use super::*; #[test] - fn blas_row_major_2d_normal_matrix() { + fn blas_row_major_2d_normal_matrix() + { let m: Array2 = Array2::zeros((3, 5)); assert!(blas_row_major_2d::(&m)); assert!(!blas_column_major_2d::(&m)); } #[test] - fn blas_row_major_2d_row_matrix() { + fn blas_row_major_2d_row_matrix() + { let m: Array2 = Array2::zeros((1, 5)); assert!(blas_row_major_2d::(&m)); assert!(blas_column_major_2d::(&m)); } #[test] - fn blas_row_major_2d_column_matrix() { + fn blas_row_major_2d_column_matrix() + { let m: Array2 = Array2::zeros((5, 1)); assert!(blas_row_major_2d::(&m)); assert!(blas_column_major_2d::(&m)); } #[test] - fn blas_row_major_2d_transposed_row_matrix() { + fn blas_row_major_2d_transposed_row_matrix() + { let m: Array2 = Array2::zeros((1, 5)); let m_t = m.t(); assert!(blas_row_major_2d::(&m_t)); @@ -970,7 +991,8 @@ mod blas_tests { } #[test] - fn blas_row_major_2d_transposed_column_matrix() { + fn blas_row_major_2d_transposed_column_matrix() + { let m: Array2 = Array2::zeros((5, 1)); let m_t = m.t(); assert!(blas_row_major_2d::(&m_t)); @@ -978,14 +1000,16 @@ mod blas_tests { } #[test] - fn blas_column_major_2d_normal_matrix() { + fn blas_column_major_2d_normal_matrix() + { let m: Array2 = Array2::zeros((3, 5).f()); assert!(!blas_row_major_2d::(&m)); assert!(blas_column_major_2d::(&m)); } #[test] - fn blas_row_major_2d_skip_rows_ok() { + fn blas_row_major_2d_skip_rows_ok() + { let m: Array2 = Array2::zeros((5, 5)); let mv = m.slice(s![..;2, ..]); assert!(blas_row_major_2d::(&mv)); @@ -993,7 +1017,8 @@ mod blas_tests { } #[test] - fn blas_row_major_2d_skip_columns_fail() { + fn blas_row_major_2d_skip_columns_fail() + { let m: Array2 = Array2::zeros((5, 5)); let mv = m.slice(s![.., ..;2]); assert!(!blas_row_major_2d::(&mv)); @@ -1001,7 +1026,8 @@ mod blas_tests { } #[test] - fn blas_col_major_2d_skip_columns_ok() { + fn blas_col_major_2d_skip_columns_ok() + { let m: Array2 = Array2::zeros((5, 5).f()); let mv = m.slice(s![.., ..;2]); assert!(blas_column_major_2d::(&mv)); @@ -1009,7 +1035,8 @@ mod blas_tests { } #[test] - fn blas_col_major_2d_skip_rows_fail() { + fn blas_col_major_2d_skip_rows_fail() + { let m: Array2 = Array2::zeros((5, 5).f()); let mv = m.slice(s![..;2, ..]); assert!(!blas_column_major_2d::(&mv)); @@ -1017,7 +1044,8 @@ mod blas_tests { } #[test] - fn blas_too_short_stride() { + fn blas_too_short_stride() + { // leading stride must be longer than the other dimension // Example, in a 5 x 5 matrix, the leading stride must be >= 5 for BLAS. @@ -1125,18 +1153,18 @@ mod arrayd_dot_tests { // Test case from the original issue let mat1 = ArrayD::from_shape_vec(vec![3, 2], vec![3.0; 6]).unwrap(); let mat2 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap(); - + let result = mat1.dot(&mat2); - + // Verify the result is correct assert_eq!(result.ndim(), 2); assert_eq!(result.shape(), &[3, 3]); - + // Compare with Array2 implementation let mat1_2d = Array2::from_shape_vec((3, 2), vec![3.0; 6]).unwrap(); let mat2_2d = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap(); let expected = mat1_2d.dot(&mat2_2d); - + assert_eq!(result.into_dimensionality::().unwrap(), expected); } @@ -1145,9 +1173,9 @@ mod arrayd_dot_tests { // Test 1D array dot product let vec1 = ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).unwrap(); let vec2 = ArrayD::from_shape_vec(vec![3], vec![4.0, 5.0, 6.0]).unwrap(); - + let result = vec1.dot(&vec2); - + // Verify scalar result assert_eq!(result.ndim(), 0); assert_eq!(result.shape(), &[]); @@ -1160,7 +1188,7 @@ mod arrayd_dot_tests { // Test that 3D arrays are not supported let arr1 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap(); let arr2 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap(); - + let _result = arr1.dot(&arr2); // Should panic } @@ -1170,7 +1198,7 @@ mod arrayd_dot_tests { // Test arrays with incompatible dimensions let arr1 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap(); let arr2 = ArrayD::from_shape_vec(vec![4, 5], vec![1.0; 20]).unwrap(); - + let _result = arr1.dot(&arr2); // Should panic } @@ -1179,18 +1207,18 @@ mod arrayd_dot_tests { // Test matrix-vector multiplication let mat = ArrayD::from_shape_vec(vec![3, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(); let vec = ArrayD::from_shape_vec(vec![2], vec![1.0, 2.0]).unwrap(); - + let result = mat.dot(&vec); - + // Verify result assert_eq!(result.ndim(), 1); assert_eq!(result.shape(), &[3]); - + // Compare with Array2 implementation let mat_2d = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(); let vec_1d = Array1::from_vec(vec![1.0, 2.0]); let expected = mat_2d.dot(&vec_1d); - + assert_eq!(result.into_dimensionality::().unwrap(), expected); } } From 3748a3b49db5dac0f2dfe2d5a685c25d14c3bd18 Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Mon, 17 Mar 2025 06:27:48 +0900 Subject: [PATCH 06/12] cargo nigthly fmt --- src/linalg/impl_linalg.rs | 45 ++++++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index c48967b60..40538727d 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -1109,7 +1109,8 @@ where { type Output = Array; - fn dot(&self, rhs: &ArrayBase) -> Self::Output { + fn dot(&self, rhs: &ArrayBase) -> Self::Output + { match (self.ndim(), rhs.ndim()) { (1, 1) => { let a = self.view().into_dimensionality::().unwrap(); @@ -1144,38 +1145,41 @@ where } #[cfg(test)] -mod arrayd_dot_tests { +mod arrayd_dot_tests +{ use super::*; use crate::ArrayD; #[test] - fn test_arrayd_dot_2d() { + fn test_arrayd_dot_2d() + { // Test case from the original issue let mat1 = ArrayD::from_shape_vec(vec![3, 2], vec![3.0; 6]).unwrap(); let mat2 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap(); - + let result = mat1.dot(&mat2); - + // Verify the result is correct assert_eq!(result.ndim(), 2); assert_eq!(result.shape(), &[3, 3]); - + // Compare with Array2 implementation let mat1_2d = Array2::from_shape_vec((3, 2), vec![3.0; 6]).unwrap(); let mat2_2d = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap(); let expected = mat1_2d.dot(&mat2_2d); - + assert_eq!(result.into_dimensionality::().unwrap(), expected); } #[test] - fn test_arrayd_dot_1d() { + fn test_arrayd_dot_1d() + { // Test 1D array dot product let vec1 = ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).unwrap(); let vec2 = ArrayD::from_shape_vec(vec![3], vec![4.0, 5.0, 6.0]).unwrap(); - + let result = vec1.dot(&vec2); - + // Verify scalar result assert_eq!(result.ndim(), 0); assert_eq!(result.shape(), &[]); @@ -1184,41 +1188,44 @@ mod arrayd_dot_tests { #[test] #[should_panic(expected = "Dot product for ArrayD is only supported for 1D and 2D arrays")] - fn test_arrayd_dot_3d() { + fn test_arrayd_dot_3d() + { // Test that 3D arrays are not supported let arr1 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap(); let arr2 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap(); - + let _result = arr1.dot(&arr2); // Should panic } #[test] #[should_panic(expected = "ndarray: inputs 2 × 3 and 4 × 5 are not compatible for matrix multiplication")] - fn test_arrayd_dot_incompatible_dims() { + fn test_arrayd_dot_incompatible_dims() + { // Test arrays with incompatible dimensions let arr1 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap(); let arr2 = ArrayD::from_shape_vec(vec![4, 5], vec![1.0; 20]).unwrap(); - + let _result = arr1.dot(&arr2); // Should panic } #[test] - fn test_arrayd_dot_matrix_vector() { + fn test_arrayd_dot_matrix_vector() + { // Test matrix-vector multiplication let mat = ArrayD::from_shape_vec(vec![3, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(); let vec = ArrayD::from_shape_vec(vec![2], vec![1.0, 2.0]).unwrap(); - + let result = mat.dot(&vec); - + // Verify result assert_eq!(result.ndim(), 1); assert_eq!(result.shape(), &[3]); - + // Compare with Array2 implementation let mat_2d = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(); let vec_1d = Array1::from_vec(vec![1.0, 2.0]); let expected = mat_2d.dot(&vec_1d); - + assert_eq!(result.into_dimensionality::().unwrap(), expected); } } From ac7de365b5ecb9ece7dfc0f1b03d9b839f614a2e Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Mon, 17 Mar 2025 07:34:56 +0900 Subject: [PATCH 07/12] not to exectue example for CI --- src/linalg/impl_linalg.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index 40538727d..de60ff76f 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -1088,7 +1088,7 @@ mod blas_tests /// /// # Examples /// -/// ``` +/// ```no_run /// use ndarray::{ArrayD, Array2, Array1}; /// /// // Matrix multiplication From da24369cc513ec8193a3278c473e9e672d5f7db1 Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Mon, 17 Mar 2025 07:45:56 +0900 Subject: [PATCH 08/12] add missing use block(vec macro) --- src/linalg/impl_linalg.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index de60ff76f..c11970907 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -16,6 +16,9 @@ use crate::{LinalgScalar, Zip}; #[cfg(not(feature = "std"))] use alloc::vec::Vec; +#[cfg(not(feature = "std"))] +use alloc::vec; + use std::any::TypeId; use std::mem::MaybeUninit; From 58418919671f9df568111c5478d66db8a999493f Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Mon, 17 Mar 2025 08:03:25 +0900 Subject: [PATCH 09/12] correct example --- src/linalg/impl_linalg.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index c11970907..4f7dee5ac 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -1091,9 +1091,10 @@ mod blas_tests /// /// # Examples /// -/// ```no_run -/// use ndarray::{ArrayD, Array2, Array1}; -/// +/// ``` +/// use ndarray::ArrayD; +/// use alloc::vec; + /// // Matrix multiplication /// let a = ArrayD::from_shape_vec(vec![2, 3], vec![1., 2., 3., 4., 5., 6.]).unwrap(); /// let b = ArrayD::from_shape_vec(vec![3, 2], vec![1., 2., 3., 4., 5., 6.]).unwrap(); From 28cc3b09a193f69ecd51f9315d88476660851d71 Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Mon, 17 Mar 2025 08:47:47 +0900 Subject: [PATCH 10/12] make sure example runs OK, apply fmt --- src/linalg/impl_linalg.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index 4f7dee5ac..59cc375b7 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -14,10 +14,10 @@ use crate::numeric_util; use crate::{LinalgScalar, Zip}; -#[cfg(not(feature = "std"))] -use alloc::vec::Vec; #[cfg(not(feature = "std"))] use alloc::vec; +#[cfg(not(feature = "std"))] +use alloc::vec::Vec; use std::any::TypeId; use std::mem::MaybeUninit; @@ -1092,9 +1092,8 @@ mod blas_tests /// # Examples /// /// ``` -/// use ndarray::ArrayD; -/// use alloc::vec; - +/// use ndarray::{ArrayD, linalg::Dot}; +/// /// // Matrix multiplication /// let a = ArrayD::from_shape_vec(vec![2, 3], vec![1., 2., 3., 4., 5., 6.]).unwrap(); /// let b = ArrayD::from_shape_vec(vec![3, 2], vec![1., 2., 3., 4., 5., 6.]).unwrap(); From ba3c02914d758bfb50e8db2fe235ddbc3112398d Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Mon, 17 Mar 2025 11:48:24 +0900 Subject: [PATCH 11/12] move tests into blas-tests --- crates/blas-tests/tests/dyn.rs | 75 ++++++++++++++++++++++++ src/linalg/impl_linalg.rs | 103 +-------------------------------- 2 files changed, 76 insertions(+), 102 deletions(-) create mode 100644 crates/blas-tests/tests/dyn.rs diff --git a/crates/blas-tests/tests/dyn.rs b/crates/blas-tests/tests/dyn.rs new file mode 100644 index 000000000..e851e37cc --- /dev/null +++ b/crates/blas-tests/tests/dyn.rs @@ -0,0 +1,75 @@ +extern crate blas_src; +use ndarray::{Array1, Array2, ArrayD, linalg::Dot, Ix1, Ix2}; + +#[test] +fn test_arrayd_dot_2d() { + let mat1 = ArrayD::from_shape_vec(vec![3, 2], vec![3.0; 6]).unwrap(); + let mat2 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap(); + + let result = mat1.dot(&mat2); + + // Verify the result is correct + assert_eq!(result.ndim(), 2); + assert_eq!(result.shape(), &[3, 3]); + + // Compare with Array2 implementation + let mat1_2d = Array2::from_shape_vec((3, 2), vec![3.0; 6]).unwrap(); + let mat2_2d = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap(); + let expected = mat1_2d.dot(&mat2_2d); + + assert_eq!(result.into_dimensionality::().unwrap(), expected); +} + +#[test] +fn test_arrayd_dot_1d() { + // Test 1D array dot product + let vec1 = ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).unwrap(); + let vec2 = ArrayD::from_shape_vec(vec![3], vec![4.0, 5.0, 6.0]).unwrap(); + + let result = vec1.dot(&vec2); + + // Verify scalar result + assert_eq!(result.ndim(), 0); + assert_eq!(result.shape(), &[]); + assert_eq!(result[[]], 32.0); // 1*4 + 2*5 + 3*6 +} + +#[test] +#[should_panic(expected = "Dot product for ArrayD is only supported for 1D and 2D arrays")] +fn test_arrayd_dot_3d() { + // Test that 3D arrays are not supported + let arr1 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap(); + let arr2 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap(); + + let _result = arr1.dot(&arr2); // Should panic +} + +#[test] +#[should_panic(expected = "ndarray: inputs 2 × 3 and 4 × 5 are not compatible for matrix multiplication")] +fn test_arrayd_dot_incompatible_dims() { + // Test arrays with incompatible dimensions + let arr1 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap(); + let arr2 = ArrayD::from_shape_vec(vec![4, 5], vec![1.0; 20]).unwrap(); + + let _result = arr1.dot(&arr2); // Should panic +} + +#[test] +fn test_arrayd_dot_matrix_vector() { + // Test matrix-vector multiplication + let mat = ArrayD::from_shape_vec(vec![3, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(); + let vec = ArrayD::from_shape_vec(vec![2], vec![1.0, 2.0]).unwrap(); + + let result = mat.dot(&vec); + + // Verify result + assert_eq!(result.ndim(), 1); + assert_eq!(result.shape(), &[3]); + + // Compare with Array2 implementation + let mat_2d = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(); + let vec_1d = Array1::from_vec(vec![1.0, 2.0]); + let expected = mat_2d.dot(&vec_1d); + + assert_eq!(result.into_dimensionality::().unwrap(), expected); +} \ No newline at end of file diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index 59cc375b7..b3470778d 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -1089,21 +1089,6 @@ mod blas_tests /// - The array shapes are incompatible for the operation /// - For vector dot product: the vectors have different lengths /// -/// # Examples -/// -/// ``` -/// use ndarray::{ArrayD, linalg::Dot}; -/// -/// // Matrix multiplication -/// let a = ArrayD::from_shape_vec(vec![2, 3], vec![1., 2., 3., 4., 5., 6.]).unwrap(); -/// let b = ArrayD::from_shape_vec(vec![3, 2], vec![1., 2., 3., 4., 5., 6.]).unwrap(); -/// let c = a.dot(&b); -/// -/// // Vector dot product -/// let v1 = ArrayD::from_shape_vec(vec![3], vec![1., 2., 3.]).unwrap(); -/// let v2 = ArrayD::from_shape_vec(vec![3], vec![4., 5., 6.]).unwrap(); -/// let scalar = v1.dot(&v2); -/// ``` impl Dot> for ArrayBase where S: Data, @@ -1145,90 +1130,4 @@ where _ => panic!("Dot product for ArrayD is only supported for 1D and 2D arrays"), } } -} - -#[cfg(test)] -mod arrayd_dot_tests -{ - use super::*; - use crate::ArrayD; - - #[test] - fn test_arrayd_dot_2d() - { - // Test case from the original issue - let mat1 = ArrayD::from_shape_vec(vec![3, 2], vec![3.0; 6]).unwrap(); - let mat2 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap(); - - let result = mat1.dot(&mat2); - - // Verify the result is correct - assert_eq!(result.ndim(), 2); - assert_eq!(result.shape(), &[3, 3]); - - // Compare with Array2 implementation - let mat1_2d = Array2::from_shape_vec((3, 2), vec![3.0; 6]).unwrap(); - let mat2_2d = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap(); - let expected = mat1_2d.dot(&mat2_2d); - - assert_eq!(result.into_dimensionality::().unwrap(), expected); - } - - #[test] - fn test_arrayd_dot_1d() - { - // Test 1D array dot product - let vec1 = ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).unwrap(); - let vec2 = ArrayD::from_shape_vec(vec![3], vec![4.0, 5.0, 6.0]).unwrap(); - - let result = vec1.dot(&vec2); - - // Verify scalar result - assert_eq!(result.ndim(), 0); - assert_eq!(result.shape(), &[]); - assert_eq!(result[[]], 32.0); // 1*4 + 2*5 + 3*6 - } - - #[test] - #[should_panic(expected = "Dot product for ArrayD is only supported for 1D and 2D arrays")] - fn test_arrayd_dot_3d() - { - // Test that 3D arrays are not supported - let arr1 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap(); - let arr2 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap(); - - let _result = arr1.dot(&arr2); // Should panic - } - - #[test] - #[should_panic(expected = "ndarray: inputs 2 × 3 and 4 × 5 are not compatible for matrix multiplication")] - fn test_arrayd_dot_incompatible_dims() - { - // Test arrays with incompatible dimensions - let arr1 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap(); - let arr2 = ArrayD::from_shape_vec(vec![4, 5], vec![1.0; 20]).unwrap(); - - let _result = arr1.dot(&arr2); // Should panic - } - - #[test] - fn test_arrayd_dot_matrix_vector() - { - // Test matrix-vector multiplication - let mat = ArrayD::from_shape_vec(vec![3, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(); - let vec = ArrayD::from_shape_vec(vec![2], vec![1.0, 2.0]).unwrap(); - - let result = mat.dot(&vec); - - // Verify result - assert_eq!(result.ndim(), 1); - assert_eq!(result.shape(), &[3]); - - // Compare with Array2 implementation - let mat_2d = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(); - let vec_1d = Array1::from_vec(vec![1.0, 2.0]); - let expected = mat_2d.dot(&vec_1d); - - assert_eq!(result.into_dimensionality::().unwrap(), expected); - } -} +} \ No newline at end of file From e577db35bb53675e215c3003c54ad4656b1272d4 Mon Sep 17 00:00:00 2001 From: NewBornRustacean Date: Mon, 17 Mar 2025 11:49:48 +0900 Subject: [PATCH 12/12] apply nightly formatter --- crates/blas-tests/tests/dyn.rs | 19 ++++++++++++------- src/linalg/impl_linalg.rs | 2 +- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/crates/blas-tests/tests/dyn.rs b/crates/blas-tests/tests/dyn.rs index e851e37cc..6c0fd975e 100644 --- a/crates/blas-tests/tests/dyn.rs +++ b/crates/blas-tests/tests/dyn.rs @@ -1,8 +1,9 @@ extern crate blas_src; -use ndarray::{Array1, Array2, ArrayD, linalg::Dot, Ix1, Ix2}; +use ndarray::{linalg::Dot, Array1, Array2, ArrayD, Ix1, Ix2}; #[test] -fn test_arrayd_dot_2d() { +fn test_arrayd_dot_2d() +{ let mat1 = ArrayD::from_shape_vec(vec![3, 2], vec![3.0; 6]).unwrap(); let mat2 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap(); @@ -21,7 +22,8 @@ fn test_arrayd_dot_2d() { } #[test] -fn test_arrayd_dot_1d() { +fn test_arrayd_dot_1d() +{ // Test 1D array dot product let vec1 = ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).unwrap(); let vec2 = ArrayD::from_shape_vec(vec![3], vec![4.0, 5.0, 6.0]).unwrap(); @@ -36,7 +38,8 @@ fn test_arrayd_dot_1d() { #[test] #[should_panic(expected = "Dot product for ArrayD is only supported for 1D and 2D arrays")] -fn test_arrayd_dot_3d() { +fn test_arrayd_dot_3d() +{ // Test that 3D arrays are not supported let arr1 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap(); let arr2 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap(); @@ -46,7 +49,8 @@ fn test_arrayd_dot_3d() { #[test] #[should_panic(expected = "ndarray: inputs 2 × 3 and 4 × 5 are not compatible for matrix multiplication")] -fn test_arrayd_dot_incompatible_dims() { +fn test_arrayd_dot_incompatible_dims() +{ // Test arrays with incompatible dimensions let arr1 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap(); let arr2 = ArrayD::from_shape_vec(vec![4, 5], vec![1.0; 20]).unwrap(); @@ -55,7 +59,8 @@ fn test_arrayd_dot_incompatible_dims() { } #[test] -fn test_arrayd_dot_matrix_vector() { +fn test_arrayd_dot_matrix_vector() +{ // Test matrix-vector multiplication let mat = ArrayD::from_shape_vec(vec![3, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(); let vec = ArrayD::from_shape_vec(vec![2], vec![1.0, 2.0]).unwrap(); @@ -72,4 +77,4 @@ fn test_arrayd_dot_matrix_vector() { let expected = mat_2d.dot(&vec_1d); assert_eq!(result.into_dimensionality::().unwrap(), expected); -} \ No newline at end of file +} diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index b3470778d..e05740378 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -1130,4 +1130,4 @@ where _ => panic!("Dot product for ArrayD is only supported for 1D and 2D arrays"), } } -} \ No newline at end of file +}