1 use crate::cmp;
2 use crate::iter::{adapters::SourceIter, FusedIterator, InPlaceIterable, TrustedLen};
3 use crate::ops::{ControlFlow, Try};
4 
5 /// An iterator that only iterates over the first `n` iterations of `iter`.
6 ///
7 /// This `struct` is created by the [`take`] method on [`Iterator`]. See its
8 /// documentation for more.
9 ///
10 /// [`take`]: Iterator::take
11 /// [`Iterator`]: trait.Iterator.html
12 #[derive(Clone, Debug)]
13 #[must_use = "iterators are lazy and do nothing unless consumed"]
14 #[stable(feature = "rust1", since = "1.0.0")]
15 pub struct Take<I> {
16     iter: I,
17     n: usize,
18 }
19 
20 impl<I> Take<I> {
new(iter: I, n: usize) -> Take<I>21     pub(in crate::iter) fn new(iter: I, n: usize) -> Take<I> {
22         Take { iter, n }
23     }
24 }
25 
26 #[stable(feature = "rust1", since = "1.0.0")]
27 impl<I> Iterator for Take<I>
28 where
29     I: Iterator,
30 {
31     type Item = <I as Iterator>::Item;
32 
33     #[inline]
next(&mut self) -> Option<<I as Iterator>::Item>34     fn next(&mut self) -> Option<<I as Iterator>::Item> {
35         if self.n != 0 {
36             self.n -= 1;
37             self.iter.next()
38         } else {
39             None
40         }
41     }
42 
43     #[inline]
nth(&mut self, n: usize) -> Option<I::Item>44     fn nth(&mut self, n: usize) -> Option<I::Item> {
45         if self.n > n {
46             self.n -= n + 1;
47             self.iter.nth(n)
48         } else {
49             if self.n > 0 {
50                 self.iter.nth(self.n - 1);
51                 self.n = 0;
52             }
53             None
54         }
55     }
56 
57     #[inline]
size_hint(&self) -> (usize, Option<usize>)58     fn size_hint(&self) -> (usize, Option<usize>) {
59         if self.n == 0 {
60             return (0, Some(0));
61         }
62 
63         let (lower, upper) = self.iter.size_hint();
64 
65         let lower = cmp::min(lower, self.n);
66 
67         let upper = match upper {
68             Some(x) if x < self.n => Some(x),
69             _ => Some(self.n),
70         };
71 
72         (lower, upper)
73     }
74 
75     #[inline]
try_fold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R where Self: Sized, Fold: FnMut(Acc, Self::Item) -> R, R: Try<Output = Acc>,76     fn try_fold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
77     where
78         Self: Sized,
79         Fold: FnMut(Acc, Self::Item) -> R,
80         R: Try<Output = Acc>,
81     {
82         fn check<'a, T, Acc, R: Try<Output = Acc>>(
83             n: &'a mut usize,
84             mut fold: impl FnMut(Acc, T) -> R + 'a,
85         ) -> impl FnMut(Acc, T) -> ControlFlow<R, Acc> + 'a {
86             move |acc, x| {
87                 *n -= 1;
88                 let r = fold(acc, x);
89                 if *n == 0 { ControlFlow::Break(r) } else { ControlFlow::from_try(r) }
90             }
91         }
92 
93         if self.n == 0 {
94             try { init }
95         } else {
96             let n = &mut self.n;
97             self.iter.try_fold(init, check(n, fold)).into_try()
98         }
99     }
100 
101     #[inline]
fold<Acc, Fold>(mut self, init: Acc, fold: Fold) -> Acc where Self: Sized, Fold: FnMut(Acc, Self::Item) -> Acc,102     fn fold<Acc, Fold>(mut self, init: Acc, fold: Fold) -> Acc
103     where
104         Self: Sized,
105         Fold: FnMut(Acc, Self::Item) -> Acc,
106     {
107         #[inline]
108         fn ok<B, T>(mut f: impl FnMut(B, T) -> B) -> impl FnMut(B, T) -> Result<B, !> {
109             move |acc, x| Ok(f(acc, x))
110         }
111 
112         self.try_fold(init, ok(fold)).unwrap()
113     }
114 
115     #[inline]
116     #[rustc_inherit_overflow_checks]
advance_by(&mut self, n: usize) -> Result<(), usize>117     fn advance_by(&mut self, n: usize) -> Result<(), usize> {
118         let min = self.n.min(n);
119         match self.iter.advance_by(min) {
120             Ok(_) => {
121                 self.n -= min;
122                 if min < n { Err(min) } else { Ok(()) }
123             }
124             ret @ Err(advanced) => {
125                 self.n -= advanced;
126                 ret
127             }
128         }
129     }
130 }
131 
132 #[unstable(issue = "none", feature = "inplace_iteration")]
133 unsafe impl<I> SourceIter for Take<I>
134 where
135     I: SourceIter,
136 {
137     type Source = I::Source;
138 
139     #[inline]
as_inner(&mut self) -> &mut I::Source140     unsafe fn as_inner(&mut self) -> &mut I::Source {
141         // SAFETY: unsafe function forwarding to unsafe function with the same requirements
142         unsafe { SourceIter::as_inner(&mut self.iter) }
143     }
144 }
145 
146 #[unstable(issue = "none", feature = "inplace_iteration")]
147 unsafe impl<I: InPlaceIterable> InPlaceIterable for Take<I> {}
148 
149 #[stable(feature = "double_ended_take_iterator", since = "1.38.0")]
150 impl<I> DoubleEndedIterator for Take<I>
151 where
152     I: DoubleEndedIterator + ExactSizeIterator,
153 {
154     #[inline]
next_back(&mut self) -> Option<Self::Item>155     fn next_back(&mut self) -> Option<Self::Item> {
156         if self.n == 0 {
157             None
158         } else {
159             let n = self.n;
160             self.n -= 1;
161             self.iter.nth_back(self.iter.len().saturating_sub(n))
162         }
163     }
164 
165     #[inline]
nth_back(&mut self, n: usize) -> Option<Self::Item>166     fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
167         let len = self.iter.len();
168         if self.n > n {
169             let m = len.saturating_sub(self.n) + n;
170             self.n -= n + 1;
171             self.iter.nth_back(m)
172         } else {
173             if len > 0 {
174                 self.iter.nth_back(len - 1);
175             }
176             None
177         }
178     }
179 
180     #[inline]
try_rfold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R where Self: Sized, Fold: FnMut(Acc, Self::Item) -> R, R: Try<Output = Acc>,181     fn try_rfold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
182     where
183         Self: Sized,
184         Fold: FnMut(Acc, Self::Item) -> R,
185         R: Try<Output = Acc>,
186     {
187         if self.n == 0 {
188             try { init }
189         } else {
190             let len = self.iter.len();
191             if len > self.n && self.iter.nth_back(len - self.n - 1).is_none() {
192                 try { init }
193             } else {
194                 self.iter.try_rfold(init, fold)
195             }
196         }
197     }
198 
199     #[inline]
rfold<Acc, Fold>(mut self, init: Acc, fold: Fold) -> Acc where Self: Sized, Fold: FnMut(Acc, Self::Item) -> Acc,200     fn rfold<Acc, Fold>(mut self, init: Acc, fold: Fold) -> Acc
201     where
202         Self: Sized,
203         Fold: FnMut(Acc, Self::Item) -> Acc,
204     {
205         if self.n == 0 {
206             init
207         } else {
208             let len = self.iter.len();
209             if len > self.n && self.iter.nth_back(len - self.n - 1).is_none() {
210                 init
211             } else {
212                 self.iter.rfold(init, fold)
213             }
214         }
215     }
216 
217     #[inline]
advance_back_by(&mut self, n: usize) -> Result<(), usize>218     fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
219         let inner_len = self.iter.len();
220         let len = self.n;
221         let remainder = len.saturating_sub(n);
222         let to_advance = inner_len - remainder;
223         match self.iter.advance_back_by(to_advance) {
224             Ok(_) => {
225                 self.n = remainder;
226                 if n > len {
227                     return Err(len);
228                 }
229                 return Ok(());
230             }
231             _ => panic!("ExactSizeIterator contract violation"),
232         }
233     }
234 }
235 
236 #[stable(feature = "rust1", since = "1.0.0")]
237 impl<I> ExactSizeIterator for Take<I> where I: ExactSizeIterator {}
238 
239 #[stable(feature = "fused", since = "1.26.0")]
240 impl<I> FusedIterator for Take<I> where I: FusedIterator {}
241 
242 #[unstable(feature = "trusted_len", issue = "37572")]
243 unsafe impl<I: TrustedLen> TrustedLen for Take<I> {}
244