1 // Licensed to the .NET Foundation under one or more agreements.
2 // The .NET Foundation licenses this file to you under the MIT license.
3 // See the LICENSE file in the project root for more information.
4 
5 using System.Collections.Generic;
6 using System.Diagnostics;
7 using System.IO;
8 using System.Runtime.CompilerServices;
9 using System.Threading;
10 using System.Threading.Tasks;
11 
12 namespace System.IO.Tests
13 {
14     public class CallTrackingStream : Stream
15     {
16         private readonly Dictionary<string, int> _callCounts; // maps names of methods -> how many times they were called
17 
CallTrackingStream(Stream inner)18         public CallTrackingStream(Stream inner)
19         {
20             Debug.Assert(inner != null);
21 
22             Inner = inner;
23             _callCounts = new Dictionary<string, int>();
24         }
25 
26         public Stream Inner { get; }
27 
28         // Overridden Stream properties
29 
30         public override bool CanRead => Read(Inner.CanRead);
31         public override bool CanWrite => Read(Inner.CanWrite);
32         public override bool CanSeek => Read(Inner.CanSeek);
33         public override bool CanTimeout => Read(Inner.CanTimeout);
34 
35         public override long Length => Read(Inner.Length);
36 
37         public override long Position
38         {
39             get { return Read(Inner.Position); }
40             set { Update(() => Inner.Position = value); }
41         }
42 
43         public override int ReadTimeout
44         {
45             get { return Read(Inner.ReadTimeout); }
46             set { Update(() => Inner.ReadTimeout = value); }
47         }
48 
49         public override int WriteTimeout
50         {
51             get { return Read(Inner.WriteTimeout); }
52             set { Update(() => Inner.WriteTimeout = value); }
53         }
54 
55         // Arguments we record
56         // We can just use regular, auto-implemented properties for these,
57         // since we know none of them are going to be called by the framework
58 
59         public Stream CopyToAsyncDestination { get; private set; }
60         public int CopyToAsyncBufferSize { get; private set; }
61         public CancellationToken CopyToAsyncCancellationToken { get; private set; }
62 
63         public bool DisposeDisposing { get; private set; }
64 
65         public CancellationToken FlushAsyncCancellationToken { get; private set; }
66 
67         public byte[] ReadBuffer { get; private set; }
68         public int ReadOffset { get; private set; }
69         public int ReadCount { get; private set; }
70 
71         public byte[] ReadAsyncBuffer { get; private set; }
72         public int ReadAsyncOffset { get; private set; }
73         public int ReadAsyncCount { get; private set; }
74         public CancellationToken ReadAsyncCancellationToken { get; private set; }
75 
76         public long SeekOffset { get; private set; }
77         public SeekOrigin SeekOrigin { get; private set; }
78 
79         public long SetLengthValue { get; private set; }
80 
81         public byte[] WriteBuffer { get; private set; }
82         public int WriteOffset { get; private set; }
83         public int WriteCount { get; private set; }
84 
85         public byte[] WriteAsyncBuffer { get; private set; }
86         public int WriteAsyncOffset { get; private set; }
87         public int WriteAsyncCount { get; private set; }
88         public CancellationToken WriteAsyncCancellationToken { get; private set; }
89 
90         public byte WriteByteValue { get; private set; }
91 
92         // Overridden methods
93 
CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)94         public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
95         {
96             UpdateCallCount();
97             CopyToAsyncDestination = destination;
98             CopyToAsyncBufferSize = bufferSize;
99             CopyToAsyncCancellationToken = cancellationToken;
100             return Inner.CopyToAsync(destination, bufferSize, cancellationToken);
101         }
102 
103         // Skip Dispose; it's not accessible to us since the virtual overload is protected
104 
Flush()105         public override void Flush()
106         {
107             UpdateCallCount();
108             Inner.Flush();
109         }
110 
FlushAsync(CancellationToken cancellationToken)111         public override Task FlushAsync(CancellationToken cancellationToken)
112         {
113             UpdateCallCount();
114             FlushAsyncCancellationToken = cancellationToken;
115             return Inner.FlushAsync(cancellationToken);
116         }
117 
Read(byte[] buffer, int offset, int count)118         public override int Read(byte[] buffer, int offset, int count)
119         {
120             UpdateCallCount();
121             ReadBuffer = buffer;
122             ReadOffset = offset;
123             ReadCount = count;
124             return Inner.Read(buffer, offset, count);
125         }
126 
ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)127         public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
128         {
129             UpdateCallCount();
130             ReadAsyncBuffer = buffer;
131             ReadAsyncOffset = offset;
132             ReadAsyncCount = count;
133             ReadAsyncCancellationToken = cancellationToken;
134             return Inner.ReadAsync(buffer, offset, count, cancellationToken);
135         }
136 
ReadByte()137         public override int ReadByte()
138         {
139             UpdateCallCount();
140             return Inner.ReadByte();
141         }
142 
Seek(long offset, SeekOrigin origin)143         public override long Seek(long offset, SeekOrigin origin)
144         {
145             UpdateCallCount();
146             SeekOffset = offset;
147             SeekOrigin = origin;
148             return Inner.Seek(offset, origin);
149         }
150 
SetLength(long value)151         public override void SetLength(long value)
152         {
153             UpdateCallCount();
154             SetLengthValue = value;
155             Inner.SetLength(value);
156         }
157 
Write(byte[] buffer, int offset, int count)158         public override void Write(byte[] buffer, int offset, int count)
159         {
160             UpdateCallCount();
161             WriteBuffer = buffer;
162             WriteOffset = offset;
163             WriteCount = count;
164             Inner.Write(buffer, offset, count);
165         }
166 
WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)167         public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
168         {
169             UpdateCallCount();
170             WriteAsyncBuffer = buffer;
171             WriteAsyncOffset = offset;
172             WriteAsyncCount = count;
173             WriteAsyncCancellationToken = cancellationToken;
174             return Inner.WriteAsync(buffer, offset, count, cancellationToken);
175         }
176 
WriteByte(byte value)177         public override void WriteByte(byte value)
178         {
179             UpdateCallCount();
180             WriteByteValue = value;
181             Inner.WriteByte(value);
182         }
183 
184         // Bookkeeping logic
185 
TimesCalled(string member)186         public int TimesCalled(string member)
187         {
188             int result;
189             _callCounts.TryGetValue(member, out result);
190             return result; // not present means we haven't called it yet, so return 0
191         }
192 
193         // [CallerMemberName] causes the member parameter to be set to the name
194         // of the calling member if not specified, e.g. calling this method
195         // from SetLength would pass in member with a value of "SetLength"
Read(T property, [CallerMemberName] string member = null)196         private T Read<T>(T property, [CallerMemberName] string member = null)
197         {
198             UpdateCallCount(member);
199             return property;
200         }
201 
Update(Action setter, [CallerMemberName] string member = null)202         private void Update(Action setter, [CallerMemberName] string member = null)
203         {
204             UpdateCallCount(member);
205             setter();
206         }
207 
UpdateCallCount([CallerMemberName] string member = null)208         private void UpdateCallCount([CallerMemberName] string member = null)
209         {
210             _callCounts[member] = TimesCalled(member) + 1;
211         }
212     }
213 }
214