Skip to content

Commit 32c40b0

Browse files
committed
shape: Add optional order argument for into_shape
1 parent 014290c commit 32c40b0

File tree

2 files changed

+74
-12
lines changed

2 files changed

+74
-12
lines changed

src/impl_methods.rs

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1885,11 +1885,19 @@ where
18851885
}
18861886

18871887
/// Transform the array into `shape`; any shape with the same number of
1888-
/// elements is accepted, but the source array or view must be in standard
1889-
/// or column-major (Fortran) layout.
1888+
/// elements is accepted, but the source array must be in contiguous row-major (C) or
1889+
/// column-major (F) layout.
1890+
///
1891+
/// If a memory ordering is specified (optional) in the shape argument, the operation
1892+
/// will only succeed if the input has this memory order.
18901893
///
18911894
/// **Errors** if the shapes don't have the same number of elements.<br>
1892-
/// **Errors** if the input array is not c- or f-contiguous.
1895+
/// **Errors** if the input array is not c- or f-contiguous.<br>
1896+
/// **Errors** if a memory ordering is requested that is not compatible with the array.<br>
1897+
///
1898+
/// If shape is not given: use memory layout of incoming array. Row major arrays are
1899+
/// reshaped using row major index ordering, column major arrays with column major index
1900+
/// ordering.
18931901
///
18941902
/// ```
18951903
/// use ndarray::{aview1, aview2};
@@ -1899,10 +1907,24 @@ where
18991907
/// == aview2(&[[1., 2.],
19001908
/// [3., 4.]])
19011909
/// );
1910+
///
1911+
/// assert!(
1912+
/// aview1(&[1., 2., 3., 4.]).into_shape(((2, 2), Order::ColumnMajor)).unwrap()
1913+
/// == aview2(&[[1., 3.],
1914+
/// [2., 4.]])
1915+
/// );
19021916
/// ```
19031917
pub fn into_shape<E>(self, shape: E) -> Result<ArrayBase<S, E::Dim>, ShapeError>
19041918
where
1905-
E: IntoDimension,
1919+
E: ShapeArg,
1920+
{
1921+
let (shape, order) = shape.into_shape_and_order();
1922+
self.into_shape_order(shape, order)
1923+
}
1924+
1925+
fn into_shape_order<E>(self, shape: E, order: Option<Order>) -> Result<ArrayBase<S, E>, ShapeError>
1926+
where
1927+
E: Dimension,
19061928
{
19071929
let shape = shape.into_dimension();
19081930
if size_of_shape_checked(&shape) != Ok(self.dim.size()) {
@@ -1911,12 +1933,14 @@ where
19111933
// Check if contiguous, if not => copy all, else just adapt strides
19121934
unsafe {
19131935
// safe because arrays are contiguous and len is unchanged
1914-
if self.is_standard_layout() {
1915-
Ok(self.with_strides_dim(shape.default_strides(), shape))
1916-
} else if self.ndim() > 1 && self.raw_view().reversed_axes().is_standard_layout() {
1917-
Ok(self.with_strides_dim(shape.fortran_strides(), shape))
1918-
} else {
1919-
Err(error::from_kind(error::ErrorKind::IncompatibleLayout))
1936+
match order {
1937+
None | Some(Order::RowMajor) if self.is_standard_layout() => {
1938+
Ok(self.with_strides_dim(shape.default_strides(), shape))
1939+
}
1940+
None | Some(Order::ColumnMajor) if self.raw_view().reversed_axes().is_standard_layout() => {
1941+
Ok(self.with_strides_dim(shape.fortran_strides(), shape))
1942+
}
1943+
_otherwise => Err(error::from_kind(error::ErrorKind::IncompatibleLayout))
19201944
}
19211945
}
19221946
}
@@ -1932,7 +1956,7 @@ where
19321956
self.into_shape_clone_order(shape, order)
19331957
}
19341958

1935-
pub fn into_shape_clone_order<E>(self, shape: E, order: Order)
1959+
fn into_shape_clone_order<E>(self, shape: E, order: Order)
19361960
-> Result<ArrayBase<S, E>, ShapeError>
19371961
where
19381962
S: DataOwned,
@@ -2004,7 +2028,7 @@ where
20042028
A: Clone,
20052029
E: IntoDimension,
20062030
{
2007-
return self.clone().into_shape_clone(shape).unwrap();
2031+
//return self.clone().into_shape_clone(shape).unwrap();
20082032
let shape = shape.into_dimension();
20092033
if size_of_shape_checked(&shape) != Ok(self.dim.size()) {
20102034
panic!(

tests/reshape.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,41 @@ fn to_shape_broadcast() {
230230
}
231231
}
232232
}
233+
234+
235+
#[test]
236+
fn into_shape_easy() {
237+
// 1D -> C -> C
238+
let data = [1, 2, 3, 4, 5, 6, 7, 8];
239+
let v = aview1(&data);
240+
let u = v.into_shape(((3, 3), Order::RowMajor));
241+
assert!(u.is_err());
242+
243+
let u = v.into_shape(((2, 2, 2), Order::C));
244+
assert!(u.is_ok());
245+
246+
let u = u.unwrap();
247+
assert_eq!(u.shape(), &[2, 2, 2]);
248+
assert_eq!(u, array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
249+
250+
let s = u.into_shape((4, 2)).unwrap();
251+
assert_eq!(s.shape(), &[4, 2]);
252+
assert_eq!(s, aview2(&[[1, 2], [3, 4], [5, 6], [7, 8]]));
253+
254+
// 1D -> F -> F
255+
let data = [1, 2, 3, 4, 5, 6, 7, 8];
256+
let v = aview1(&data);
257+
let u = v.into_shape(((3, 3), Order::ColumnMajor));
258+
assert!(u.is_err());
259+
260+
let u = v.into_shape(((2, 2, 2), Order::ColumnMajor));
261+
assert!(u.is_ok());
262+
263+
let u = u.unwrap();
264+
assert_eq!(u.shape(), &[2, 2, 2]);
265+
assert_eq!(u, array![[[1, 5], [3, 7]], [[2, 6], [4, 8]]]);
266+
267+
let s = u.into_shape(((4, 2), Order::ColumnMajor)).unwrap();
268+
assert_eq!(s.shape(), &[4, 2]);
269+
assert_eq!(s, array![[1, 5], [2, 6], [3, 7], [4, 8]]);
270+
}

0 commit comments

Comments
 (0)