Skip to content

Commit 728794a

Browse files
committed
Add .into_shape_clone()
1 parent fe46a38 commit 728794a

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

src/impl_methods.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1921,6 +1921,59 @@ where
19211921
}
19221922
}
19231923

1924+
pub fn into_shape_clone<E>(self, shape: E) -> Result<ArrayBase<S, E::Dim>, ShapeError>
1925+
where
1926+
S: DataOwned,
1927+
A: Clone,
1928+
E: ShapeArg,
1929+
{
1930+
let (shape, order) = shape.into_shape_and_order();
1931+
let order = order.unwrap_or(Order::RowMajor);
1932+
self.into_shape_clone_order(shape, order)
1933+
}
1934+
1935+
pub fn into_shape_clone_order<E>(self, shape: E, order: Order)
1936+
-> Result<ArrayBase<S, E>, ShapeError>
1937+
where
1938+
S: DataOwned,
1939+
A: Clone,
1940+
E: Dimension,
1941+
{
1942+
let len = self.dim.size();
1943+
if size_of_shape_checked(&shape) != Ok(len) {
1944+
return Err(error::incompatible_shapes(&self.dim, &shape));
1945+
}
1946+
1947+
// Safe because the array and new shape is empty.
1948+
if len == 0 {
1949+
unsafe {
1950+
return Ok(self.with_strides_dim(shape.default_strides(), shape));
1951+
}
1952+
}
1953+
1954+
// Try to reshape the array's current data
1955+
match reshape_dim(&self.dim, &self.strides, &shape, order) {
1956+
Ok(to_strides) => unsafe {
1957+
return Ok(self.with_strides_dim(to_strides, shape));
1958+
}
1959+
Err(err) if err.kind() == ErrorKind::IncompatibleShape => {
1960+
return Err(error::incompatible_shapes(&self.dim, &shape));
1961+
}
1962+
_otherwise => { }
1963+
}
1964+
1965+
// otherwise, clone and allocate a new array
1966+
unsafe {
1967+
let (shape, view) = match order {
1968+
Order::RowMajor => (shape.set_f(false), self.view()),
1969+
Order::ColumnMajor => (shape.set_f(true), self.t()),
1970+
};
1971+
1972+
Ok(ArrayBase::from_shape_trusted_iter_unchecked(
1973+
shape, view.into_iter(), A::clone))
1974+
}
1975+
}
1976+
19241977
/// *Note: Reshape is for `ArcArray` only. Use `.into_shape()` for
19251978
/// other arrays and array views.*
19261979
///
@@ -1951,6 +2004,7 @@ where
19512004
A: Clone,
19522005
E: IntoDimension,
19532006
{
2007+
return self.clone().into_shape_clone(shape).unwrap();
19542008
let shape = shape.into_dimension();
19552009
if size_of_shape_checked(&shape) != Ok(self.dim.size()) {
19562010
panic!(

0 commit comments

Comments
 (0)