diff --git a/examples/matmul.rs b/examples/matmul.rs index a6b04dab9..80e500eec 100644 --- a/examples/matmul.rs +++ b/examples/matmul.rs @@ -6,7 +6,7 @@ use ndarray::Array; fn main() { - let mat = Array::range(0.0f32, 16.0).reshape((2, 4, 2)); + let mat = Array::range(0.0f32, 16.0).reshape_clone((2, 4, 2)); println!("{a:?}\n times \n{b:?}\nis equal to:\n{c:?}", a=mat.subview(2,1), b=mat.subview(0,1), diff --git a/src/arrayformat.rs b/src/arrayformat.rs index 65509f15c..34c1f7df5 100644 --- a/src/arrayformat.rs +++ b/src/arrayformat.rs @@ -1,12 +1,14 @@ use std::fmt; use super::{Array, Dimension}; +use std::ops::Deref; /// HACK: fmt::rt::FlagAlternate has been hidden away const FLAG_ALTERNATE: usize = 2; -fn format_array(view: &Array, f: &mut fmt::Formatter, - mut format: F) -> fmt::Result where - F: FnMut(&mut fmt::Formatter, &A) -> fmt::Result, +fn format_array(view: &Array, + f: &mut fmt::Formatter, + mut format: F) -> fmt::Result where + F: FnMut(&mut fmt::Formatter, &A) -> fmt::Result, S: Deref { let ndim = view.dim.slice().len(); /* private nowadays @@ -71,7 +73,8 @@ fn format_array(view: &Array, f: &mut fmt::Formatter, } // NOTE: We can impl other fmt traits here -impl<'a, A: fmt::Display, D: Dimension> fmt::Display for Array +impl<'a, A: fmt::Display, S, D: Dimension> fmt::Display for Array +where S: Deref { /// Format the array using `Display` and apply the formatting parameters used /// to each element. @@ -83,7 +86,8 @@ impl<'a, A: fmt::Display, D: Dimension> fmt::Display for Array } } -impl<'a, A: fmt::Debug, D: Dimension> fmt::Debug for Array +impl<'a, A: fmt::Debug, S, D: Dimension> fmt::Debug for Array +where S: Deref { /// Format the array using `Debug` and apply the formatting parameters used /// to each element. @@ -95,7 +99,8 @@ impl<'a, A: fmt::Debug, D: Dimension> fmt::Debug for Array } } -impl<'a, A: fmt::LowerExp, D: Dimension> fmt::LowerExp for Array +impl<'a, A: fmt::LowerExp, S, D: Dimension> fmt::LowerExp for Array +where S: Deref { /// Format the array using `LowerExp` and apply the formatting parameters used /// to each element. @@ -107,7 +112,8 @@ impl<'a, A: fmt::LowerExp, D: Dimension> fmt::LowerExp for Array } } -impl<'a, A: fmt::UpperExp, D: Dimension> fmt::UpperExp for Array +impl<'a, A: fmt::UpperExp, S, D: Dimension> fmt::UpperExp for Array +where S: Deref { /// Format the array using `UpperExp` and apply the formatting parameters used /// to each element. diff --git a/src/arraytraits.rs b/src/arraytraits.rs index bf074e5c0..f64c6001b 100644 --- a/src/arraytraits.rs +++ b/src/arraytraits.rs @@ -7,11 +7,14 @@ use std::iter::IntoIterator; use std::ops::{ Index, IndexMut, + Deref, + DerefMut }; use super::{Array, Dimension, Ix, Elements, ElementsMut}; -impl<'a, A, D: Dimension> Index for Array +impl<'a, A, S, D: Dimension> Index for Array +where S: Deref { type Output = A; #[inline] @@ -23,7 +26,8 @@ impl<'a, A, D: Dimension> Index for Array } } -impl<'a, A: Clone, D: Dimension> IndexMut for Array +impl<'a, A: Clone, S, D: Dimension> IndexMut for Array +where S: DerefMut { #[inline] /// Access the element at **index** mutably. @@ -35,31 +39,34 @@ impl<'a, A: Clone, D: Dimension> IndexMut for Array } -impl -PartialEq for Array +impl +PartialEq for Array +where S: Deref { /// Return `true` if the array shapes and all elements of `self` and /// `other` are equal. Return `false` otherwise. - fn eq(&self, other: &Array) -> bool + fn eq(&self, other: &Array) -> bool { self.shape() == other.shape() && self.iter().zip(other.iter()).all(|(a, b)| a == b) } } -impl -Eq for Array {} +impl +Eq for Array +where S: Deref +{} -impl FromIterator for Array +impl FromIterator for Array, Ix> { - fn from_iter>(it: I) -> Array + fn from_iter>(it: I) -> Array, Ix> { Array::from_iter(it.into_iter()) } } -impl<'a, A, D> IntoIterator for &'a Array where - D: Dimension, +impl<'a, A, S, D> IntoIterator for &'a Array where + D: Dimension, S: Deref { type Item = &'a A; type IntoIter = Elements<'a, A, D>; @@ -70,9 +77,10 @@ impl<'a, A, D> IntoIterator for &'a Array where } } -impl<'a, A, D> IntoIterator for &'a mut Array where +impl<'a, A, S, D> IntoIterator for &'a mut Array where A: Clone, D: Dimension, + S: Deref, { type Item = &'a mut A; type IntoIter = ElementsMut<'a, A, D>; @@ -83,10 +91,10 @@ impl<'a, A, D> IntoIterator for &'a mut Array where } } -impl -hash::Hash for Array +impl, D: Dimension> +hash::Hash for Array { - fn hash(&self, state: &mut S) + fn hash(&self, state: &mut H) { self.shape().hash(state); for elt in self.iter() { @@ -100,7 +108,8 @@ hash::Hash for Array static ARRAY_FORMAT_VERSION: u8 = 1u8; #[cfg(feature = "rustc-serialize")] -impl Encodable for Array +impl Encodable for Array +where S: Deref { fn encode(&self, s: &mut S) -> Result<(), S::Error> { @@ -129,10 +138,11 @@ impl Encodable for Array } #[cfg(feature = "rustc-serialize")] -impl - Decodable for Array +impl + Decodable for Array +where S: Deref { - fn decode(d: &mut S) -> Result, S::Error> + fn decode(d: &mut Dec) -> Result, Dec::Error> { d.read_struct("Array", 3, |d| { let version: u8 = try!(d.read_struct_field("v", 0, Decodable::decode)); diff --git a/src/lib.rs b/src/lib.rs index 83b21161c..d6e62e750 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,8 +14,8 @@ extern crate itertools as it; #[cfg(not(nocomplex))] extern crate num as libnum; +use std::ops::{Deref, DerefMut}; use std::mem; -use std::rc::Rc; use libnum::Float; use std::ops::{Add, Sub, Mul, Div, Rem, Neg, Not, Shr, Shl, BitAnd, @@ -45,9 +45,9 @@ mod si; // NOTE: In theory, the whole library should compile // and pass tests even if you change Ix and Ixs. /// Array index type -pub type Ix = u32; +pub type Ix = usize; /// Array index type (signed) -pub type Ixs = i32; +pub type Ixs = isize; /// The **Array** type is an *N-dimensional array*. /// @@ -101,10 +101,10 @@ pub type Ixs = i32; /// ); /// ``` /// -pub struct Array { +pub struct Array where S: Deref { // FIXME: Unsafecell around vec needed? /// Rc data when used as view, Uniquely held data when being mutated - data: Rc>, + data: S, /// A pointer into the buffer held by data, may point anywhere /// in its range. ptr: *mut A, @@ -114,37 +114,49 @@ pub struct Array { strides: D, } -impl Clone for Array +pub type ArrayOwned = Array, D>; +pub type ArrayView<'a, A, D> = Array; +pub type ArrayViewMut<'a, A, D> = Array; + +impl Clone for Array, D> +where A: Clone { - fn clone(&self) -> Array { + fn clone(&self) -> ArrayOwned { + let mut data = self.data.clone(); + let mut ptr = data.as_mut_ptr(); + let offset = (self.ptr as isize - self.data.as_ptr() as isize) + / mem::size_of::() as isize; + unsafe { + ptr = ptr.offset(offset); + } Array { - data: self.data.clone(), - ptr: self.ptr, + data: data, + ptr: ptr, dim: self.dim.clone(), strides: self.strides.clone(), } } } -impl Array +impl Array, Ix> { /// Create a one-dimensional array from a vector (no allocation needed). - pub fn from_vec(v: Vec) -> Array { + pub fn from_vec(v: Vec) -> ArrayOwned { unsafe { Array::from_vec_dim(v.len() as Ix, v) } } /// Create a one-dimensional array from an iterator. - pub fn from_iter>(it: I) -> Array { + pub fn from_iter>(it: I) -> ArrayOwned { Array::from_vec(it.collect()) } } -impl Array +impl Array, Ix> { /// Create a one-dimensional Array from interval **[begin, end)** - pub fn range(begin: f32, end: f32) -> Array + pub fn range(begin: f32, end: f32) -> ArrayOwned { let n = (end - begin) as usize; let span = if n > 0 { (n - 1) as f32 } else { 0. }; @@ -154,10 +166,24 @@ impl Array } } -impl Array where D: Dimension +impl Array, D> where D: Dimension { + /// Create an array from a vector (with no allocation needed). + /// + /// Unsafe because dimension is unchecked, and must be correct. + pub unsafe fn from_vec_dim(dim: D, mut v: Vec) -> ArrayOwned + { + debug_assert!(dim.size() == v.len()); + Array { + ptr: v.as_mut_ptr(), + data: v, + strides: dim.default_strides(), + dim: dim + } + } + /// Construct an Array with zeros. - pub fn zeros(dim: D) -> Array where A: Clone + libnum::Zero + pub fn zeros(dim: D) -> ArrayOwned where A: Clone + libnum::Zero { Array::from_elem(dim, libnum::zero()) } @@ -179,7 +205,7 @@ impl Array where D: Dimension /// [1., 1.]]]) /// ); /// ``` - pub fn from_elem(dim: D, elem: A) -> Array where A: Clone + pub fn from_elem(dim: D, elem: A) -> ArrayOwned where A: Clone { let v = std::iter::repeat(elem).take(dim.size()).collect(); unsafe { @@ -187,20 +213,10 @@ impl Array where D: Dimension } } - /// Create an array from a vector (with no allocation needed). - /// - /// Unsafe because dimension is unchecked, and must be correct. - pub unsafe fn from_vec_dim(dim: D, mut v: Vec) -> Array - { - debug_assert!(dim.size() == v.len()); - Array { - ptr: v.as_mut_ptr(), - data: std::rc::Rc::new(v), - strides: dim.default_strides(), - dim: dim - } - } +} +impl Array where D: Dimension, S: Deref +{ /// Return the total number of elements in the Array. pub fn len(&self) -> usize { @@ -233,17 +249,56 @@ impl Array where D: Dimension /// **Note:** Data memory order may not correspond to the index order /// of the array. Neither is the raw data slice is restricted to just the /// Array's view. - pub fn raw_data<'a>(&'a self) -> &'a [A] + pub fn raw_data(& self) -> &[A] { - &self.data + &self.data[..] + } + + /// Get a view (borrow) into this array + pub fn view(&self) -> ArrayView { + ArrayView { + data: &self.data[..], + ptr: self.ptr.clone(), + dim: self.dim.clone(), + strides: self.strides.clone(), + } + } + + /// Get a mutable view (borrow) into this array + pub fn view_mut(&mut self) -> ArrayViewMut + where S: DerefMut { + ArrayViewMut { + data: &mut self.data[..], + ptr: self.ptr.clone(), + dim: self.dim.clone(), + strides: self.strides.clone(), + } + } + + /// Get an owned copy of this array + pub fn to_owned(&self) -> ArrayOwned + where A: Clone { + let mut res = ArrayOwned { + data: self.data.to_vec(), + ptr: self.ptr.clone(), + dim: self.dim.clone(), + strides: self.strides.clone(), + }; + res.ptr = res.data.as_mut_ptr(); + let offset = (self.ptr as isize - self.data.as_ptr() as isize) + / mem::size_of::() as isize; + unsafe { + res.ptr = res.ptr.offset(offset); + } + res } /// Return a sliced array. /// /// **Panics** if **indexes** does not match the number of array axes. - pub fn slice(&self, indexes: &[Si]) -> Array + pub fn slice(&self, indexes: &[Si]) -> ArrayView { - let mut arr = self.clone(); + let mut arr = self.view(); arr.islice(indexes); arr } @@ -298,10 +353,8 @@ impl Array where D: Dimension /// Return a mutable reference to the element at **index**. /// /// **Note:** Only unchecked for non-debug builds of ndarray.
- /// **Note:** The array must be uniquely held when mutating it. #[inline] pub unsafe fn uchk_at_mut(&mut self, index: D) -> &mut A { - debug_assert!(Rc::get_mut(&mut self.data).is_some()); debug_assert!(self.dim.stride_offset_checked(&self.strides, &index).is_some()); let off = Dimension::stride_offset(&index, &self.strides); &mut *self.ptr.offset(off) @@ -486,10 +539,10 @@ impl Array where D: Dimension } /// Return the diagonal as a one-dimensional array. - pub fn diag(&self) -> Array { + pub fn diag(&self) -> ArrayView { let (len, stride) = self.diag_params(); Array { - data: self.data.clone(), + data: &self.data[..], ptr: self.ptr, dim: len, strides: stride as Ix, @@ -511,7 +564,7 @@ impl Array where D: Dimension /// == arr2(&[[0, 1], [1, 2]]) /// ); /// ``` - pub fn map<'a, B, F>(&'a self, mut f: F) -> Array where + pub fn map<'a, B, F>(&'a self, mut f: F) -> ArrayOwned where F: FnMut(&'a A) -> B { let mut res = Vec::::with_capacity(self.dim.size()); @@ -535,18 +588,19 @@ impl Array where D: Dimension /// [3., 4.]]); /// /// assert!( - /// a.subview(0, 0) == arr1(&[1., 2.]) && - /// a.subview(1, 1) == arr1(&[2., 4.]) + /// a.subview(0, 0) == arr1(&[1., 2.]).view() && + /// a.subview(1, 1) == arr1(&[2., 4.]).view() /// ); /// ``` - pub fn subview(&self, axis: usize, index: Ix) -> Array::Smaller> where + pub fn subview(&self, axis: usize, + index: Ix) -> ArrayView::Smaller> where D: RemoveAxis { - let mut res = self.clone(); + let mut res = self.view(); res.isubview(axis, index); // don't use reshape -- we always know it will fit the size, // and we can use remove_axis on the strides as well - Array{ + ArrayView { data: res.data, ptr: res.ptr, dim: res.dim.remove_axis(axis), @@ -554,34 +608,10 @@ impl Array where D: Dimension } } - /// Make the array unshared. - /// - /// This method is mostly only useful with unsafe code. - pub fn ensure_unique(&mut self) where A: Clone - { - if Rc::get_mut(&mut self.data).is_some() { - return - } - if self.dim.size() <= self.data.len() / 2 { - unsafe { - *self = Array::from_vec_dim(self.dim.clone(), - self.iter().map(|x| x.clone()).collect()); - } - return; - } - let our_off = (self.ptr as isize - self.data.as_ptr() as isize) - / mem::size_of::
() as isize; - let rvec = Rc::make_mut(&mut self.data); - unsafe { - self.ptr = rvec.as_mut_ptr().offset(our_off); - } - } - /// Return a mutable reference to the element at **index**, or return **None** /// if the index is out of bounds. pub fn at_mut<'a>(&'a mut self, index: D) -> Option<&'a mut A> where A: Clone { - self.ensure_unique(); self.dim.stride_offset_checked(&self.strides, &index) .map(|offset| unsafe { &mut *self.ptr.offset(offset) @@ -593,7 +623,6 @@ impl Array where D: Dimension /// Iterator element type is **&'a mut A**. pub fn iter_mut<'a>(&'a mut self) -> ElementsMut<'a, A, D> where A: Clone { - self.ensure_unique(); ElementsMut { inner: self.base_iter() } } @@ -638,7 +667,6 @@ impl Array where D: Dimension /// Return an iterator over the diagonal elements of the array. pub fn diag_iter_mut<'a>(&'a mut self) -> ElementsMut<'a, A, Ix> where A: Clone { - self.ensure_unique(); let (len, stride) = self.diag_params(); unsafe { ElementsMut { inner: @@ -655,47 +683,95 @@ impl Array where D: Dimension /// /// **Note:** The data is uniquely held and nonaliased /// while it is mutably borrowed. - pub fn raw_data_mut<'a>(&'a mut self) -> &'a mut [A] - where A: Clone + pub fn raw_data_mut(&mut self) -> &mut [A] + where A: Clone, S: DerefMut { - &mut Rc::make_mut(&mut self.data)[..] + &mut self.data[..] } /// Transform the array into **shape**; any other shape /// with the same number of elements is accepted. /// - /// **Panics** if sizes are incompatible. + /// **Panics** if sizes are incompatible or the reshape can't be done + /// without cloning /// /// ``` /// use ndarray::{arr1, arr2}; /// /// assert!( - /// arr1(&[1., 2., 3., 4.]).reshape((2, 2)) + /// arr1(&[1., 2., 3., 4.]).reshape_view((2, 2)) /// == arr2(&[[1., 2.], - /// [3., 4.]]) + /// [3., 4.]]).view() /// ); /// ``` - pub fn reshape(&self, shape: E) -> Array where A: Clone + pub fn reshape_view(&self, shape: E + ) -> ArrayView { if shape.size() != self.dim.size() { panic!("Incompatible sizes in reshape, attempted from: {:?}, to: {:?}", self.dim.slice(), shape.slice()) } - // Check if contiguous, if not => copy all, else just adapt strides - if self.is_standard_layout() { - let cl = self.clone(); - Array{ - data: cl.data, - ptr: cl.ptr, - strides: shape.default_strides(), - dim: shape, - } - } else { - let v = self.iter().map(|x| x.clone()).collect::>(); - unsafe { - Array::from_vec_dim(shape, v) - } + // Check if contiguous, if not => panic + if ! self.is_standard_layout() { + panic!("cannot reshape without allocating, you should use reshape_clone") + } + ArrayView { + data: &self.data[..], + ptr: self.ptr, + strides: shape.default_strides(), + dim: shape, + } + } + + /// Transform the array into **shape**; any other shape + /// with the same number of elements is accepted. + /// + /// **Panics** if sizes are incompatible or the reshape can't be done + /// without cloning + pub fn reshape_into(self, shape: E) -> Array + { + if shape.size() != self.dim.size() { + panic!("Incompatible sizes in reshape, attempted from: {:?}, to: {:?}", + self.dim.slice(), shape.slice()) + } + // Check if contiguous, if not => panic + if ! self.is_standard_layout() { + panic!("cannot reshape without allocating, you should use reshape_clone") + } + Array{ + data: self.data, + ptr: self.ptr, + strides: shape.default_strides(), + dim: shape, + } + } + + /// Clone the array into **shape**; any other shape + /// with the same number of elements is accepted. + /// + /// **Panics** if sizes are incompatible + /// + /// ``` + /// use ndarray::{arr1, arr2}; + /// + /// assert!( + /// arr1(&[1., 2., 3., 4.]).reshape_clone((2, 2)) + /// == arr2(&[[1., 2.], + /// [3., 4.]]) + /// ); + /// ``` + pub fn reshape_clone(&self, shape: E + ) -> ArrayOwned + where A: Clone { + if shape.size() != self.dim.size() { + panic!("Incompatible sizes in reshape, attempted from: {:?}, to: {:?}", + self.dim.slice(), shape.slice()) + } + + let v = self.iter().map(|x| x.clone()).collect::>(); + unsafe { + Array::from_vec_dim(shape, v) } } @@ -704,7 +780,9 @@ impl Array where D: Dimension /// If their shapes disagree, **other** is broadcast to the shape of **self**. /// /// **Panics** if broadcasting isn't possible. - pub fn assign(&mut self, other: &Array) where A: Clone + pub fn assign(&mut self, + other: &Array) where + A: Clone, S2: Deref { if self.shape() == other.shape() { for (x, y) in self.iter_mut().zip(other.iter()) { @@ -719,7 +797,8 @@ impl Array where D: Dimension } /// Perform an elementwise assigment to **self** from scalar **x**. - pub fn assign_scalar(&mut self, x: &A) where A: Clone + pub fn assign_scalar(&mut self, x: &A) + where A: Clone, S: DerefMut { for elt in self.raw_data_mut().iter_mut() { *elt = x.clone(); @@ -728,7 +807,7 @@ impl Array where D: Dimension } /// Return a zero-dimensional array with the element **x**. -pub fn arr0(x: A) -> Array +pub fn arr0(x: A) -> Array, ()> { let mut v = Vec::with_capacity(1); v.push(x); @@ -736,7 +815,7 @@ pub fn arr0(x: A) -> Array } /// Return a one-dimensional array with elements from **xs**. -pub fn arr1(xs: &[A]) -> Array +pub fn arr1(xs: &[A]) -> Array, Ix> { Array::from_vec(xs.to_vec()) } @@ -785,7 +864,7 @@ impl_arr_init!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,); /// a.shape() == [2, 3] /// ); /// ``` -pub fn arr2>(xs: &[V]) -> Array +pub fn arr2>(xs: &[V]) -> Array, (Ix, Ix)> { // FIXME: Simplify this when V is fix size array let (m, n) = (xs.len() as Ix, @@ -819,7 +898,8 @@ pub fn arr2>(xs: &[V]) -> Array /// a.shape() == [3, 2, 2] /// ); /// ``` -pub fn arr3, U: ArrInit>(xs: &[V]) -> Array +pub fn arr3(xs: &[V]) -> Array, (Ix, Ix, Ix)> +where A: Clone, V: ArrInit, U: ArrInit, { // FIXME: Simplify this when U/V are fix size arrays let m = xs.len() as Ix; @@ -844,8 +924,9 @@ pub fn arr3, U: ArrInit>(xs: &[V]) -> Array Array where +impl Array where A: Clone + Add, + S: Deref, D: RemoveAxis, { /// Return sum along **axis**. @@ -864,19 +945,21 @@ impl Array where /// ``` /// /// **Panics** if **axis** is out of bounds. - pub fn sum(&self, axis: usize) -> Array::Smaller> + pub fn sum(&self, axis: usize) -> Array, ::Smaller> { let n = self.shape()[axis]; - let mut res = self.subview(axis, 0); + let mut res = self.subview(axis, 0).to_owned(); for i in (1..n) { - res.iadd(&self.subview(axis, i)) + let slice = self.subview(axis, i); + res.iadd(&slice) } res } } -impl Array where +impl Array where A: Copy + linalg::Field, + S: Deref, D: RemoveAxis, { /// Return mean along **axis**. @@ -894,7 +977,8 @@ impl Array where /// /// /// **Panics** if **axis** is out of bounds. - pub fn mean(&self, axis: usize) -> Array::Smaller> + pub fn mean(&self, axis: usize + ) -> Array, ::Smaller> { let n = self.shape()[axis]; let mut sum = self.sum(axis); @@ -910,7 +994,7 @@ impl Array where } } -impl Array +impl> Array { /// Return an iterator over the elements of row **index**. /// @@ -946,7 +1030,7 @@ impl Array // Matrix multiplication only defined for simple types to // avoid trouble with failing + and *, and destructors -impl<'a, A: Copy + linalg::Ring> Array +impl<'a, A: Copy + linalg::Ring, S: Deref> Array { /// Perform matrix multiplication of rectangular arrays **self** and **other**. /// @@ -971,7 +1055,10 @@ impl<'a, A: Copy + linalg::Ring> Array /// ); /// ``` /// - pub fn mat_mul(&self, other: &Array) -> Array + pub fn mat_mul(&self, + other: &Array + ) -> ArrayOwned + where S2: Deref { let ((m, a), (b, n)) = (self.dim, other.dim); let (self_columns, other_rows) = (a, b); @@ -1011,7 +1098,8 @@ impl<'a, A: Copy + linalg::Ring> Array /// Return a result array with shape *M*. /// /// **Panics** if sizes are incompatible. - pub fn mat_mul_col(&self, other: &Array) -> Array + pub fn mat_mul_col(&self, other: &Array) -> ArrayOwned + where S2: Deref { let ((m, a), n) = (self.dim, other.dim); let (self_columns, other_rows) = (a, n); @@ -1039,12 +1127,13 @@ impl<'a, A: Copy + linalg::Ring> Array } -impl Array +impl, D: Dimension> Array { /// Return **true** if the arrays' elementwise differences are all within /// the given absolute tolerance.
/// Return **false** otherwise, or if the shapes disagree. - pub fn allclose(&self, other: &Array, tol: A) -> bool + pub fn allclose(&self, other: &Array, tol: A) -> bool + where S2: Deref { self.shape() == other.shape() && self.iter().zip(other.iter()).all(|(x, y)| (*x - *y).abs() <= tol) @@ -1056,8 +1145,9 @@ impl Array macro_rules! impl_binary_op( ($trt:ident, $mth:ident, $imethod:ident, $imth_scalar:ident) => ( -impl Array where +impl Array where A: Clone + $trt, + S: Deref, D: Dimension, { /// Perform an elementwise arithmetic operation between **self** and **other**, @@ -1066,7 +1156,8 @@ impl Array where /// If their shapes disagree, **other** is broadcast to the shape of **self**. /// /// **Panics** if broadcasting isn't possible. - pub fn $imethod (&mut self, other: &Array) + pub fn $imethod (&mut self, other: &Array) + where S2: Deref { if self.dim.ndim() == other.dim.ndim() && self.shape() == other.shape() { @@ -1091,19 +1182,21 @@ impl Array where } } -impl<'a, A, D, E> $trt> for Array where +impl<'a, A, S1, S2, D, E> $trt> for Array where A: Clone + $trt, + S1: DerefMut, + S2: Deref, D: Dimension, E: Dimension, { - type Output = Array; + type Output = Array; /// Perform an elementwise arithmetic operation between **self** and **other**, /// and return the result. /// /// If their shapes disagree, **other** is broadcast to the shape of **self**. /// /// **Panics** if broadcasting isn't possible. - fn $mth (mut self, other: Array) -> Array + fn $mth (mut self, other: Array) -> Array { // FIXME: Can we co-broadcast arrays here? And how? if self.shape() == other.shape() { @@ -1120,19 +1213,21 @@ impl<'a, A, D, E> $trt> for Array where } } -impl<'a, A, D, E> $trt<&'a Array> for &'a Array where +impl<'a, A, S1, S2, D, E> $trt<&'a Array> for &'a Array where A: Clone + $trt, + S1: Deref, + S2: Deref, D: Dimension, E: Dimension, { - type Output = Array; + type Output = ArrayOwned; /// Perform an elementwise arithmetic operation between **self** and **other**, /// and return the result. /// /// If their shapes disagree, **other** is broadcast to the shape of **self**. /// /// **Panics** if broadcasting isn't possible. - fn $mth (self, other: &'a Array) -> Array + fn $mth (self, other: &'a Array) -> ArrayOwned { // FIXME: Can we co-broadcast arrays here? And how? let mut result = Vec::
::with_capacity(self.dim.size()); @@ -1165,8 +1260,8 @@ impl_binary_op!(BitXor, bitxor, ibitxor, ibitxor_scalar); impl_binary_op!(Shl, shl, ishl, ishl_scalar); impl_binary_op!(Shr, shr, ishr, ishr_scalar); -impl, D: Dimension> -Array +impl, S: DerefMut, D: Dimension> +Array { /// Perform an elementwise negation of **self**, *in place*. pub fn ineg(&mut self) @@ -1177,20 +1272,20 @@ Array } } -impl, D: Dimension> -Neg for Array +impl, S: DerefMut, D: Dimension> +Neg for Array { type Output = Self; /// Perform an elementwise negation of **self** and return the result. - fn neg(mut self) -> Array + fn neg(mut self) -> Array { self.ineg(); self } } -impl, D: Dimension> -Array +impl, S: DerefMut, D: Dimension> +Array { /// Perform an elementwise unary not of **self**, *in place*. pub fn inot(&mut self) @@ -1201,12 +1296,12 @@ Array } } -impl, D: Dimension> -Not for Array +impl, S: DerefMut, D: Dimension> +Not for Array { type Output = Self; /// Perform an elementwise unary not of **self** and return the result. - fn not(mut self) -> Array + fn not(mut self) -> Array { self.inot(); self diff --git a/src/linalg.rs b/src/linalg.rs index 1d9d39ce5..7a63f26ea 100644 --- a/src/linalg.rs +++ b/src/linalg.rs @@ -5,14 +5,14 @@ use libnum::{Num, zero, one, Zero, One}; use libnum::Float; use libnum::Complex; -use std::ops::{Add, Sub, Mul, Div}; +use std::ops::{Add, Sub, Mul, Div, Deref, DerefMut}; use super::{Array, Ix}; /// Column vector. -pub type Col = Array; +pub type Col = Array; /// Rectangular matrix. -pub type Mat = Array; +pub type Mat = Array; /// Trait union for a ring with 1. pub trait Ring : Clone + Zero + Add + Sub @@ -56,7 +56,7 @@ impl ComplexField for Complex } /// Return the identity matrix of dimension *n*. -pub fn eye(n: Ix) -> Mat +pub fn eye(n: Ix) -> Mat> { let mut eye = Array::zeros((n, n)); for a_ii in eye.diag_iter_mut() { @@ -80,7 +80,9 @@ pub fn inverse(a: &Mat) -> Mat /// unknowns *x*. /// /// Return best fit for *x*. -pub fn least_squares(a: &Mat, b: &Col) -> Col +pub fn least_squares(a: &Mat, + b: &Col) -> Col> +where S1: DerefMut, S2: Deref { // Using transpose: a.T a x = a.T b; // a.T a being square gives naive solution @@ -93,12 +95,12 @@ pub fn least_squares(a: &Mat, b: &Col) -> Col // // L L.T x = aT b // - // => L z = aT b + // => L z = aT b // fw subst for z // => L.T x = z // bw subst for x estimate - // - let mut aT = a.clone(); + // + let mut aT = a.to_owned(); aT.swap_axes(0, 1); if ::is_complex() { // conjugate transpose @@ -106,10 +108,10 @@ pub fn least_squares(a: &Mat, b: &Col) -> Col *elt = elt.conjugate(); } } - + let rhs = aT.mat_mul_col(b); let aT_a = aT.mat_mul(a); + let mut L = cholesky(aT_a); - let rhs = aT.mat_mul_col(b); // Solve L z = aT b let z = subst_fw(&L, &rhs); @@ -147,7 +149,8 @@ pub fn least_squares(a: &Mat, b: &Col) -> Col /// substitution.” /// /// Return L. -pub fn cholesky(a: Mat) -> Mat +pub fn cholesky(a: Mat) -> Mat +where S: DerefMut { let z = zero::(); let (m, n) = a.dim(); @@ -206,7 +209,9 @@ fn vec_elem(elt: A, n: usize) -> Vec } /// Solve *L x = b* where *L* is a lower triangular matrix. -pub fn subst_fw(l: &Mat, b: &Col) -> Col +pub fn subst_fw(l: &Mat, + b: &Col) -> Col> +where S1: Deref, S2: Deref { let (m, n) = l.dim(); assert!(m == n); @@ -224,7 +229,9 @@ pub fn subst_fw(l: &Mat, b: &Col) -> Col } /// Solve *U x = b* where *U* is an upper triangular matrix. -pub fn subst_bw(u: &Mat, b: &Col) -> Col +pub fn subst_bw(u: &Mat, + b: &Col) -> Col> +where S1: Deref, S2: Deref { let (m, n) = u.dim(); assert!(m == n); diff --git a/tests/array.rs b/tests/array.rs index 40b77b2f7..e0b5fcfa6 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -86,7 +86,7 @@ fn test_add() #[test] fn test_multidim() { - let mut mat = Array::zeros(2*3*4*5*6).reshape((2,3,4,5,6)); + let mut mat = Array::zeros(2*3*4*5*6).reshape_into((2,3,4,5,6)); mat[(0,0,0,0,0)] = 22u8; { for (i, elt) in mat.iter_mut().enumerate() { @@ -123,7 +123,7 @@ fn test_negative_stride_rcarray() assert_eq!(vi.dim(), (2,4,2)); // Test against sequential iterator let seq = [7f32,6., 5.,4.,3.,2.,1.,0.,15.,14.,13., 12.,11., 10., 9., 8.]; - for (a, b) in vi.clone().iter().zip(seq.iter()) { + for (a, b) in vi.iter().zip(seq.iter()) { assert_eq!(*a, *b); } } @@ -136,62 +136,64 @@ fn test_negative_stride_rcarray() } } -#[test] -fn test_cow() -{ - let mut mat = Array::::zeros((2,2)); - mat[(0, 0)] = 1; - let n = mat.clone(); - mat[(0, 1)] = 2; - mat[(1, 0)] = 3; - mat[(1, 1)] = 4; - assert_eq!(mat[(0,0)], 1); - assert_eq!(mat[(0,1)], 2); - assert_eq!(n[(0,0)], 1); - assert_eq!(n[(0,1)], 0); - let mut rev = mat.reshape(4).slice(&[Si(0, None, -1)]); - assert_eq!(rev[0], 4); - assert_eq!(rev[1], 3); - assert_eq!(rev[2], 2); - assert_eq!(rev[3], 1); - let before = rev.clone(); - // mutation - rev[0] = 5; - assert_eq!(rev[0], 5); - assert_eq!(rev[1], 3); - assert_eq!(rev[2], 2); - assert_eq!(rev[3], 1); - assert_eq!(before[0], 4); - assert_eq!(before[1], 3); - assert_eq!(before[2], 2); - assert_eq!(before[3], 1); -} +// Removed copy on write test, only makes sense with Rc storage +// #[test] +// fn test_cow() +// { +// let mut mat = Array::::zeros((2,2)); +// mat[(0, 0)] = 1; +// let n = mat.clone(); +// mat[(0, 1)] = 2; +// mat[(1, 0)] = 3; +// mat[(1, 1)] = 4; +// assert_eq!(mat[(0,0)], 1); +// assert_eq!(mat[(0,1)], 2); +// assert_eq!(n[(0,0)], 1); +// assert_eq!(n[(0,1)], 0); +// let mat = mat.reshape_into(4); +// let mut rev = mat.slice(&[Si(0, None, -1)]); +// assert_eq!(rev[0], 4); +// assert_eq!(rev[1], 3); +// assert_eq!(rev[2], 2); +// assert_eq!(rev[3], 1); +// let before = rev.clone(); +// // mutation +// rev[0] = 5; +// assert_eq!(rev[0], 5); +// assert_eq!(rev[1], 3); +// assert_eq!(rev[2], 2); +// assert_eq!(rev[3], 1); +// assert_eq!(before[0], 4); +// assert_eq!(before[1], 3); +// assert_eq!(before[2], 2); +// assert_eq!(before[3], 1); +// } #[test] fn test_sub() { - let mat = Array::range(0.0f32, 16.0).reshape((2, 4, 2)); + let mat = Array::range(0.0f32, 16.0).reshape_into((2, 4, 2)); let s1 = mat.subview(0,0); let s2 = mat.subview(0,1); assert_eq!(s1.dim(), (4, 2)); assert_eq!(s2.dim(), (4, 2)); - let n = Array::range(8.0f32, 16.0).reshape((4,2)); - assert_eq!(n, s2); - let m = Array::from_vec(vec![2f32, 3., 10., 11.]).reshape((2, 2)); - assert_eq!(m, mat.subview(1, 1)); + let n = Array::range(8.0f32, 16.0).reshape_into((4,2)); + assert_eq!(n.view(), s2); + let m = Array::from_vec(vec![2f32, 3., 10., 11.]).reshape_into((2, 2)); + assert_eq!(m.view(), mat.subview(1, 1)); } #[test] fn diag() { - let d = arr2(&[[1., 2., 3.0f32]]).diag(); - assert_eq!(d.dim(), 1); - let d = arr2(&[[1., 2., 3.0f32], [0., 0., 0.]]).diag(); - assert_eq!(d.dim(), 2); - let d = arr2::(&[[]]).diag(); - assert_eq!(d.dim(), 0); - let d = Array::::zeros(()).diag(); - assert_eq!(d.dim(), 1); + let a = arr2(&[[1., 2., 3.0f32]]); + assert_eq!(a.diag().dim(), 1); + let a = arr2(&[[1., 2., 3.0f32], [0., 0., 0.]]); + assert_eq!(a.diag().dim(), 2); + let a = arr2::(&[[]]); + assert_eq!(a.diag().dim(), 0); + let a = Array::::zeros(()); + assert_eq!(a.diag().dim(), 1); } #[test] @@ -239,7 +241,7 @@ fn assign() #[test] fn dyn_dimension() { - let a = arr2(&[[1., 2.], [3., 4.0]]).reshape(vec![2, 2]); + let a = arr2(&[[1., 2.], [3., 4.0]]).reshape_into(vec![2, 2]); assert_eq!(&a - &a, Array::zeros(vec![2, 2])); let mut dim = vec![1; 1024]; diff --git a/tests/broadcast.rs b/tests/broadcast.rs index 66866ddea..e98470b15 100644 --- a/tests/broadcast.rs +++ b/tests/broadcast.rs @@ -8,12 +8,12 @@ fn broadcast_1() { let a_dim = (2, 4, 2, 2); let b_dim = (2, 1, 2, 1); - let a = Array::range(0.0, a_dim.size() as f32).reshape(a_dim); - let b = Array::range(0.0, b_dim.size() as f32).reshape(b_dim); + let a = Array::range(0.0, a_dim.size() as f32).reshape_into(a_dim); + let b = Array::range(0.0, b_dim.size() as f32).reshape_into(b_dim); assert!(b.broadcast_iter(a.dim()).is_some()); let c_dim = (2, 1); - let c = Array::range(0.0, c_dim.size() as f32).reshape(c_dim); + let c = Array::range(0.0, c_dim.size() as f32).reshape_into(c_dim); assert!(c.broadcast_iter(1).is_none()); assert!(c.broadcast_iter(()).is_none()); assert!(c.broadcast_iter((2, 1)).is_some()); @@ -34,8 +34,8 @@ fn test_add() { let a_dim = (2, 4, 2, 2); let b_dim = (2, 1, 2, 1); - let mut a = Array::range(0.0, a_dim.size() as f32).reshape(a_dim); - let b = Array::range(0.0, b_dim.size() as f32).reshape(b_dim); + let mut a = Array::range(0.0, a_dim.size() as f32).reshape_into(a_dim); + let b = Array::range(0.0, b_dim.size() as f32).reshape_into(b_dim); a.iadd(&b); let t = Array::from_elem((), 1.0f32); a.iadd(&t); @@ -45,7 +45,7 @@ fn test_add() fn test_add_incompat() { let a_dim = (2, 4, 2, 2); - let mut a = Array::range(0.0, a_dim.size() as f32).reshape(a_dim); + let mut a = Array::range(0.0, a_dim.size() as f32).reshape_into(a_dim); let incompat = Array::from_elem(3, 1.0f32); a.iadd(&incompat); } diff --git a/tests/dimension.rs b/tests/dimension.rs index 88cd3ca59..f8e74fdb7 100644 --- a/tests/dimension.rs +++ b/tests/dimension.rs @@ -18,6 +18,5 @@ fn remove_axis() assert_eq!(vec![4, 5, 6].remove_axis(1), vec![4, 6]); let a = Array::::zeros(vec![4,5,6]); - let b = a.subview(1, 0).reshape((4, 6)).reshape(vec![2, 3, 4]); - + let b = a.subview(1, 0).reshape_clone((4, 6)).reshape_clone(vec![2, 3, 4]); } diff --git a/tests/format.rs b/tests/format.rs index da67b3236..f719e16b3 100644 --- a/tests/format.rs +++ b/tests/format.rs @@ -12,11 +12,11 @@ fn formatting() "[1, 2, 3, 4]"); assert_eq!(format!("{:4?}", a), "[ 1, 2, 3, 4]"); - let a = a.reshape((4, 1, 1)); + let a = a.reshape_into((4, 1, 1)); assert_eq!(format!("{:4?}", a), "[[[ 1]],\n [[ 2]],\n [[ 3]],\n [[ 4]]]"); - let a = a.reshape((2, 2)); + let a = a.reshape_into((2, 2)); assert_eq!(format!("{}", a), "[[1, 2],\n [3, 4]]"); assert_eq!(format!("{:?}", a), diff --git a/tests/iterators.rs b/tests/iterators.rs index e3e7d9e28..5f27c9a6a 100644 --- a/tests/iterators.rs +++ b/tests/iterators.rs @@ -21,7 +21,7 @@ fn indexed() for (i, elt) in a.indexed_iter() { assert_eq!(i, *elt as Ix); } - let a = a.reshape((2, 4, 1)); + let a = a.reshape_into((2, 4, 1)); let (mut i, mut j, k) = (0, 0, 0); for (idx, elt) in a.indexed_iter() { assert_eq!(idx, (i, j, k)); @@ -38,12 +38,14 @@ fn indexed() fn indexed2() { let a = Array::range(0.0, 8.0f32); - let mut iter = a.iter(); - iter.next(); - for (i, elt) in iter.indexed() { - assert_eq!(i, *elt as Ix); + { + let mut iter = a.iter(); + iter.next(); + for (i, elt) in iter.indexed() { + assert_eq!(i, *elt as Ix); + } } - let a = a.reshape((2, 4, 1)); + let a = a.reshape_into((2, 4, 1)); let (mut i, mut j, k) = (0, 0, 0); for (idx, elt) in a.iter().indexed() { assert_eq!(idx, (i, j, k)); @@ -60,7 +62,7 @@ fn indexed2() fn indexed3() { let a = Array::range(0.0, 8.0f32); - let mut a = a.reshape((2, 4, 1)); + let mut a = a.reshape_into((2, 4, 1)); let (mut i, mut j, k) = (0, 0, 0); for (idx, elt) in a.slice_iter_mut(&[S, Si(1, None, 2), S]).indexed() { @@ -73,7 +75,7 @@ fn indexed3() *elt = -1.; println!("{:?}", (idx, elt)); } - let a = a.reshape((2, 4)); + let a = a.reshape_into((2, 4)); assert_eq!( a, arr2(&[[0., -1., 2., -1.], [4., -1., 6., -1.]])); } diff --git a/tests/linalg.rs b/tests/linalg.rs index 1af93e01e..67585a806 100644 --- a/tests/linalg.rs +++ b/tests/linalg.rs @@ -34,7 +34,7 @@ fn chol() assert!(ans.allclose(&chol, 0.001)); // Compute bT b for a pos def matrix - let b = Array::range(0.0f32, 9.).reshape((3, 3)); + let b = Array::range(0.0f32, 9.).reshape_into((3, 3)); let mut bt = b.clone(); bt.swap_axes(0, 1); let bpd = bt.mat_mul(&b); diff --git a/tests/oper.rs b/tests/oper.rs index cacc84598..ed3591bbd 100644 --- a/tests/oper.rs +++ b/tests/oper.rs @@ -1,7 +1,7 @@ extern crate ndarray; extern crate num as libnum; -use ndarray::Array; +use ndarray::{Array, ArrayOwned}; use ndarray::{arr0, arr1, arr2}; use std::fmt; @@ -14,19 +14,20 @@ fn test_oper(op: &str, a: &[f32], b: &[f32], c: &[f32]) let cc = arr1(c); test_oper_arr(op, aa.clone(), bb.clone(), cc.clone()); let dim = (2, 2); - let aa = aa.reshape(dim); - let bb = bb.reshape(dim); - let cc = cc.reshape(dim); + let aa = aa.reshape_into(dim); + let bb = bb.reshape_into(dim); + let cc = cc.reshape_into(dim); test_oper_arr(op, aa.clone(), bb.clone(), cc.clone()); let dim = (1, 2, 1, 2); - let aa = aa.reshape(dim); - let bb = bb.reshape(dim); - let cc = cc.reshape(dim); + let aa = aa.reshape_into(dim); + let bb = bb.reshape_into(dim); + let cc = cc.reshape_into(dim); test_oper_arr(op, aa.clone(), bb.clone(), cc.clone()); } fn test_oper_arr - (op: &str, mut aa: Array, bb: Array, cc: Array) + (op: &str, mut aa: ArrayOwned, + bb: ArrayOwned, cc: ArrayOwned) { match op { "+" => { diff --git a/tests/tests.rs b/tests/tests.rs index 5a78260d6..4dd71a918 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -6,6 +6,6 @@ use ndarray::Array; fn char_array() { // test compilation & basics of non-numerical array - let cc = Array::from_iter("alphabet".chars()).reshape((4, 2)); - assert!(cc.subview(1, 0) == Array::from_iter("apae".chars())); + let cc = Array::from_iter("alphabet".chars()).reshape_into((4, 2)); + assert!(cc.subview(1, 0) == Array::from_iter("apae".chars()).view()); }