1 // Licensed to the .NET Foundation under one or more agreements.
2 // The .NET Foundation licenses this file to you under the MIT license.
3 // See the LICENSE file in the project root for more information.
4 
5 using System.Collections.Generic;
6 using System.Diagnostics;
7 
8 namespace System.Linq
9 {
10     public static partial class Enumerable
11     {
Concat(this IEnumerable<TSource> first, IEnumerable<TSource> second)12         public static IEnumerable<TSource> Concat<TSource>(this IEnumerable<TSource> first, IEnumerable<TSource> second)
13         {
14             if (first == null)
15             {
16                 throw Error.ArgumentNull(nameof(first));
17             }
18 
19             if (second == null)
20             {
21                 throw Error.ArgumentNull(nameof(second));
22             }
23 
24             return first is ConcatIterator<TSource> firstConcat
25                 ? firstConcat.Concat(second)
26                 : new Concat2Iterator<TSource>(first, second);
27         }
28 
29         /// <summary>
30         /// Represents the concatenation of two <see cref="IEnumerable{TSource}"/>.
31         /// </summary>
32         /// <typeparam name="TSource">The type of the source enumerables.</typeparam>
33         private sealed class Concat2Iterator<TSource> : ConcatIterator<TSource>
34         {
35             /// <summary>
36             /// The first source to concatenate.
37             /// </summary>
38             internal readonly IEnumerable<TSource> _first;
39 
40             /// <summary>
41             /// The second source to concatenate.
42             /// </summary>
43             internal readonly IEnumerable<TSource> _second;
44 
45             /// <summary>
46             /// Initializes a new instance of the <see cref="Concat2Iterator{TSource}"/> class.
47             /// </summary>
48             /// <param name="first">The first source to concatenate.</param>
49             /// <param name="second">The second source to concatenate.</param>
Concat2Iterator(IEnumerable<TSource> first, IEnumerable<TSource> second)50             internal Concat2Iterator(IEnumerable<TSource> first, IEnumerable<TSource> second)
51             {
52                 Debug.Assert(first != null);
53                 Debug.Assert(second != null);
54 
55                 _first = first;
56                 _second = second;
57             }
58 
Clone()59             public override Iterator<TSource> Clone() => new Concat2Iterator<TSource>(_first, _second);
60 
Concat(IEnumerable<TSource> next)61             internal override ConcatIterator<TSource> Concat(IEnumerable<TSource> next)
62             {
63                 bool hasOnlyCollections = next is ICollection<TSource> &&
64                                           _first is ICollection<TSource> &&
65                                           _second is ICollection<TSource>;
66                 return new ConcatNIterator<TSource>(this, next, 2, hasOnlyCollections);
67             }
68 
GetCount(bool onlyIfCheap)69             public override int GetCount(bool onlyIfCheap)
70             {
71                 int firstCount, secondCount;
72                 if (!EnumerableHelpers.TryGetCount(_first, out firstCount))
73                 {
74                     if (onlyIfCheap)
75                     {
76                         return -1;
77                     }
78 
79                     firstCount = _first.Count();
80                 }
81 
82                 if (!EnumerableHelpers.TryGetCount(_second, out secondCount))
83                 {
84                     if (onlyIfCheap)
85                     {
86                         return -1;
87                     }
88 
89                     secondCount = _second.Count();
90                 }
91 
92                 return checked(firstCount + secondCount);
93             }
94 
GetEnumerable(int index)95             internal override IEnumerable<TSource> GetEnumerable(int index)
96             {
97                 Debug.Assert(index >= 0 && index <= 2);
98 
99                 switch (index)
100                 {
101                     case 0: return _first;
102                     case 1: return _second;
103                     default: return null;
104                 }
105             }
106 
ToArray()107             public override TSource[] ToArray()
108             {
109                 var builder = new SparseArrayBuilder<TSource>(initialize: true);
110 
111                 bool reservedFirst = builder.ReserveOrAdd(_first);
112                 bool reservedSecond = builder.ReserveOrAdd(_second);
113 
114                 TSource[] array = builder.ToArray();
115 
116                 if (reservedFirst)
117                 {
118                     Marker marker = builder.Markers.First();
119                     Debug.Assert(marker.Index == 0);
120                     EnumerableHelpers.Copy(_first, array, 0, marker.Count);
121                 }
122 
123                 if (reservedSecond)
124                 {
125                     Marker marker = builder.Markers.Last();
126                     EnumerableHelpers.Copy(_second, array, marker.Index, marker.Count);
127                 }
128 
129                 return array;
130             }
131         }
132 
133         /// <summary>
134         /// Represents the concatenation of three or more <see cref="IEnumerable{TSource}"/>.
135         /// </summary>
136         /// <typeparam name="TSource">The type of the source enumerables.</typeparam>
137         /// <remarks>
138         /// To handle chains of >= 3 sources, we chain the <see cref="Concat"/> iterators together and allow
139         /// <see cref="GetEnumerable"/> to fetch enumerables from the previous sources.  This means that rather
140         /// than each <see cref="IEnumerator{T}.MoveNext"/> and <see cref="IEnumerator{T}.Current"/> calls having to traverse all of the previous
141         /// sources, we only have to traverse all of the previous sources once per chained enumerable.  An alternative
142         /// would be to use an array to store all of the enumerables, but this has a much better memory profile and
143         /// without much additional run-time cost.
144         /// </remarks>
145         private sealed class ConcatNIterator<TSource> : ConcatIterator<TSource>
146         {
147             /// <summary>
148             /// The linked list of previous sources.
149             /// </summary>
150             private readonly ConcatIterator<TSource> _tail;
151 
152             /// <summary>
153             /// The source associated with this iterator.
154             /// </summary>
155             private readonly IEnumerable<TSource> _head;
156 
157             /// <summary>
158             /// The logical index associated with this iterator.
159             /// </summary>
160             private readonly int _headIndex;
161 
162             /// <summary>
163             /// <c>true</c> if all sources this iterator concatenates implement <see cref="ICollection{TSource}"/>;
164             /// otherwise, <c>false</c>.
165             /// </summary>
166             /// <remarks>
167             /// This flag allows us to determine in O(1) time whether we can preallocate for <see cref="ToArray"/>
168             /// and <see cref="ConcatIterator{TSource}.ToList"/>, and whether we can get the count of the iterator cheaply.
169             /// </remarks>
170             private readonly bool _hasOnlyCollections;
171 
172             /// <summary>
173             /// Initializes a new instance of the <see cref="ConcatNIterator{TSource}"/> class.
174             /// </summary>
175             /// <param name="tail">The linked list of previous sources.</param>
176             /// <param name="head">The source associated with this iterator.</param>
177             /// <param name="headIndex">The logical index associated with this iterator.</param>
178             /// <param name="hasOnlyCollections">
179             /// <c>true</c> if all sources this iterator concatenates implement <see cref="ICollection{TSource}"/>;
180             /// otherwise, <c>false</c>.
181             /// </param>
ConcatNIterator(ConcatIterator<TSource> tail, IEnumerable<TSource> head, int headIndex, bool hasOnlyCollections)182             internal ConcatNIterator(ConcatIterator<TSource> tail, IEnumerable<TSource> head, int headIndex, bool hasOnlyCollections)
183             {
184                 Debug.Assert(tail != null);
185                 Debug.Assert(head != null);
186                 Debug.Assert(headIndex >= 2);
187 
188                 _tail = tail;
189                 _head = head;
190                 _headIndex = headIndex;
191                 _hasOnlyCollections = hasOnlyCollections;
192             }
193 
194             private ConcatNIterator<TSource> PreviousN => _tail as ConcatNIterator<TSource>;
195 
Clone()196             public override Iterator<TSource> Clone() => new ConcatNIterator<TSource>(_tail, _head, _headIndex, _hasOnlyCollections);
197 
Concat(IEnumerable<TSource> next)198             internal override ConcatIterator<TSource> Concat(IEnumerable<TSource> next)
199             {
200                 if (_headIndex == int.MaxValue - 2)
201                 {
202                     // In the unlikely case of this many concatenations, if we produced a ConcatNIterator
203                     // with int.MaxValue then state would overflow before it matched its index.
204                     // So we use the naïve approach of just having a left and right sequence.
205                     return new Concat2Iterator<TSource>(this, next);
206                 }
207 
208                 bool hasOnlyCollections = _hasOnlyCollections && next is ICollection<TSource>;
209                 return new ConcatNIterator<TSource>(this, next, _headIndex + 1, hasOnlyCollections);
210             }
211 
GetCount(bool onlyIfCheap)212             public override int GetCount(bool onlyIfCheap)
213             {
214                 if (onlyIfCheap && !_hasOnlyCollections)
215                 {
216                     return -1;
217                 }
218 
219                 int count = 0;
220                 ConcatNIterator<TSource> node, previousN = this;
221 
222                 do
223                 {
224                     node = previousN;
225                     IEnumerable<TSource> source = node._head;
226 
227                     // Enumerable.Count() handles ICollections in O(1) time, but check for them here anyway
228                     // to avoid a method call because 1) they're common and 2) this code is run in a loop.
229                     var collection = source as ICollection<TSource>;
230                     Debug.Assert(!_hasOnlyCollections || collection != null);
231                     int sourceCount = collection?.Count ?? source.Count();
232 
233                     checked
234                     {
235                         count += sourceCount;
236                     }
237                 }
238                 while ((previousN = node.PreviousN) != null);
239 
240                 Debug.Assert(node._tail is Concat2Iterator<TSource>);
241                 return checked(count + node._tail.GetCount(onlyIfCheap));
242             }
243 
GetEnumerable(int index)244             internal override IEnumerable<TSource> GetEnumerable(int index)
245             {
246                 Debug.Assert(index >= 0);
247 
248                 if (index > _headIndex)
249                 {
250                     return null;
251                 }
252 
253                 ConcatNIterator<TSource> node, previousN = this;
254                 do
255                 {
256                     node = previousN;
257                     if (index == node._headIndex)
258                     {
259                         return node._head;
260                     }
261                 }
262                 while ((previousN = node.PreviousN) != null);
263 
264                 Debug.Assert(index == 0 || index == 1);
265                 Debug.Assert(node._tail is Concat2Iterator<TSource>);
266                 return node._tail.GetEnumerable(index);
267             }
268 
ToArray()269             public override TSource[] ToArray() => _hasOnlyCollections ? PreallocatingToArray() : LazyToArray();
270 
LazyToArray()271             private TSource[] LazyToArray()
272             {
273                 Debug.Assert(!_hasOnlyCollections);
274 
275                 var builder = new SparseArrayBuilder<TSource>(initialize: true);
276                 var deferredCopies = new ArrayBuilder<int>();
277 
278                 for (int i = 0; ; i++)
279                 {
280                     // Unfortunately, we can't escape re-walking the linked list for each source, which has
281                     // quadratic behavior, because we need to add the sources in order.
282                     // On the bright side, the bottleneck will usually be iterating, buffering, and copying
283                     // each of the enumerables, so this shouldn't be a noticeable perf hit for most scenarios.
284 
285                     IEnumerable<TSource> source = GetEnumerable(i);
286                     if (source == null)
287                     {
288                         break;
289                     }
290 
291                     if (builder.ReserveOrAdd(source))
292                     {
293                         deferredCopies.Add(i);
294                     }
295                 }
296 
297                 TSource[] array = builder.ToArray();
298 
299                 ArrayBuilder<Marker> markers = builder.Markers;
300                 for (int i = 0; i < markers.Count; i++)
301                 {
302                     Marker marker = markers[i];
303                     IEnumerable<TSource> source = GetEnumerable(deferredCopies[i]);
304                     EnumerableHelpers.Copy(source, array, marker.Index, marker.Count);
305                 }
306 
307                 return array;
308             }
309 
PreallocatingToArray()310             private TSource[] PreallocatingToArray()
311             {
312                 // If there are only ICollections in this iterator, then we can just get the count, preallocate the
313                 // array, and copy them as we go. This has better time complexity than continuously re-walking the
314                 // linked list via GetEnumerable, and better memory usage than buffering the collections.
315 
316                 Debug.Assert(_hasOnlyCollections);
317 
318                 int count = GetCount(onlyIfCheap: true);
319                 Debug.Assert(count >= 0);
320 
321                 if (count == 0)
322                 {
323                     return Array.Empty<TSource>();
324                 }
325 
326                 var array = new TSource[count];
327                 int arrayIndex = array.Length; // We start copying in collection-sized chunks from the end of the array.
328 
329                 ConcatNIterator<TSource> node, previousN = this;
330                 do
331                 {
332                     node = previousN;
333                     ICollection<TSource> source = (ICollection<TSource>)node._head;
334                     int sourceCount = source.Count;
335                     if (sourceCount > 0)
336                     {
337                         checked
338                         {
339                             arrayIndex -= sourceCount;
340                         }
341                         source.CopyTo(array, arrayIndex);
342                     }
343                 }
344                 while ((previousN = node.PreviousN) != null);
345 
346                 var previous2 = (Concat2Iterator<TSource>)node._tail;
347                 var second = (ICollection<TSource>)previous2._second;
348                 int secondCount = second.Count;
349 
350                 if (secondCount > 0)
351                 {
352                     second.CopyTo(array, checked(arrayIndex - secondCount));
353                 }
354 
355                 if (arrayIndex > secondCount)
356                 {
357                     var first = (ICollection<TSource>)previous2._first;
358                     first.CopyTo(array, 0);
359                 }
360 
361                 return array;
362             }
363         }
364 
365         /// <summary>
366         /// Represents the concatenation of two or more <see cref="IEnumerable{TSource}"/>.
367         /// </summary>
368         /// <typeparam name="TSource">The type of the source enumerables.</typeparam>
369         private abstract class ConcatIterator<TSource> : Iterator<TSource>, IIListProvider<TSource>
370         {
371             /// <summary>
372             /// The enumerator of the current source, if <see cref="MoveNext"/> has been called.
373             /// </summary>
374             private IEnumerator<TSource> _enumerator;
375 
Dispose()376             public override void Dispose()
377             {
378                 if (_enumerator != null)
379                 {
380                     _enumerator.Dispose();
381                     _enumerator = null;
382                 }
383 
384                 base.Dispose();
385             }
386 
387             /// <summary>
388             /// Gets the enumerable at a logical index in this iterator.
389             /// If the index is equal to the number of enumerables this iterator holds, <c>null</c> is returned.
390             /// </summary>
391             /// <param name="index">The logical index.</param>
GetEnumerable(int index)392             internal abstract IEnumerable<TSource> GetEnumerable(int index);
393 
394             /// <summary>
395             /// Creates a new iterator that concatenates this iterator with an enumerable.
396             /// </summary>
397             /// <param name="next">The next enumerable.</param>
Concat(IEnumerable<TSource> next)398             internal abstract ConcatIterator<TSource> Concat(IEnumerable<TSource> next);
399 
MoveNext()400             public override bool MoveNext()
401             {
402                 if (_state == 1)
403                 {
404                     _enumerator = GetEnumerable(0).GetEnumerator();
405                     _state = 2;
406                 }
407 
408                 if (_state > 1)
409                 {
410                     while (true)
411                     {
412                         if (_enumerator.MoveNext())
413                         {
414                             _current = _enumerator.Current;
415                             return true;
416                         }
417 
418                         IEnumerable<TSource> next = GetEnumerable(_state++ - 1);
419                         if (next != null)
420                         {
421                             _enumerator.Dispose();
422                             _enumerator = next.GetEnumerator();
423                             continue;
424                         }
425 
426                         Dispose();
427                         break;
428                     }
429                 }
430 
431                 return false;
432             }
433 
GetCount(bool onlyIfCheap)434             public abstract int GetCount(bool onlyIfCheap);
435 
ToArray()436             public abstract TSource[] ToArray();
437 
ToList()438             public List<TSource> ToList()
439             {
440                 int count = GetCount(onlyIfCheap: true);
441                 var list = count != -1 ? new List<TSource>(count) : new List<TSource>();
442 
443                 for (int i = 0; ; i++)
444                 {
445                     IEnumerable<TSource> source = GetEnumerable(i);
446                     if (source == null)
447                     {
448                         break;
449                     }
450 
451                     list.AddRange(source);
452                 }
453 
454                 return list;
455             }
456         }
457     }
458 }
459