diff --git a/library/core/src/iter/adapters/take.rs b/library/core/src/iter/adapters/take.rs index ce18bffe7146f..c1d8cc4ff57bd 100644 --- a/library/core/src/iter/adapters/take.rs +++ b/library/core/src/iter/adapters/take.rs @@ -1,5 +1,7 @@ use crate::cmp; -use crate::iter::{adapters::SourceIter, FusedIterator, InPlaceIterable, TrustedLen}; +use crate::iter::{ + adapters::SourceIter, FusedIterator, InPlaceIterable, TrustedLen, TrustedRandomAccess, +}; use crate::num::NonZeroUsize; use crate::ops::{ControlFlow, Try}; @@ -98,26 +100,18 @@ where } } - impl_fold_via_try_fold! { fold -> try_fold } - #[inline] - fn for_each(mut self, f: F) { - // The default implementation would use a unit accumulator, so we can - // avoid a stateful closure by folding over the remaining number - // of items we wish to return instead. - fn check<'a, Item>( - mut action: impl FnMut(Item) + 'a, - ) -> impl FnMut(usize, Item) -> Option + 'a { - move |more, x| { - action(x); - more.checked_sub(1) - } - } + fn fold(self, init: B, f: F) -> B + where + Self: Sized, + F: FnMut(B, Self::Item) -> B, + { + Self::spec_fold(self, init, f) + } - let remaining = self.n; - if remaining > 0 { - self.iter.try_fold(remaining - 1, check(f)); - } + #[inline] + fn for_each(self, f: F) { + Self::spec_for_each(self, f) } #[inline] @@ -249,3 +243,72 @@ impl FusedIterator for Take where I: FusedIterator {} #[unstable(feature = "trusted_len", issue = "37572")] unsafe impl TrustedLen for Take {} + +trait SpecTake: Iterator { + fn spec_fold(self, init: B, f: F) -> B + where + Self: Sized, + F: FnMut(B, Self::Item) -> B; + + fn spec_for_each(self, f: F); +} + +impl SpecTake for Take { + #[inline] + default fn spec_fold(mut self, init: B, f: F) -> B + where + Self: Sized, + F: FnMut(B, Self::Item) -> B, + { + use crate::ops::NeverShortCircuit; + self.try_fold(init, NeverShortCircuit::wrap_mut_2(f)).0 + } + + #[inline] + default fn spec_for_each(mut self, f: F) { + // The default implementation would use a unit accumulator, so we can + // avoid a stateful closure by folding over the remaining number + // of items we wish to return instead. + fn check<'a, Item>( + mut action: impl FnMut(Item) + 'a, + ) -> impl FnMut(usize, Item) -> Option + 'a { + move |more, x| { + action(x); + more.checked_sub(1) + } + } + + let remaining = self.n; + if remaining > 0 { + self.iter.try_fold(remaining - 1, check(f)); + } + } +} + +impl SpecTake for Take { + #[inline] + fn spec_fold(mut self, init: B, mut f: F) -> B + where + Self: Sized, + F: FnMut(B, Self::Item) -> B, + { + let mut acc = init; + let end = self.n.min(self.iter.size()); + for i in 0..end { + // SAFETY: i < end <= self.iter.size() and we discard the iterator at the end + let val = unsafe { self.iter.__iterator_get_unchecked(i) }; + acc = f(acc, val); + } + acc + } + + #[inline] + fn spec_for_each(mut self, mut f: F) { + let end = self.n.min(self.iter.size()); + for i in 0..end { + // SAFETY: i < end <= self.iter.size() and we discard the iterator at the end + let val = unsafe { self.iter.__iterator_get_unchecked(i) }; + f(val); + } + } +} diff --git a/tests/codegen/lib-optimizations/iter-sum.rs b/tests/codegen/lib-optimizations/iter-sum.rs new file mode 100644 index 0000000000000..ff7ca6ef6c11e --- /dev/null +++ b/tests/codegen/lib-optimizations/iter-sum.rs @@ -0,0 +1,15 @@ +// ignore-debug: the debug assertions get in the way +// compile-flags: -O +// only-x86_64 (vectorization varies between architectures) +#![crate_type = "lib"] + + +// Ensure that slice + take + sum gets vectorized. +// Currently this relies on the slice::Iter::try_fold implementation +// CHECK-LABEL: @slice_take_sum +#[no_mangle] +pub fn slice_take_sum(s: &[u64], l: usize) -> u64 { + // CHECK: vector.body: + // CHECK: ret + s.iter().take(l).sum() +}