1 use super::plumbing::*;
2 use super::*;
3 use rayon_core::join;
4 use std::cmp;
5 use std::iter;
6 
7 /// `Chain` is an iterator that joins `b` after `a` in one continuous iterator.
8 /// This struct is created by the [`chain()`] method on [`ParallelIterator`]
9 ///
10 /// [`chain()`]: trait.ParallelIterator.html#method.chain
11 /// [`ParallelIterator`]: trait.ParallelIterator.html
12 #[must_use = "iterator adaptors are lazy and do nothing unless consumed"]
13 #[derive(Debug, Clone)]
14 pub struct Chain<A, B>
15 where
16     A: ParallelIterator,
17     B: ParallelIterator<Item = A::Item>,
18 {
19     a: A,
20     b: B,
21 }
22 
23 impl<A, B> Chain<A, B>
24 where
25     A: ParallelIterator,
26     B: ParallelIterator<Item = A::Item>,
27 {
28     /// Creates a new `Chain` iterator.
new(a: A, b: B) -> Self29     pub(super) fn new(a: A, b: B) -> Self {
30         Chain { a, b }
31     }
32 }
33 
34 impl<A, B> ParallelIterator for Chain<A, B>
35 where
36     A: ParallelIterator,
37     B: ParallelIterator<Item = A::Item>,
38 {
39     type Item = A::Item;
40 
drive_unindexed<C>(self, consumer: C) -> C::Result where C: UnindexedConsumer<Self::Item>,41     fn drive_unindexed<C>(self, consumer: C) -> C::Result
42     where
43         C: UnindexedConsumer<Self::Item>,
44     {
45         let Chain { a, b } = self;
46 
47         // If we returned a value from our own `opt_len`, then the collect consumer in particular
48         // will balk at being treated like an actual `UnindexedConsumer`.  But when we do know the
49         // length, we can use `Consumer::split_at` instead, and this is still harmless for other
50         // truly-unindexed consumers too.
51         let (left, right, reducer) = if let Some(len) = a.opt_len() {
52             consumer.split_at(len)
53         } else {
54             let reducer = consumer.to_reducer();
55             (consumer.split_off_left(), consumer, reducer)
56         };
57 
58         let (a, b) = join(|| a.drive_unindexed(left), || b.drive_unindexed(right));
59         reducer.reduce(a, b)
60     }
61 
opt_len(&self) -> Option<usize>62     fn opt_len(&self) -> Option<usize> {
63         self.a.opt_len()?.checked_add(self.b.opt_len()?)
64     }
65 }
66 
67 impl<A, B> IndexedParallelIterator for Chain<A, B>
68 where
69     A: IndexedParallelIterator,
70     B: IndexedParallelIterator<Item = A::Item>,
71 {
drive<C>(self, consumer: C) -> C::Result where C: Consumer<Self::Item>,72     fn drive<C>(self, consumer: C) -> C::Result
73     where
74         C: Consumer<Self::Item>,
75     {
76         let Chain { a, b } = self;
77         let (left, right, reducer) = consumer.split_at(a.len());
78         let (a, b) = join(|| a.drive(left), || b.drive(right));
79         reducer.reduce(a, b)
80     }
81 
len(&self) -> usize82     fn len(&self) -> usize {
83         self.a.len().checked_add(self.b.len()).expect("overflow")
84     }
85 
with_producer<CB>(self, callback: CB) -> CB::Output where CB: ProducerCallback<Self::Item>,86     fn with_producer<CB>(self, callback: CB) -> CB::Output
87     where
88         CB: ProducerCallback<Self::Item>,
89     {
90         let a_len = self.a.len();
91         return self.a.with_producer(CallbackA {
92             callback,
93             a_len,
94             b: self.b,
95         });
96 
97         struct CallbackA<CB, B> {
98             callback: CB,
99             a_len: usize,
100             b: B,
101         }
102 
103         impl<CB, B> ProducerCallback<B::Item> for CallbackA<CB, B>
104         where
105             B: IndexedParallelIterator,
106             CB: ProducerCallback<B::Item>,
107         {
108             type Output = CB::Output;
109 
110             fn callback<A>(self, a_producer: A) -> Self::Output
111             where
112                 A: Producer<Item = B::Item>,
113             {
114                 self.b.with_producer(CallbackB {
115                     callback: self.callback,
116                     a_len: self.a_len,
117                     a_producer,
118                 })
119             }
120         }
121 
122         struct CallbackB<CB, A> {
123             callback: CB,
124             a_len: usize,
125             a_producer: A,
126         }
127 
128         impl<CB, A> ProducerCallback<A::Item> for CallbackB<CB, A>
129         where
130             A: Producer,
131             CB: ProducerCallback<A::Item>,
132         {
133             type Output = CB::Output;
134 
135             fn callback<B>(self, b_producer: B) -> Self::Output
136             where
137                 B: Producer<Item = A::Item>,
138             {
139                 let producer = ChainProducer::new(self.a_len, self.a_producer, b_producer);
140                 self.callback.callback(producer)
141             }
142         }
143     }
144 }
145 
146 /// ////////////////////////////////////////////////////////////////////////
147 
148 struct ChainProducer<A, B>
149 where
150     A: Producer,
151     B: Producer<Item = A::Item>,
152 {
153     a_len: usize,
154     a: A,
155     b: B,
156 }
157 
158 impl<A, B> ChainProducer<A, B>
159 where
160     A: Producer,
161     B: Producer<Item = A::Item>,
162 {
new(a_len: usize, a: A, b: B) -> Self163     fn new(a_len: usize, a: A, b: B) -> Self {
164         ChainProducer { a_len, a, b }
165     }
166 }
167 
168 impl<A, B> Producer for ChainProducer<A, B>
169 where
170     A: Producer,
171     B: Producer<Item = A::Item>,
172 {
173     type Item = A::Item;
174     type IntoIter = ChainSeq<A::IntoIter, B::IntoIter>;
175 
into_iter(self) -> Self::IntoIter176     fn into_iter(self) -> Self::IntoIter {
177         ChainSeq::new(self.a.into_iter(), self.b.into_iter())
178     }
179 
min_len(&self) -> usize180     fn min_len(&self) -> usize {
181         cmp::max(self.a.min_len(), self.b.min_len())
182     }
183 
max_len(&self) -> usize184     fn max_len(&self) -> usize {
185         cmp::min(self.a.max_len(), self.b.max_len())
186     }
187 
split_at(self, index: usize) -> (Self, Self)188     fn split_at(self, index: usize) -> (Self, Self) {
189         if index <= self.a_len {
190             let a_rem = self.a_len - index;
191             let (a_left, a_right) = self.a.split_at(index);
192             let (b_left, b_right) = self.b.split_at(0);
193             (
194                 ChainProducer::new(index, a_left, b_left),
195                 ChainProducer::new(a_rem, a_right, b_right),
196             )
197         } else {
198             let (a_left, a_right) = self.a.split_at(self.a_len);
199             let (b_left, b_right) = self.b.split_at(index - self.a_len);
200             (
201                 ChainProducer::new(self.a_len, a_left, b_left),
202                 ChainProducer::new(0, a_right, b_right),
203             )
204         }
205     }
206 
fold_with<F>(self, mut folder: F) -> F where F: Folder<A::Item>,207     fn fold_with<F>(self, mut folder: F) -> F
208     where
209         F: Folder<A::Item>,
210     {
211         folder = self.a.fold_with(folder);
212         if folder.full() {
213             folder
214         } else {
215             self.b.fold_with(folder)
216         }
217     }
218 }
219 
220 /// ////////////////////////////////////////////////////////////////////////
221 /// Wrapper for Chain to implement ExactSizeIterator
222 
223 struct ChainSeq<A, B> {
224     chain: iter::Chain<A, B>,
225 }
226 
227 impl<A, B> ChainSeq<A, B> {
new(a: A, b: B) -> ChainSeq<A, B> where A: ExactSizeIterator, B: ExactSizeIterator<Item = A::Item>,228     fn new(a: A, b: B) -> ChainSeq<A, B>
229     where
230         A: ExactSizeIterator,
231         B: ExactSizeIterator<Item = A::Item>,
232     {
233         ChainSeq { chain: a.chain(b) }
234     }
235 }
236 
237 impl<A, B> Iterator for ChainSeq<A, B>
238 where
239     A: Iterator,
240     B: Iterator<Item = A::Item>,
241 {
242     type Item = A::Item;
243 
next(&mut self) -> Option<Self::Item>244     fn next(&mut self) -> Option<Self::Item> {
245         self.chain.next()
246     }
247 
size_hint(&self) -> (usize, Option<usize>)248     fn size_hint(&self) -> (usize, Option<usize>) {
249         self.chain.size_hint()
250     }
251 }
252 
253 impl<A, B> ExactSizeIterator for ChainSeq<A, B>
254 where
255     A: ExactSizeIterator,
256     B: ExactSizeIterator<Item = A::Item>,
257 {
258 }
259 
260 impl<A, B> DoubleEndedIterator for ChainSeq<A, B>
261 where
262     A: DoubleEndedIterator,
263     B: DoubleEndedIterator<Item = A::Item>,
264 {
next_back(&mut self) -> Option<Self::Item>265     fn next_back(&mut self) -> Option<Self::Item> {
266         self.chain.next_back()
267     }
268 }
269