Skip to content

Commit dad6bbd

Browse files
committed
shape: Add optional order argument for into_shape
1 parent 014290c commit dad6bbd

File tree

2 files changed

+84
-13
lines changed

2 files changed

+84
-13
lines changed

src/impl_methods.rs

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1885,38 +1885,71 @@ where
18851885
}
18861886

18871887
/// Transform the array into `shape`; any shape with the same number of
1888-
/// elements is accepted, but the source array or view must be in standard
1889-
/// or column-major (Fortran) layout.
1888+
/// elements is accepted, but the source array must be in contiguous row-major (C) or
1889+
/// column-major (F) layout.
1890+
///
1891+
/// **Note** that `.into_shape()` "moves" elements differently depending on if the input array
1892+
/// is C-contig or F-contig, it follows the index order that corresponds to the memory
1893+
/// order. If this is not wanted, use `.to_shape()`.
1894+
///
1895+
/// If a memory ordering is specified (optional) in the shape argument, the operation
1896+
/// will only succeed if the input has this memory order.
18901897
///
18911898
/// **Errors** if the shapes don't have the same number of elements.<br>
1892-
/// **Errors** if the input array is not c- or f-contiguous.
1899+
/// **Errors** if the input array is not c- or f-contiguous.<br>
1900+
/// **Errors** if a memory ordering is requested that is not compatible with the array.<br>
1901+
///
1902+
/// If shape is not given: use memory layout of incoming array. Row major arrays are
1903+
/// reshaped using row major index ordering, column major arrays with column major index
1904+
/// ordering.
18931905
///
18941906
/// ```
18951907
/// use ndarray::{aview1, aview2};
1908+
/// use ndarray::Order;
18961909
///
18971910
/// assert!(
18981911
/// aview1(&[1., 2., 3., 4.]).into_shape((2, 2)).unwrap()
18991912
/// == aview2(&[[1., 2.],
19001913
/// [3., 4.]])
19011914
/// );
1915+
///
1916+
/// assert!(
1917+
/// aview1(&[1., 2., 3., 4.]).into_shape(((2, 2), Order::ColumnMajor)).unwrap()
1918+
/// == aview2(&[[1., 3.],
1919+
/// [2., 4.]])
1920+
/// );
19021921
/// ```
19031922
pub fn into_shape<E>(self, shape: E) -> Result<ArrayBase<S, E::Dim>, ShapeError>
19041923
where
1905-
E: IntoDimension,
1924+
E: ShapeArg,
1925+
{
1926+
let (shape, order) = shape.into_shape_and_order();
1927+
self.into_shape_order(shape, order)
1928+
}
1929+
1930+
fn into_shape_order<E>(self, shape: E, order: Option<Order>) -> Result<ArrayBase<S, E>, ShapeError>
1931+
where
1932+
E: Dimension,
19061933
{
19071934
let shape = shape.into_dimension();
19081935
if size_of_shape_checked(&shape) != Ok(self.dim.size()) {
19091936
return Err(error::incompatible_shapes(&self.dim, &shape));
19101937
}
1911-
// Check if contiguous, if not => copy all, else just adapt strides
1938+
1939+
// Check if contiguous, then we can change shape
1940+
let require_order = order.is_some();
19121941
unsafe {
19131942
// safe because arrays are contiguous and len is unchanged
1914-
if self.is_standard_layout() {
1915-
Ok(self.with_strides_dim(shape.default_strides(), shape))
1916-
} else if self.ndim() > 1 && self.raw_view().reversed_axes().is_standard_layout() {
1917-
Ok(self.with_strides_dim(shape.fortran_strides(), shape))
1918-
} else {
1919-
Err(error::from_kind(error::ErrorKind::IncompatibleLayout))
1943+
match order {
1944+
None | Some(Order::RowMajor) if self.is_standard_layout() => {
1945+
Ok(self.with_strides_dim(shape.default_strides(), shape))
1946+
}
1947+
None | Some(Order::ColumnMajor) if (require_order || self.ndim() > 1) &&
1948+
self.raw_view().reversed_axes().is_standard_layout() =>
1949+
{
1950+
Ok(self.with_strides_dim(shape.fortran_strides(), shape))
1951+
}
1952+
_otherwise => Err(error::from_kind(error::ErrorKind::IncompatibleLayout))
19201953
}
19211954
}
19221955
}
@@ -1932,7 +1965,7 @@ where
19321965
self.into_shape_clone_order(shape, order)
19331966
}
19341967

1935-
pub fn into_shape_clone_order<E>(self, shape: E, order: Order)
1968+
fn into_shape_clone_order<E>(self, shape: E, order: Order)
19361969
-> Result<ArrayBase<S, E>, ShapeError>
19371970
where
19381971
S: DataOwned,
@@ -2004,7 +2037,7 @@ where
20042037
A: Clone,
20052038
E: IntoDimension,
20062039
{
2007-
return self.clone().into_shape_clone(shape).unwrap();
2040+
//return self.clone().into_shape_clone(shape).unwrap();
20082041
let shape = shape.into_dimension();
20092042
if size_of_shape_checked(&shape) != Ok(self.dim.size()) {
20102043
panic!(

tests/reshape.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,41 @@ fn to_shape_broadcast() {
230230
}
231231
}
232232
}
233+
234+
235+
#[test]
236+
fn into_shape_easy() {
237+
// 1D -> C -> C
238+
let data = [1, 2, 3, 4, 5, 6, 7, 8];
239+
let v = aview1(&data);
240+
let u = v.into_shape(((3, 3), Order::RowMajor));
241+
assert!(u.is_err());
242+
243+
let u = v.into_shape(((2, 2, 2), Order::C));
244+
assert!(u.is_ok());
245+
246+
let u = u.unwrap();
247+
assert_eq!(u.shape(), &[2, 2, 2]);
248+
assert_eq!(u, array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
249+
250+
let s = u.into_shape((4, 2)).unwrap();
251+
assert_eq!(s.shape(), &[4, 2]);
252+
assert_eq!(s, aview2(&[[1, 2], [3, 4], [5, 6], [7, 8]]));
253+
254+
// 1D -> F -> F
255+
let data = [1, 2, 3, 4, 5, 6, 7, 8];
256+
let v = aview1(&data);
257+
let u = v.into_shape(((3, 3), Order::ColumnMajor));
258+
assert!(u.is_err());
259+
260+
let u = v.into_shape(((2, 2, 2), Order::ColumnMajor));
261+
assert!(u.is_ok());
262+
263+
let u = u.unwrap();
264+
assert_eq!(u.shape(), &[2, 2, 2]);
265+
assert_eq!(u, array![[[1, 5], [3, 7]], [[2, 6], [4, 8]]]);
266+
267+
let s = u.into_shape(((4, 2), Order::ColumnMajor)).unwrap();
268+
assert_eq!(s.shape(), &[4, 2]);
269+
assert_eq!(s, array![[1, 5], [2, 6], [3, 7], [4, 8]]);
270+
}

0 commit comments

Comments
 (0)