Skip to content

Commit 480f45b

Browse files
authored
Merge pull request #906 from rust-ndarray/faster-into-dimensionality
Add fast case (no-op case) for into_dimensionality and into_dyn
2 parents e98551c + 83cf00d commit 480f45b

File tree

3 files changed

+99
-4
lines changed

3 files changed

+99
-4
lines changed

benches/bench1.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use std::mem::MaybeUninit;
1414
use ndarray::ShapeBuilder;
1515
use ndarray::{arr0, arr1, arr2, azip, s};
1616
use ndarray::{Array, Array1, Array2, Axis, Ix, Zip};
17+
use ndarray::{Ix1, Ix2, Ix3, Ix5, IxDyn};
1718

1819
use test::black_box;
1920

@@ -941,3 +942,59 @@ fn sum_axis1(bench: &mut test::Bencher) {
941942
let a = range_mat(MEAN_SUM_N, MEAN_SUM_N);
942943
bench.iter(|| a.sum_axis(Axis(1)));
943944
}
945+
946+
#[bench]
947+
fn into_dimensionality_ix1_ok(bench: &mut test::Bencher) {
948+
let a = Array::<f32, _>::zeros(Ix1(10));
949+
let a = a.view();
950+
bench.iter(|| a.into_dimensionality::<Ix1>());
951+
}
952+
953+
#[bench]
954+
fn into_dimensionality_ix3_ok(bench: &mut test::Bencher) {
955+
let a = Array::<f32, _>::zeros(Ix3(10, 10, 10));
956+
let a = a.view();
957+
bench.iter(|| a.into_dimensionality::<Ix3>());
958+
}
959+
960+
#[bench]
961+
fn into_dimensionality_ix3_err(bench: &mut test::Bencher) {
962+
let a = Array::<f32, _>::zeros(Ix3(10, 10, 10));
963+
let a = a.view();
964+
bench.iter(|| a.into_dimensionality::<Ix2>());
965+
}
966+
967+
#[bench]
968+
fn into_dimensionality_dyn_to_ix3(bench: &mut test::Bencher) {
969+
let a = Array::<f32, _>::zeros(IxDyn(&[10, 10, 10]));
970+
let a = a.view();
971+
bench.iter(|| a.clone().into_dimensionality::<Ix3>());
972+
}
973+
974+
#[bench]
975+
fn into_dimensionality_dyn_to_dyn(bench: &mut test::Bencher) {
976+
let a = Array::<f32, _>::zeros(IxDyn(&[10, 10, 10]));
977+
let a = a.view();
978+
bench.iter(|| a.clone().into_dimensionality::<IxDyn>());
979+
}
980+
981+
#[bench]
982+
fn into_dyn_ix3(bench: &mut test::Bencher) {
983+
let a = Array::<f32, _>::zeros(Ix3(10, 10, 10));
984+
let a = a.view();
985+
bench.iter(|| a.into_dyn());
986+
}
987+
988+
#[bench]
989+
fn into_dyn_ix5(bench: &mut test::Bencher) {
990+
let a = Array::<f32, _>::zeros(Ix5(2, 2, 2, 2, 2));
991+
let a = a.view();
992+
bench.iter(|| a.into_dyn());
993+
}
994+
995+
#[bench]
996+
fn into_dyn_dyn(bench: &mut test::Bencher) {
997+
let a = Array::<f32, _>::zeros(IxDyn(&[10, 10, 10]));
998+
let a = a.view();
999+
bench.iter(|| a.clone().into_dyn());
1000+
}

src/dimension/dimension_trait.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,11 @@ impl Dimension for IxDyn {
929929
fn from_dimension<D2: Dimension>(d: &D2) -> Option<Self> {
930930
Some(IxDyn(d.slice()))
931931
}
932+
933+
fn into_dyn(self) -> IxDyn {
934+
self
935+
}
936+
932937
private_impl! {}
933938
}
934939

src/impl_methods.rs

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
// option. This file may not be copied, modified, or distributed
77
// except according to those terms.
88

9+
use std::mem::{size_of, ManuallyDrop};
910
use alloc::slice;
1011
use alloc::vec;
1112
use alloc::vec::Vec;
@@ -1583,8 +1584,11 @@ where
15831584
}
15841585
}
15851586

1586-
/// Convert an array or array view to another with the same type, but
1587-
/// different dimensionality type. Errors if the dimensions don't agree.
1587+
/// Convert an array or array view to another with the same type, but different dimensionality
1588+
/// type. Errors if the dimensions don't agree (the number of axes must match).
1589+
///
1590+
/// Note that conversion to a dynamic dimensional array will never fail (and is equivalent to
1591+
/// the `into_dyn` method).
15881592
///
15891593
/// ```
15901594
/// use ndarray::{ArrayD, Ix2, IxDyn};
@@ -1600,15 +1604,29 @@ where
16001604
where
16011605
D2: Dimension,
16021606
{
1603-
if let Some(dim) = D2::from_dimension(&self.dim) {
1604-
if let Some(strides) = D2::from_dimension(&self.strides) {
1607+
if D::NDIM == D2::NDIM {
1608+
// safe because D == D2
1609+
unsafe {
1610+
let dim = unlimited_transmute::<D, D2>(self.dim);
1611+
let strides = unlimited_transmute::<D, D2>(self.strides);
16051612
return Ok(ArrayBase {
16061613
data: self.data,
16071614
ptr: self.ptr,
16081615
dim,
16091616
strides,
16101617
});
16111618
}
1619+
} else if D::NDIM == None || D2::NDIM == None { // one is dynamic dim
1620+
if let Some(dim) = D2::from_dimension(&self.dim) {
1621+
if let Some(strides) = D2::from_dimension(&self.strides) {
1622+
return Ok(ArrayBase {
1623+
data: self.data,
1624+
ptr: self.ptr,
1625+
dim,
1626+
strides,
1627+
});
1628+
}
1629+
}
16121630
}
16131631
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))
16141632
}
@@ -2375,3 +2393,18 @@ where
23752393
});
23762394
}
23772395
}
2396+
2397+
2398+
/// Transmute from A to B.
2399+
///
2400+
/// Like transmute, but does not have the compile-time size check which blocks
2401+
/// using regular transmute in some cases.
2402+
///
2403+
/// **Panics** if the size of A and B are different.
2404+
#[inline]
2405+
unsafe fn unlimited_transmute<A, B>(data: A) -> B {
2406+
// safe when sizes are equal and caller guarantees that representations are equal
2407+
assert_eq!(size_of::<A>(), size_of::<B>());
2408+
let old_data = ManuallyDrop::new(data);
2409+
(&*old_data as *const A as *const B).read()
2410+
}

0 commit comments

Comments
 (0)