1 // Licensed to the .NET Foundation under one or more agreements. 2 // The .NET Foundation licenses this file to you under the MIT license. 3 // See the LICENSE file in the project root for more information. 4 5 using System.Collections.Generic; 6 using System.IO; 7 using System.Linq; 8 using System.Net.Http; 9 using System.Net.NetworkInformation; 10 using System.Net.Security; 11 using System.Net.Sockets; 12 using System.Security.Authentication; 13 using System.Security.Cryptography; 14 using System.Security.Cryptography.X509Certificates; 15 using System.Text; 16 using System.Threading; 17 using System.Threading.Tasks; 18 19 namespace System.Net.Test.Common 20 { 21 public class LoopbackServer 22 { 23 public static Func<HttpRequestMessage, X509Certificate2, X509Chain, SslPolicyErrors, bool> AllowAllCertificates = (_, __, ___, ____) => true; 24 25 public class Options 26 { 27 public IPAddress Address { get; set; } = IPAddress.Loopback; 28 public int ListenBacklog { get; set; } = 1; 29 public bool UseSsl { get; set; } = false; 30 public SslProtocols SslProtocols { get; set; } = SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12; 31 public bool WebSocketEndpoint { get; set; } = false; 32 public Func<Stream, Stream> ResponseStreamWrapper { get; set; } 33 } 34 CreateServerAsync(Func<Socket, Uri, Task> funcAsync, Options options = null)35 public static Task CreateServerAsync(Func<Socket, Uri, Task> funcAsync, Options options = null) 36 { 37 IPEndPoint ignored; 38 return CreateServerAsync(funcAsync, out ignored, options); 39 } 40 CreateServerAsync(Func<Socket, Uri, Task> funcAsync, out IPEndPoint localEndPoint, Options options = null)41 public static Task CreateServerAsync(Func<Socket, Uri, Task> funcAsync, out IPEndPoint localEndPoint, Options options = null) 42 { 43 options = options ?? new Options(); 44 try 45 { 46 var server = new Socket(options.Address.AddressFamily, SocketType.Stream, ProtocolType.Tcp); 47 48 server.Bind(new IPEndPoint(options.Address, 0)); 49 server.Listen(options.ListenBacklog); 50 51 localEndPoint = (IPEndPoint)server.LocalEndPoint; 52 string host = options.Address.AddressFamily == AddressFamily.InterNetworkV6 ? 53 $"[{localEndPoint.Address}]" : 54 localEndPoint.Address.ToString(); 55 56 string scheme = options.UseSsl ? "https" : "http"; 57 if (options.WebSocketEndpoint) 58 { 59 scheme = options.UseSsl ? "wss" : "ws"; 60 } 61 62 var url = new Uri($"{scheme}://{host}:{localEndPoint.Port}/"); 63 64 return funcAsync(server, url).ContinueWith(t => 65 { 66 server.Dispose(); 67 t.GetAwaiter().GetResult(); 68 }, CancellationToken.None, TaskContinuationOptions.None, TaskScheduler.Default); 69 } 70 catch (Exception e) 71 { 72 localEndPoint = null; 73 return Task.FromException(e); 74 } 75 } 76 77 public static string DefaultHttpResponse => $"HTTP/1.1 200 OK\r\nDate: {DateTimeOffset.UtcNow:R}\r\nContent-Length: 0\r\n\r\n"; 78 GetIPv6LinkLocalAddress()79 public static IPAddress GetIPv6LinkLocalAddress() => 80 NetworkInterface 81 .GetAllNetworkInterfaces() 82 .SelectMany(i => i.GetIPProperties().UnicastAddresses) 83 .Select(a => a.Address) 84 .Where(a => a.IsIPv6LinkLocal) 85 .FirstOrDefault(); 86 ReadRequestAndSendResponseAsync(Socket server, string response = null, Options options = null)87 public static Task<List<string>> ReadRequestAndSendResponseAsync(Socket server, string response = null, Options options = null) 88 { 89 return AcceptSocketAsync(server, (s, stream, reader, writer) => ReadWriteAcceptedAsync(s, reader, writer, response), options); 90 } 91 ReadWriteAcceptedAsync(Socket s, StreamReader reader, StreamWriter writer, string response = null)92 public static async Task<List<string>> ReadWriteAcceptedAsync(Socket s, StreamReader reader, StreamWriter writer, string response = null) 93 { 94 // Read request line and headers. Skip any request body. 95 var lines = new List<string>(); 96 string line; 97 while (!string.IsNullOrEmpty(line = await reader.ReadLineAsync().ConfigureAwait(false))) 98 { 99 lines.Add(line); 100 } 101 102 await writer.WriteAsync(response ?? DefaultHttpResponse).ConfigureAwait(false); 103 104 return lines; 105 } 106 WebSocketHandshakeAsync(Socket s, StreamReader reader, StreamWriter writer)107 public static async Task<bool> WebSocketHandshakeAsync(Socket s, StreamReader reader, StreamWriter writer) 108 { 109 string serverResponse = null; 110 string currentRequestLine; 111 while (!string.IsNullOrEmpty(currentRequestLine = await reader.ReadLineAsync().ConfigureAwait(false))) 112 { 113 string[] tokens = currentRequestLine.Split(new char[] { ':' }, 2); 114 if (tokens.Length == 2) 115 { 116 string headerName = tokens[0]; 117 if (headerName == "Sec-WebSocket-Key") 118 { 119 string headerValue = tokens[1].Trim(); 120 string responseSecurityAcceptValue = ComputeWebSocketHandshakeSecurityAcceptValue(headerValue); 121 serverResponse = 122 "HTTP/1.1 101 Switching Protocols\r\n" + 123 "Upgrade: websocket\r\n" + 124 "Connection: Upgrade\r\n" + 125 "Sec-WebSocket-Accept: " + responseSecurityAcceptValue + "\r\n\r\n"; 126 } 127 } 128 } 129 130 if (serverResponse != null) 131 { 132 // We received a valid WebSocket opening handshake. Send the appropriate response. 133 await writer.WriteAsync(serverResponse).ConfigureAwait(false); 134 return true; 135 } 136 137 return false; 138 } 139 ComputeWebSocketHandshakeSecurityAcceptValue(string secWebSocketKey)140 private static string ComputeWebSocketHandshakeSecurityAcceptValue(string secWebSocketKey) 141 { 142 // GUID specified by RFC 6455. 143 const string Rfc6455Guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; 144 string combinedKey = secWebSocketKey + Rfc6455Guid; 145 146 // Use of SHA1 hash is required by RFC 6455. 147 SHA1 sha1Provider = new SHA1CryptoServiceProvider(); 148 byte[] sha1Hash = sha1Provider.ComputeHash(Encoding.UTF8.GetBytes(combinedKey)); 149 return Convert.ToBase64String(sha1Hash); 150 } 151 AcceptSocketAsync(Socket server, Func<Socket, Stream, StreamReader, StreamWriter, Task<List<string>>> funcAsync, Options options = null)152 public static async Task<List<string>> AcceptSocketAsync(Socket server, Func<Socket, Stream, StreamReader, StreamWriter, Task<List<string>>> funcAsync, Options options = null) 153 { 154 options = options ?? new Options(); 155 Socket s = await server.AcceptAsync().ConfigureAwait(false); 156 try 157 { 158 Stream stream = new NetworkStream(s, ownsSocket: false); 159 if (options.UseSsl) 160 { 161 var sslStream = new SslStream(stream, false, delegate { return true; }); 162 using (var cert = Configuration.Certificates.GetServerCertificate()) 163 { 164 await sslStream.AuthenticateAsServerAsync( 165 cert, 166 clientCertificateRequired: true, // allowed but not required 167 enabledSslProtocols: options.SslProtocols, 168 checkCertificateRevocation: false).ConfigureAwait(false); 169 } 170 stream = sslStream; 171 } 172 173 using (var reader = new StreamReader(stream, Encoding.ASCII)) 174 using (var writer = new StreamWriter(options?.ResponseStreamWrapper?.Invoke(stream) ?? stream, Encoding.ASCII) { AutoFlush = true }) 175 { 176 return await funcAsync(s, stream, reader, writer).ConfigureAwait(false); 177 } 178 } 179 finally 180 { 181 try 182 { 183 s.Shutdown(SocketShutdown.Send); 184 s.Dispose(); 185 } 186 catch (ObjectDisposedException) 187 { 188 // In case the test itself disposes of the socket 189 } 190 } 191 } 192 193 public enum TransferType 194 { 195 None = 0, 196 ContentLength, 197 Chunked 198 } 199 200 public enum TransferError 201 { 202 None = 0, 203 ContentLengthTooLarge, 204 ChunkSizeTooLarge, 205 MissingChunkTerminator 206 } 207 StartTransferTypeAndErrorServer( TransferType transferType, TransferError transferError, out IPEndPoint localEndPoint)208 public static Task StartTransferTypeAndErrorServer( 209 TransferType transferType, 210 TransferError transferError, 211 out IPEndPoint localEndPoint) 212 { 213 return CreateServerAsync((server, url) => AcceptSocketAsync(server, async (client, stream, reader, writer) => 214 { 215 // Read past request headers. 216 string line; 217 while (!string.IsNullOrEmpty(line = reader.ReadLine())) ; 218 219 // Determine response transfer headers. 220 string transferHeader = null; 221 string content = "This is some response content."; 222 if (transferType == TransferType.ContentLength) 223 { 224 transferHeader = transferError == TransferError.ContentLengthTooLarge ? 225 $"Content-Length: {content.Length + 42}\r\n" : 226 $"Content-Length: {content.Length}\r\n"; 227 } 228 else if (transferType == TransferType.Chunked) 229 { 230 transferHeader = "Transfer-Encoding: chunked\r\n"; 231 } 232 233 // Write response header 234 await writer.WriteAsync("HTTP/1.1 200 OK\r\n").ConfigureAwait(false); 235 await writer.WriteAsync($"Date: {DateTimeOffset.UtcNow:R}\r\n").ConfigureAwait(false); 236 await writer.WriteAsync("Content-Type: text/plain\r\n").ConfigureAwait(false); 237 if (!string.IsNullOrEmpty(transferHeader)) 238 { 239 await writer.WriteAsync(transferHeader).ConfigureAwait(false); 240 } 241 await writer.WriteAsync("\r\n").ConfigureAwait(false); 242 243 // Write response body 244 if (transferType == TransferType.Chunked) 245 { 246 string chunkSizeInHex = string.Format( 247 "{0:x}\r\n", 248 content.Length + (transferError == TransferError.ChunkSizeTooLarge ? 42 : 0)); 249 await writer.WriteAsync(chunkSizeInHex).ConfigureAwait(false); 250 await writer.WriteAsync($"{content}\r\n").ConfigureAwait(false); 251 if (transferError != TransferError.MissingChunkTerminator) 252 { 253 await writer.WriteAsync("0\r\n\r\n").ConfigureAwait(false); 254 } 255 } 256 else 257 { 258 await writer.WriteAsync($"{content}\r\n").ConfigureAwait(false); 259 } 260 261 return null; 262 }), out localEndPoint); 263 } 264 } 265 } 266