1 // Licensed to the .NET Foundation under one or more agreements.
2 // The .NET Foundation licenses this file to you under the MIT license.
3 // See the LICENSE file in the project root for more information.
4 
5 using System.Collections;
6 using System.Collections.Generic;
7 using Xunit;
8 
9 namespace System.Linq.Tests
10 {
11     public static class TestExtensions
12     {
RunOnce(this IEnumerable<T> source)13         public static IEnumerable<T> RunOnce<T>(this IEnumerable<T> source) =>
14             source == null ? null : (source as IList<T>)?.RunOnce() ?? new RunOnceEnumerable<T>(source);
15 
RunOnce(this IList<T> source)16         public static IEnumerable<T> RunOnce<T>(this IList<T> source)
17             => source == null ? null : new RunOnceList<T>(source);
18 
19         private class RunOnceEnumerable<T> : IEnumerable<T>
20         {
21             private readonly IEnumerable<T> _source;
22             private bool _called;
23 
RunOnceEnumerable(IEnumerable<T> source)24             public RunOnceEnumerable(IEnumerable<T> source)
25             {
26                 _source = source;
27             }
28 
GetEnumerator()29             public IEnumerator<T> GetEnumerator()
30             {
31                 Assert.False(_called);
32                 _called = true;
33                 return _source.GetEnumerator();
34             }
35 
IEnumerable.GetEnumerator()36             IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
37         }
38 
39         private class RunOnceList<T> : IList<T>
40         {
41             private readonly IList<T> _source;
42             private readonly HashSet<int> _called = new HashSet<int>();
43 
AssertAll()44             private void AssertAll()
45             {
46                 Assert.Empty(_called);
47                 _called.Add(-1);
48             }
49 
AssertIndex(int index)50             private void AssertIndex(int index)
51             {
52                 Assert.False(_called.Contains(-1));
53                 Assert.True(_called.Add(index));
54             }
55 
RunOnceList(IList<T> source)56             public RunOnceList(IList<T> source)
57             {
58                 _source = source;
59             }
60 
GetEnumerator()61             public IEnumerator<T> GetEnumerator()
62             {
63                 AssertAll();
64                 return _source.GetEnumerator();
65             }
66 
IEnumerable.GetEnumerator()67             IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
68 
Add(T item)69             public void Add(T item)
70             {
71                 throw new NotSupportedException();
72             }
73 
Clear()74             public void Clear()
75             {
76                 throw new NotSupportedException();
77             }
78 
Contains(T item)79             public bool Contains(T item)
80             {
81                 AssertAll();
82                 return _source.Contains(item);
83             }
84 
CopyTo(T[] array, int arrayIndex)85             public void CopyTo(T[] array, int arrayIndex)
86             {
87                 AssertAll();
88                 _source.CopyTo(array, arrayIndex);
89             }
90 
Remove(T item)91             public bool Remove(T item)
92             {
93                 throw new NotSupportedException();
94             }
95 
96             public int Count => _source.Count;
97 
98             public bool IsReadOnly => true;
99 
IndexOf(T item)100             public int IndexOf(T item)
101             {
102                 AssertAll();
103                 return _source.IndexOf(item);
104             }
105 
Insert(int index, T item)106             public void Insert(int index, T item)
107             {
108                 throw new NotSupportedException();
109             }
110 
RemoveAt(int index)111             public void RemoveAt(int index)
112             {
113                 throw new NotSupportedException();
114             }
115 
116             public T this[int index]
117             {
118                 get
119                 {
120                     AssertIndex(index);
121                     return _source[index];
122                 }
123                 set { throw new NotSupportedException(); }
124             }
125         }
126     }
127 }
128