1 // Licensed to the .NET Foundation under one or more agreements.
2 // The .NET Foundation licenses this file to you under the MIT license.
3 // See the LICENSE file in the project root for more information.
4 
5 using System.Collections.Generic;
6 using System.IO;
7 using System.Linq;
8 using System.Net.Http;
9 using System.Net.NetworkInformation;
10 using System.Net.Security;
11 using System.Net.Sockets;
12 using System.Security.Authentication;
13 using System.Security.Cryptography;
14 using System.Security.Cryptography.X509Certificates;
15 using System.Text;
16 using System.Threading;
17 using System.Threading.Tasks;
18 
19 namespace System.Net.Test.Common
20 {
21     public class LoopbackServer
22     {
23         public static Func<HttpRequestMessage, X509Certificate2, X509Chain, SslPolicyErrors, bool> AllowAllCertificates = (_, __, ___, ____) => true;
24 
25         public class Options
26         {
27             public IPAddress Address { get; set; } = IPAddress.Loopback;
28             public int ListenBacklog { get; set; } = 1;
29             public bool UseSsl { get; set; } = false;
30             public SslProtocols SslProtocols { get; set; } = SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12;
31             public bool WebSocketEndpoint { get; set; } = false;
32             public Func<Stream, Stream> ResponseStreamWrapper { get; set; }
33         }
34 
CreateServerAsync(Func<Socket, Uri, Task> funcAsync, Options options = null)35         public static Task CreateServerAsync(Func<Socket, Uri, Task> funcAsync, Options options = null)
36         {
37             IPEndPoint ignored;
38             return CreateServerAsync(funcAsync, out ignored, options);
39         }
40 
CreateServerAsync(Func<Socket, Uri, Task> funcAsync, out IPEndPoint localEndPoint, Options options = null)41         public static Task CreateServerAsync(Func<Socket, Uri, Task> funcAsync, out IPEndPoint localEndPoint, Options options = null)
42         {
43             options = options ?? new Options();
44             try
45             {
46                 var server = new Socket(options.Address.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
47 
48                 server.Bind(new IPEndPoint(options.Address, 0));
49                 server.Listen(options.ListenBacklog);
50 
51                 localEndPoint = (IPEndPoint)server.LocalEndPoint;
52                 string host = options.Address.AddressFamily == AddressFamily.InterNetworkV6 ?
53                     $"[{localEndPoint.Address}]" :
54                     localEndPoint.Address.ToString();
55 
56                 string scheme = options.UseSsl ? "https" : "http";
57                 if (options.WebSocketEndpoint)
58                 {
59                     scheme = options.UseSsl ? "wss" : "ws";
60                 }
61 
62                 var url = new Uri($"{scheme}://{host}:{localEndPoint.Port}/");
63 
64                 return funcAsync(server, url).ContinueWith(t =>
65                 {
66                     server.Dispose();
67                     t.GetAwaiter().GetResult();
68                 }, CancellationToken.None, TaskContinuationOptions.None, TaskScheduler.Default);
69             }
70             catch (Exception e)
71             {
72                 localEndPoint = null;
73                 return Task.FromException(e);
74             }
75         }
76 
77         public static string DefaultHttpResponse => $"HTTP/1.1 200 OK\r\nDate: {DateTimeOffset.UtcNow:R}\r\nContent-Length: 0\r\n\r\n";
78 
GetIPv6LinkLocalAddress()79         public static IPAddress GetIPv6LinkLocalAddress() =>
80             NetworkInterface
81                 .GetAllNetworkInterfaces()
82                 .SelectMany(i => i.GetIPProperties().UnicastAddresses)
83                 .Select(a => a.Address)
84                 .Where(a => a.IsIPv6LinkLocal)
85                 .FirstOrDefault();
86 
ReadRequestAndSendResponseAsync(Socket server, string response = null, Options options = null)87         public static Task<List<string>> ReadRequestAndSendResponseAsync(Socket server, string response = null, Options options = null)
88         {
89             return AcceptSocketAsync(server, (s, stream, reader, writer) => ReadWriteAcceptedAsync(s, reader, writer, response), options);
90         }
91 
ReadWriteAcceptedAsync(Socket s, StreamReader reader, StreamWriter writer, string response = null)92         public static async Task<List<string>> ReadWriteAcceptedAsync(Socket s, StreamReader reader, StreamWriter writer, string response = null)
93         {
94             // Read request line and headers. Skip any request body.
95             var lines = new List<string>();
96             string line;
97             while (!string.IsNullOrEmpty(line = await reader.ReadLineAsync().ConfigureAwait(false)))
98             {
99                 lines.Add(line);
100             }
101 
102             await writer.WriteAsync(response ?? DefaultHttpResponse).ConfigureAwait(false);
103 
104             return lines;
105         }
106 
WebSocketHandshakeAsync(Socket s, StreamReader reader, StreamWriter writer)107         public static async Task<bool> WebSocketHandshakeAsync(Socket s, StreamReader reader, StreamWriter writer)
108         {
109             string serverResponse = null;
110             string currentRequestLine;
111             while (!string.IsNullOrEmpty(currentRequestLine = await reader.ReadLineAsync().ConfigureAwait(false)))
112             {
113                 string[] tokens = currentRequestLine.Split(new char[] { ':' }, 2);
114                 if (tokens.Length == 2)
115                 {
116                     string headerName = tokens[0];
117                     if (headerName == "Sec-WebSocket-Key")
118                     {
119                         string headerValue = tokens[1].Trim();
120                         string responseSecurityAcceptValue = ComputeWebSocketHandshakeSecurityAcceptValue(headerValue);
121                         serverResponse =
122                             "HTTP/1.1 101 Switching Protocols\r\n" +
123                             "Upgrade: websocket\r\n" +
124                             "Connection: Upgrade\r\n" +
125                             "Sec-WebSocket-Accept: " + responseSecurityAcceptValue + "\r\n\r\n";
126                     }
127                 }
128             }
129 
130             if (serverResponse != null)
131             {
132                 // We received a valid WebSocket opening handshake. Send the appropriate response.
133                 await writer.WriteAsync(serverResponse).ConfigureAwait(false);
134                 return true;
135             }
136 
137             return false;
138         }
139 
ComputeWebSocketHandshakeSecurityAcceptValue(string secWebSocketKey)140         private static string ComputeWebSocketHandshakeSecurityAcceptValue(string secWebSocketKey)
141         {
142             // GUID specified by RFC 6455.
143             const string Rfc6455Guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
144             string combinedKey = secWebSocketKey + Rfc6455Guid;
145 
146             // Use of SHA1 hash is required by RFC 6455.
147             SHA1 sha1Provider = new SHA1CryptoServiceProvider();
148             byte[] sha1Hash = sha1Provider.ComputeHash(Encoding.UTF8.GetBytes(combinedKey));
149             return Convert.ToBase64String(sha1Hash);
150         }
151 
AcceptSocketAsync(Socket server, Func<Socket, Stream, StreamReader, StreamWriter, Task<List<string>>> funcAsync, Options options = null)152         public static async Task<List<string>> AcceptSocketAsync(Socket server, Func<Socket, Stream, StreamReader, StreamWriter, Task<List<string>>> funcAsync, Options options = null)
153         {
154             options = options ?? new Options();
155             Socket s = await server.AcceptAsync().ConfigureAwait(false);
156             try
157             {
158                 Stream stream = new NetworkStream(s, ownsSocket: false);
159                 if (options.UseSsl)
160                 {
161                     var sslStream = new SslStream(stream, false, delegate { return true; });
162                     using (var cert = Configuration.Certificates.GetServerCertificate())
163                     {
164                         await sslStream.AuthenticateAsServerAsync(
165                             cert,
166                             clientCertificateRequired: true, // allowed but not required
167                             enabledSslProtocols: options.SslProtocols,
168                             checkCertificateRevocation: false).ConfigureAwait(false);
169                     }
170                     stream = sslStream;
171                 }
172 
173                 using (var reader = new StreamReader(stream, Encoding.ASCII))
174                 using (var writer = new StreamWriter(options?.ResponseStreamWrapper?.Invoke(stream) ?? stream, Encoding.ASCII) { AutoFlush = true })
175                 {
176                     return await funcAsync(s, stream, reader, writer).ConfigureAwait(false);
177                 }
178             }
179             finally
180             {
181                 try
182                 {
183                     s.Shutdown(SocketShutdown.Send);
184                     s.Dispose();
185                 }
186                 catch (ObjectDisposedException)
187                 {
188                     // In case the test itself disposes of the socket
189                 }
190             }
191         }
192 
193         public enum TransferType
194         {
195             None = 0,
196             ContentLength,
197             Chunked
198         }
199 
200         public enum TransferError
201         {
202             None = 0,
203             ContentLengthTooLarge,
204             ChunkSizeTooLarge,
205             MissingChunkTerminator
206         }
207 
StartTransferTypeAndErrorServer( TransferType transferType, TransferError transferError, out IPEndPoint localEndPoint)208         public static Task StartTransferTypeAndErrorServer(
209             TransferType transferType,
210             TransferError transferError,
211             out IPEndPoint localEndPoint)
212         {
213             return CreateServerAsync((server, url) => AcceptSocketAsync(server, async (client, stream, reader, writer) =>
214             {
215                 // Read past request headers.
216                 string line;
217                 while (!string.IsNullOrEmpty(line = reader.ReadLine())) ;
218 
219                 // Determine response transfer headers.
220                 string transferHeader = null;
221                 string content = "This is some response content.";
222                 if (transferType == TransferType.ContentLength)
223                 {
224                     transferHeader = transferError == TransferError.ContentLengthTooLarge ?
225                         $"Content-Length: {content.Length + 42}\r\n" :
226                         $"Content-Length: {content.Length}\r\n";
227                 }
228                 else if (transferType == TransferType.Chunked)
229                 {
230                     transferHeader = "Transfer-Encoding: chunked\r\n";
231                 }
232 
233                 // Write response header
234                 await writer.WriteAsync("HTTP/1.1 200 OK\r\n").ConfigureAwait(false);
235                 await writer.WriteAsync($"Date: {DateTimeOffset.UtcNow:R}\r\n").ConfigureAwait(false);
236                 await writer.WriteAsync("Content-Type: text/plain\r\n").ConfigureAwait(false);
237                 if (!string.IsNullOrEmpty(transferHeader))
238                 {
239                     await writer.WriteAsync(transferHeader).ConfigureAwait(false);
240                 }
241                 await writer.WriteAsync("\r\n").ConfigureAwait(false);
242 
243                 // Write response body
244                 if (transferType == TransferType.Chunked)
245                 {
246                     string chunkSizeInHex = string.Format(
247                         "{0:x}\r\n",
248                         content.Length + (transferError == TransferError.ChunkSizeTooLarge ? 42 : 0));
249                     await writer.WriteAsync(chunkSizeInHex).ConfigureAwait(false);
250                     await writer.WriteAsync($"{content}\r\n").ConfigureAwait(false);
251                     if (transferError != TransferError.MissingChunkTerminator)
252                     {
253                         await writer.WriteAsync("0\r\n\r\n").ConfigureAwait(false);
254                     }
255                 }
256                 else
257                 {
258                     await writer.WriteAsync($"{content}\r\n").ConfigureAwait(false);
259                 }
260 
261                 return null;
262             }), out localEndPoint);
263         }
264     }
265 }
266