1 // Licensed to the .NET Foundation under one or more agreements.
2 // The .NET Foundation licenses this file to you under the MIT license.
3 // See the LICENSE file in the project root for more information.
4 
5 using Microsoft.Win32.SafeHandles;
6 using System.Collections.Generic;
7 using System.Diagnostics;
8 using System.IO;
9 using System.Runtime.InteropServices;
10 using System.Threading;
11 using System.Threading.Tasks;
12 
13 namespace System.Net.WebSockets
14 {
15     internal sealed class WebSocketHttpListenerDuplexStream : Stream, WebSocketBase.IWebSocketStream
16     {
17         private static readonly EventHandler<HttpListenerAsyncEventArgs> s_OnReadCompleted =
18             new EventHandler<HttpListenerAsyncEventArgs>(OnReadCompleted);
19         private static readonly EventHandler<HttpListenerAsyncEventArgs> s_OnWriteCompleted =
20             new EventHandler<HttpListenerAsyncEventArgs>(OnWriteCompleted);
21         private static readonly Func<Exception, bool> s_CanHandleException = new Func<Exception, bool>(CanHandleException);
22         private static readonly Action<object> s_OnCancel = new Action<object>(OnCancel);
23         private readonly HttpRequestStream _inputStream;
24         private readonly HttpResponseStream _outputStream;
25         private HttpListenerContext _context;
26         private bool _inOpaqueMode;
27         private WebSocketBase _webSocket;
28         private HttpListenerAsyncEventArgs _writeEventArgs;
29         private HttpListenerAsyncEventArgs _readEventArgs;
30         private TaskCompletionSource<object> _writeTaskCompletionSource;
31         private TaskCompletionSource<int> _readTaskCompletionSource;
32         private int _cleanedUp;
33 
34 #if DEBUG
35         private class OutstandingOperations
36         {
37             internal int _reads;
38             internal int _writes;
39         }
40 
41         private readonly OutstandingOperations _outstandingOperations = new OutstandingOperations();
42 #endif //DEBUG
43 
WebSocketHttpListenerDuplexStream(HttpRequestStream inputStream, HttpResponseStream outputStream, HttpListenerContext context)44         public WebSocketHttpListenerDuplexStream(HttpRequestStream inputStream,
45             HttpResponseStream outputStream,
46             HttpListenerContext context)
47         {
48             Debug.Assert(inputStream != null, "'inputStream' MUST NOT be NULL.");
49             Debug.Assert(outputStream != null, "'outputStream' MUST NOT be NULL.");
50             Debug.Assert(context != null, "'context' MUST NOT be NULL.");
51             Debug.Assert(inputStream.CanRead, "'inputStream' MUST support read operations.");
52             Debug.Assert(outputStream.CanWrite, "'outputStream' MUST support write operations.");
53 
54             _inputStream = inputStream;
55             _outputStream = outputStream;
56             _context = context;
57 
58             if (NetEventSource.IsEnabled)
59             {
60                 NetEventSource.Associate(inputStream, this);
61                 NetEventSource.Associate(outputStream, this);
62             }
63         }
64 
65         public override bool CanRead
66         {
67             get
68             {
69                 return _inputStream.CanRead;
70             }
71         }
72 
73         public override bool CanSeek
74         {
75             get
76             {
77                 return false;
78             }
79         }
80 
81         public override bool CanTimeout
82         {
83             get
84             {
85                 return _inputStream.CanTimeout && _outputStream.CanTimeout;
86             }
87         }
88 
89         public override bool CanWrite
90         {
91             get
92             {
93                 return _outputStream.CanWrite;
94             }
95         }
96 
97         public override long Length
98         {
99             get
100             {
101                 throw new NotSupportedException(SR.net_noseek);
102             }
103         }
104 
105         public override long Position
106         {
107             get
108             {
109                 throw new NotSupportedException(SR.net_noseek);
110             }
111             set
112             {
113                 throw new NotSupportedException(SR.net_noseek);
114             }
115         }
116 
Read(byte[] buffer, int offset, int count)117         public override int Read(byte[] buffer, int offset, int count)
118         {
119             return _inputStream.Read(buffer, offset, count);
120         }
121 
ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)122         public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
123         {
124             WebSocketValidate.ValidateBuffer(buffer, offset, count);
125 
126             return ReadAsyncCore(buffer, offset, count, cancellationToken);
127         }
128 
ReadAsyncCore(byte[] buffer, int offset, int count, CancellationToken cancellationToken)129         private async Task<int> ReadAsyncCore(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
130         {
131             if (NetEventSource.IsEnabled)
132             {
133                 NetEventSource.Enter(this, HttpWebSocket.GetTraceMsgForParameters(offset, count, cancellationToken));
134             }
135 
136             CancellationTokenRegistration cancellationTokenRegistration = new CancellationTokenRegistration();
137 
138             int bytesRead = 0;
139             try
140             {
141                 if (cancellationToken.CanBeCanceled)
142                 {
143                     cancellationTokenRegistration = cancellationToken.Register(s_OnCancel, this, false);
144                 }
145 
146                 if (!_inOpaqueMode)
147                 {
148                     bytesRead = await _inputStream.ReadAsync(buffer, offset, count, cancellationToken).SuppressContextFlow<int>();
149                 }
150                 else
151                 {
152 #if DEBUG
153                     // When using fast path only one outstanding read is permitted. By switching into opaque mode
154                     // via IWebSocketStream.SwitchToOpaqueMode (see more detailed comments in interface definition)
155                     // caller takes responsibility for enforcing this constraint.
156                     Debug.Assert(Interlocked.Increment(ref _outstandingOperations._reads) == 1,
157                         "Only one outstanding read allowed at any given time.");
158 #endif
159                     _readTaskCompletionSource = new TaskCompletionSource<int>();
160                     _readEventArgs.SetBuffer(buffer, offset, count);
161                     if (!ReadAsyncFast(_readEventArgs))
162                     {
163                         if (_readEventArgs.Exception != null)
164                         {
165                             throw _readEventArgs.Exception;
166                         }
167 
168                         bytesRead = _readEventArgs.BytesTransferred;
169                     }
170                     else
171                     {
172                         bytesRead = await _readTaskCompletionSource.Task.SuppressContextFlow<int>();
173                     }
174                 }
175             }
176             catch (Exception error)
177             {
178                 if (s_CanHandleException(error))
179                 {
180                     cancellationToken.ThrowIfCancellationRequested();
181                 }
182 
183                 throw;
184             }
185             finally
186             {
187                 cancellationTokenRegistration.Dispose();
188 
189                 if (NetEventSource.IsEnabled)
190                 {
191                     NetEventSource.Exit(this, bytesRead);
192                 }
193             }
194 
195             return bytesRead;
196         }
197 
198         // return value indicates sync vs async completion
199         // false: sync completion
200         // true: async completion
ReadAsyncFast(HttpListenerAsyncEventArgs eventArgs)201         private unsafe bool ReadAsyncFast(HttpListenerAsyncEventArgs eventArgs)
202         {
203             if (NetEventSource.IsEnabled)
204             {
205                 NetEventSource.Enter(this);
206             }
207 
208             eventArgs.StartOperationCommon(this, _inputStream.InternalHttpContext.RequestQueueBoundHandle);
209             eventArgs.StartOperationReceive();
210 
211             uint statusCode = 0;
212             bool completedAsynchronously = false;
213             try
214             {
215                 Debug.Assert(eventArgs.Buffer != null, "'BufferList' is not supported for read operations.");
216                 if (eventArgs.Count == 0 || _inputStream.Closed)
217                 {
218                     eventArgs.FinishOperationSuccess(0, true);
219                     return false;
220                 }
221 
222                 uint dataRead = 0;
223                 int offset = eventArgs.Offset;
224                 int remainingCount = eventArgs.Count;
225 
226                 if (_inputStream.BufferedDataChunksAvailable)
227                 {
228                     dataRead = _inputStream.GetChunks(eventArgs.Buffer, eventArgs.Offset, eventArgs.Count);
229                     if (_inputStream.BufferedDataChunksAvailable && dataRead == eventArgs.Count)
230                     {
231                         eventArgs.FinishOperationSuccess(eventArgs.Count, true);
232                         return false;
233                     }
234                 }
235 
236                 Debug.Assert(!_inputStream.BufferedDataChunksAvailable, "'m_InputStream.BufferedDataChunksAvailable' MUST BE 'FALSE' at this point.");
237                 Debug.Assert(dataRead <= eventArgs.Count, "'dataRead' MUST NOT be bigger than 'eventArgs.Count'.");
238 
239                 if (dataRead != 0)
240                 {
241                     offset += (int)dataRead;
242                     remainingCount -= (int)dataRead;
243                     //the http.sys team recommends that we limit the size to 128kb
244                     if (remainingCount > HttpRequestStream.MaxReadSize)
245                     {
246                         remainingCount = HttpRequestStream.MaxReadSize;
247                     }
248 
249                     eventArgs.SetBuffer(eventArgs.Buffer, offset, remainingCount);
250                 }
251                 else if (remainingCount > HttpRequestStream.MaxReadSize)
252                 {
253                     remainingCount = HttpRequestStream.MaxReadSize;
254                     eventArgs.SetBuffer(eventArgs.Buffer, offset, remainingCount);
255                 }
256 
257                 uint flags = 0;
258                 uint bytesReturned = 0;
259                 statusCode =
260                     Interop.HttpApi.HttpReceiveRequestEntityBody(
261                         _inputStream.InternalHttpContext.RequestQueueHandle,
262                         _inputStream.InternalHttpContext.RequestId,
263                         flags,
264                         (byte*)_webSocket.InternalBuffer.ToIntPtr(eventArgs.Offset),
265                         (uint)eventArgs.Count,
266                         out bytesReturned,
267                         eventArgs.NativeOverlapped);
268 
269                 if (statusCode != Interop.HttpApi.ERROR_SUCCESS &&
270                     statusCode != Interop.HttpApi.ERROR_IO_PENDING &&
271                     statusCode != Interop.HttpApi.ERROR_HANDLE_EOF)
272                 {
273                     throw new HttpListenerException((int)statusCode);
274                 }
275                 else if (statusCode == Interop.HttpApi.ERROR_SUCCESS &&
276                     HttpListener.SkipIOCPCallbackOnSuccess)
277                 {
278                     // IO operation completed synchronously. No IO completion port callback is used because
279                     // it was disabled in SwitchToOpaqueMode()
280                     eventArgs.FinishOperationSuccess((int)bytesReturned, true);
281                     completedAsynchronously = false;
282                 }
283                 else
284                 {
285                     completedAsynchronously = true;
286                 }
287             }
288             catch (Exception e)
289             {
290                 _readEventArgs.FinishOperationFailure(e, true);
291                 _outputStream.SetClosedFlag();
292                 _outputStream.InternalHttpContext.Abort();
293 
294                 throw;
295             }
296             finally
297             {
298                 if (NetEventSource.IsEnabled)
299                 {
300                     NetEventSource.Exit(this, completedAsynchronously);
301                 }
302             }
303 
304             return completedAsynchronously;
305         }
306 
ReadByte()307         public override int ReadByte()
308         {
309             return _inputStream.ReadByte();
310         }
311 
312         public bool SupportsMultipleWrite
313         {
314             get
315             {
316                 return true;
317             }
318         }
319 
BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state)320         public override IAsyncResult BeginRead(byte[] buffer,
321             int offset,
322             int count,
323             AsyncCallback callback,
324             object state)
325         {
326             return _inputStream.BeginRead(buffer, offset, count, callback, state);
327         }
328 
EndRead(IAsyncResult asyncResult)329         public override int EndRead(IAsyncResult asyncResult)
330         {
331             return _inputStream.EndRead(asyncResult);
332         }
333 
MultipleWriteAsync(IList<ArraySegment<byte>> sendBuffers, CancellationToken cancellationToken)334         public Task MultipleWriteAsync(IList<ArraySegment<byte>> sendBuffers, CancellationToken cancellationToken)
335         {
336             Debug.Assert(_inOpaqueMode, "The stream MUST be in opaque mode at this point.");
337             Debug.Assert(sendBuffers != null, "'sendBuffers' MUST NOT be NULL.");
338             Debug.Assert(sendBuffers.Count == 1 || sendBuffers.Count == 2,
339                 "'sendBuffers.Count' MUST be either '1' or '2'.");
340 
341             if (sendBuffers.Count == 1)
342             {
343                 ArraySegment<byte> buffer = sendBuffers[0];
344                 return WriteAsync(buffer.Array, buffer.Offset, buffer.Count, cancellationToken);
345             }
346 
347             return MultipleWriteAsyncCore(sendBuffers, cancellationToken);
348         }
349 
MultipleWriteAsyncCore(IList<ArraySegment<byte>> sendBuffers, CancellationToken cancellationToken)350         private async Task MultipleWriteAsyncCore(IList<ArraySegment<byte>> sendBuffers, CancellationToken cancellationToken)
351         {
352             Debug.Assert(sendBuffers != null, "'sendBuffers' MUST NOT be NULL.");
353             Debug.Assert(sendBuffers.Count == 2, "'sendBuffers.Count' MUST be '2' at this point.");
354 
355             if (NetEventSource.IsEnabled)
356             {
357                 NetEventSource.Enter(this);
358             }
359 
360             CancellationTokenRegistration cancellationTokenRegistration = new CancellationTokenRegistration();
361 
362             try
363             {
364                 if (cancellationToken.CanBeCanceled)
365                 {
366                     cancellationTokenRegistration = cancellationToken.Register(s_OnCancel, this, false);
367                 }
368 #if DEBUG
369                 // When using fast path only one outstanding read is permitted. By switching into opaque mode
370                 // via IWebSocketStream.SwitchToOpaqueMode (see more detailed comments in interface definition)
371                 // caller takes responsibility for enforcing this constraint.
372                 Debug.Assert(Interlocked.Increment(ref _outstandingOperations._writes) == 1,
373                     "Only one outstanding write allowed at any given time.");
374 #endif
375                 _writeTaskCompletionSource = new TaskCompletionSource<object>();
376                 _writeEventArgs.SetBuffer(null, 0, 0);
377                 _writeEventArgs.BufferList = sendBuffers;
378                 if (WriteAsyncFast(_writeEventArgs))
379                 {
380                     await _writeTaskCompletionSource.Task.SuppressContextFlow();
381                 }
382             }
383             catch (Exception error)
384             {
385                 if (s_CanHandleException(error))
386                 {
387                     cancellationToken.ThrowIfCancellationRequested();
388                 }
389 
390                 throw;
391             }
392             finally
393             {
394                 cancellationTokenRegistration.Dispose();
395 
396                 if (NetEventSource.IsEnabled)
397                 {
398                     NetEventSource.Exit(this);
399                 }
400             }
401         }
402 
Write(byte[] buffer, int offset, int count)403         public override void Write(byte[] buffer, int offset, int count)
404         {
405             _outputStream.Write(buffer, offset, count);
406         }
407 
WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)408         public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
409         {
410             WebSocketValidate.ValidateBuffer(buffer, offset, count);
411 
412             return WriteAsyncCore(buffer, offset, count, cancellationToken);
413         }
414 
WriteAsyncCore(byte[] buffer, int offset, int count, CancellationToken cancellationToken)415         private async Task WriteAsyncCore(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
416         {
417             if (NetEventSource.IsEnabled)
418             {
419                 NetEventSource.Enter(this, HttpWebSocket.GetTraceMsgForParameters(offset, count, cancellationToken));
420             }
421 
422             CancellationTokenRegistration cancellationTokenRegistration = new CancellationTokenRegistration();
423 
424             try
425             {
426                 if (cancellationToken.CanBeCanceled)
427                 {
428                     cancellationTokenRegistration = cancellationToken.Register(s_OnCancel, this, false);
429                 }
430 
431                 if (!_inOpaqueMode)
432                 {
433                     await _outputStream.WriteAsync(buffer, offset, count, cancellationToken).SuppressContextFlow();
434                 }
435                 else
436                 {
437 #if DEBUG
438                     // When using fast path only one outstanding read is permitted. By switching into opaque mode
439                     // via IWebSocketStream.SwitchToOpaqueMode (see more detailed comments in interface definition)
440                     // caller takes responsibility for enforcing this constraint.
441                     Debug.Assert(Interlocked.Increment(ref _outstandingOperations._writes) == 1,
442                         "Only one outstanding write allowed at any given time.");
443 #endif
444                     _writeTaskCompletionSource = new TaskCompletionSource<object>();
445                     _writeEventArgs.BufferList = null;
446                     _writeEventArgs.SetBuffer(buffer, offset, count);
447                     if (WriteAsyncFast(_writeEventArgs))
448                     {
449                         await _writeTaskCompletionSource.Task.SuppressContextFlow();
450                     }
451                 }
452             }
453             catch (Exception error)
454             {
455                 if (s_CanHandleException(error))
456                 {
457                     cancellationToken.ThrowIfCancellationRequested();
458                 }
459 
460                 throw;
461             }
462             finally
463             {
464                 cancellationTokenRegistration.Dispose();
465 
466                 if (NetEventSource.IsEnabled)
467                 {
468                     NetEventSource.Exit(this);
469                 }
470             }
471         }
472 
473         // return value indicates sync vs async completion
474         // false: sync completion
475         // true: async completion
WriteAsyncFast(HttpListenerAsyncEventArgs eventArgs)476         private unsafe bool WriteAsyncFast(HttpListenerAsyncEventArgs eventArgs)
477         {
478             if (NetEventSource.IsEnabled)
479             {
480                 NetEventSource.Enter(this);
481             }
482 
483             Interop.HttpApi.HTTP_FLAGS flags = Interop.HttpApi.HTTP_FLAGS.NONE;
484 
485             eventArgs.StartOperationCommon(this, _outputStream.InternalHttpContext.RequestQueueBoundHandle);
486             eventArgs.StartOperationSend();
487 
488             uint statusCode;
489             bool completedAsynchronously = false;
490             try
491             {
492                 if (_outputStream.Closed ||
493                     (eventArgs.Buffer != null && eventArgs.Count == 0))
494                 {
495                     eventArgs.FinishOperationSuccess(eventArgs.Count, true);
496                     return false;
497                 }
498 
499                 if (eventArgs.ShouldCloseOutput)
500                 {
501                     flags |= Interop.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_DISCONNECT;
502                 }
503                 else
504                 {
505                     flags |= Interop.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA;
506                     // When using HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA HTTP.SYS will copy the payload to
507                     // kernel memory (Non-Paged Pool). Http.Sys will buffer up to
508                     // Math.Min(16 MB, current TCP window size)
509                     flags |= Interop.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA;
510                 }
511 
512                 uint bytesSent;
513                 statusCode =
514                     Interop.HttpApi.HttpSendResponseEntityBody(
515                         _outputStream.InternalHttpContext.RequestQueueHandle,
516                         _outputStream.InternalHttpContext.RequestId,
517                         (uint)flags,
518                         eventArgs.EntityChunkCount,
519                         (Interop.HttpApi.HTTP_DATA_CHUNK*)eventArgs.EntityChunks,
520                         &bytesSent,
521                         SafeLocalAllocHandle.Zero,
522                         0,
523                         eventArgs.NativeOverlapped,
524                         null);
525 
526                 if (statusCode != Interop.HttpApi.ERROR_SUCCESS &&
527                     statusCode != Interop.HttpApi.ERROR_IO_PENDING)
528                 {
529                     throw new HttpListenerException((int)statusCode);
530                 }
531                 else if (statusCode == Interop.HttpApi.ERROR_SUCCESS &&
532                     HttpListener.SkipIOCPCallbackOnSuccess)
533                 {
534                     // IO operation completed synchronously - callback won't be called to signal completion.
535                     eventArgs.FinishOperationSuccess((int)bytesSent, true);
536                     completedAsynchronously = false;
537                 }
538                 else
539                 {
540                     completedAsynchronously = true;
541                 }
542             }
543             catch (Exception e)
544             {
545                 _writeEventArgs.FinishOperationFailure(e, true);
546                 _outputStream.SetClosedFlag();
547                 _outputStream.InternalHttpContext.Abort();
548 
549                 throw;
550             }
551             finally
552             {
553                 if (NetEventSource.IsEnabled)
554                 {
555                     NetEventSource.Exit(this, completedAsynchronously);
556                 }
557             }
558 
559             return completedAsynchronously;
560         }
561 
WriteByte(byte value)562         public override void WriteByte(byte value)
563         {
564             _outputStream.WriteByte(value);
565         }
566 
BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state)567         public override IAsyncResult BeginWrite(byte[] buffer,
568             int offset,
569             int count,
570             AsyncCallback callback,
571             object state)
572         {
573             return _outputStream.BeginWrite(buffer, offset, count, callback, state);
574         }
575 
EndWrite(IAsyncResult asyncResult)576         public override void EndWrite(IAsyncResult asyncResult)
577         {
578             _outputStream.EndWrite(asyncResult);
579         }
580 
Flush()581         public override void Flush()
582         {
583             _outputStream.Flush();
584         }
585 
FlushAsync(CancellationToken cancellationToken)586         public override Task FlushAsync(CancellationToken cancellationToken)
587         {
588             return _outputStream.FlushAsync(cancellationToken);
589         }
590 
Seek(long offset, SeekOrigin origin)591         public override long Seek(long offset, SeekOrigin origin)
592         {
593             throw new NotSupportedException(SR.net_noseek);
594         }
595 
SetLength(long value)596         public override void SetLength(long value)
597         {
598             throw new NotSupportedException(SR.net_noseek);
599         }
600 
CloseNetworkConnectionAsync(CancellationToken cancellationToken)601         public async Task CloseNetworkConnectionAsync(CancellationToken cancellationToken)
602         {
603             // need to yield here to make sure that we don't get any exception synchronously
604             await Task.Yield();
605 
606             if (NetEventSource.IsEnabled)
607             {
608                 NetEventSource.Enter(this);
609             }
610 
611             CancellationTokenRegistration cancellationTokenRegistration = new CancellationTokenRegistration();
612 
613             try
614             {
615                 if (cancellationToken.CanBeCanceled)
616                 {
617                     cancellationTokenRegistration = cancellationToken.Register(s_OnCancel, this, false);
618                 }
619 #if DEBUG
620                 // When using fast path only one outstanding read is permitted. By switching into opaque mode
621                 // via IWebSocketStream.SwitchToOpaqueMode (see more detailed comments in interface definition)
622                 // caller takes responsibility for enforcing this constraint.
623                 Debug.Assert(Interlocked.Increment(ref _outstandingOperations._writes) == 1,
624                     "Only one outstanding write allowed at any given time.");
625 #endif
626                 _writeTaskCompletionSource = new TaskCompletionSource<object>();
627                 _writeEventArgs.SetShouldCloseOutput();
628                 if (WriteAsyncFast(_writeEventArgs))
629                 {
630                     await _writeTaskCompletionSource.Task.SuppressContextFlow();
631                 }
632             }
633             catch (Exception error)
634             {
635                 if (!s_CanHandleException(error))
636                 {
637                     throw;
638                 }
639 
640                 // throw OperationCancelledException when canceled by the caller
641                 // otherwise swallow the exception
642                 cancellationToken.ThrowIfCancellationRequested();
643             }
644             finally
645             {
646                 cancellationTokenRegistration.Dispose();
647 
648                 if (NetEventSource.IsEnabled)
649                 {
650                     NetEventSource.Exit(this);
651                 }
652             }
653         }
654 
Dispose(bool disposing)655         protected override void Dispose(bool disposing)
656         {
657             if (disposing && Interlocked.Exchange(ref _cleanedUp, 1) == 0)
658             {
659                 if (_readTaskCompletionSource != null)
660                 {
661                     _readTaskCompletionSource.TrySetCanceled();
662                 }
663 
664                 if (_writeTaskCompletionSource != null)
665                 {
666                     _writeTaskCompletionSource.TrySetCanceled();
667                 }
668 
669                 if (_readEventArgs != null)
670                 {
671                     _readEventArgs.Dispose();
672                 }
673 
674                 if (_writeEventArgs != null)
675                 {
676                     _writeEventArgs.Dispose();
677                 }
678 
679                 try
680                 {
681                     _inputStream.Close();
682                 }
683                 finally
684                 {
685                     _outputStream.Close();
686                 }
687             }
688         }
689 
Abort()690         public void Abort()
691         {
692             OnCancel(this);
693         }
694 
CanHandleException(Exception error)695         private static bool CanHandleException(Exception error)
696         {
697             return error is HttpListenerException ||
698                 error is ObjectDisposedException ||
699                 error is IOException;
700         }
701 
OnCancel(object state)702         private static void OnCancel(object state)
703         {
704             Debug.Assert(state != null, "'state' MUST NOT be NULL.");
705             WebSocketHttpListenerDuplexStream thisPtr = state as WebSocketHttpListenerDuplexStream;
706             Debug.Assert(thisPtr != null, "'thisPtr' MUST NOT be NULL.");
707 
708             if (NetEventSource.IsEnabled)
709             {
710                 NetEventSource.Enter(state);
711             }
712 
713             try
714             {
715                 thisPtr._outputStream.SetClosedFlag();
716                 thisPtr._context.Abort();
717             }
718             catch { }
719 
720             TaskCompletionSource<int> readTaskCompletionSourceSnapshot = thisPtr._readTaskCompletionSource;
721 
722             if (readTaskCompletionSourceSnapshot != null)
723             {
724                 readTaskCompletionSourceSnapshot.TrySetCanceled();
725             }
726 
727             TaskCompletionSource<object> writeTaskCompletionSourceSnapshot = thisPtr._writeTaskCompletionSource;
728 
729             if (writeTaskCompletionSourceSnapshot != null)
730             {
731                 writeTaskCompletionSourceSnapshot.TrySetCanceled();
732             }
733 
734             if (NetEventSource.IsEnabled)
735             {
736                 NetEventSource.Exit(state);
737             }
738         }
739 
SwitchToOpaqueMode(WebSocketBase webSocket)740         public void SwitchToOpaqueMode(WebSocketBase webSocket)
741         {
742             Debug.Assert(webSocket != null, "'webSocket' MUST NOT be NULL.");
743             Debug.Assert(_outputStream != null, "'m_OutputStream' MUST NOT be NULL.");
744             Debug.Assert(_outputStream.InternalHttpContext != null,
745                 "'m_OutputStream.InternalHttpContext' MUST NOT be NULL.");
746             Debug.Assert(_outputStream.InternalHttpContext.Response != null,
747                 "'m_OutputStream.InternalHttpContext.Response' MUST NOT be NULL.");
748             Debug.Assert(_outputStream.InternalHttpContext.Response.SentHeaders,
749                 "Headers MUST have been sent at this point.");
750             Debug.Assert(!_inOpaqueMode, "SwitchToOpaqueMode MUST NOT be called multiple times.");
751 
752             if (_inOpaqueMode)
753             {
754                 throw new InvalidOperationException();
755             }
756 
757             _webSocket = webSocket;
758             _inOpaqueMode = true;
759             _readEventArgs = new HttpListenerAsyncEventArgs(webSocket, this);
760             _readEventArgs.Completed += s_OnReadCompleted;
761             _writeEventArgs = new HttpListenerAsyncEventArgs(webSocket, this);
762             _writeEventArgs.Completed += s_OnWriteCompleted;
763 
764             if (NetEventSource.IsEnabled)
765             {
766                 NetEventSource.Associate(this, webSocket);
767             }
768         }
769 
OnWriteCompleted(object sender, HttpListenerAsyncEventArgs eventArgs)770         private static void OnWriteCompleted(object sender, HttpListenerAsyncEventArgs eventArgs)
771         {
772             Debug.Assert(eventArgs != null, "'eventArgs' MUST NOT be NULL.");
773             WebSocketHttpListenerDuplexStream thisPtr = eventArgs.CurrentStream;
774             Debug.Assert(thisPtr != null, "'thisPtr' MUST NOT be NULL.");
775 #if DEBUG
776             Debug.Assert(Interlocked.Decrement(ref thisPtr._outstandingOperations._writes) >= 0,
777                 "'thisPtr.m_OutstandingOperations.m_Writes' MUST NOT be negative.");
778 #endif
779 
780             if (NetEventSource.IsEnabled)
781             {
782                 NetEventSource.Enter(thisPtr);
783             }
784 
785             if (eventArgs.Exception != null)
786             {
787                 thisPtr._writeTaskCompletionSource.TrySetException(eventArgs.Exception);
788             }
789             else
790             {
791                 thisPtr._writeTaskCompletionSource.TrySetResult(null);
792             }
793 
794             if (NetEventSource.IsEnabled)
795             {
796                 NetEventSource.Exit(thisPtr);
797             }
798         }
799 
OnReadCompleted(object sender, HttpListenerAsyncEventArgs eventArgs)800         private static void OnReadCompleted(object sender, HttpListenerAsyncEventArgs eventArgs)
801         {
802             Debug.Assert(eventArgs != null, "'eventArgs' MUST NOT be NULL.");
803             WebSocketHttpListenerDuplexStream thisPtr = eventArgs.CurrentStream;
804             Debug.Assert(thisPtr != null, "'thisPtr' MUST NOT be NULL.");
805 #if DEBUG
806             Debug.Assert(Interlocked.Decrement(ref thisPtr._outstandingOperations._reads) >= 0,
807                 "'thisPtr.m_OutstandingOperations.m_Reads' MUST NOT be negative.");
808 #endif
809 
810             if (NetEventSource.IsEnabled)
811             {
812                 NetEventSource.Enter(thisPtr);
813             }
814 
815             if (eventArgs.Exception != null)
816             {
817                 thisPtr._readTaskCompletionSource.TrySetException(eventArgs.Exception);
818             }
819             else
820             {
821                 thisPtr._readTaskCompletionSource.TrySetResult(eventArgs.BytesTransferred);
822             }
823 
824             if (NetEventSource.IsEnabled)
825             {
826                 NetEventSource.Exit(thisPtr);
827             }
828         }
829 
830         internal class HttpListenerAsyncEventArgs : EventArgs, IDisposable
831         {
832             private const int Free = 0;
833             private const int InProgress = 1;
834             private const int Disposed = 2;
835             private int _operating;
836 
837             private bool _disposeCalled;
838             private unsafe NativeOverlapped* _ptrNativeOverlapped;
839             private ThreadPoolBoundHandle _boundHandle;
840             private event EventHandler<HttpListenerAsyncEventArgs> m_Completed;
841             private byte[] _buffer;
842             private IList<ArraySegment<byte>> _bufferList;
843             private int _count;
844             private int _offset;
845             private int _bytesTransferred;
846             private HttpListenerAsyncOperation _completedOperation;
847             private Interop.HttpApi.HTTP_DATA_CHUNK[] _dataChunks;
848             private GCHandle _dataChunksGCHandle;
849             private ushort _dataChunkCount;
850             private Exception _exception;
851             private bool _shouldCloseOutput;
852             private readonly WebSocketBase _webSocket;
853             private readonly WebSocketHttpListenerDuplexStream _currentStream;
854 
855 #if DEBUG
856             private volatile int _nativeOverlappedCounter = 0;
857             private volatile int _nativeOverlappedUsed = 0;
858 
DebugRefCountReleaseNativeOverlapped()859             private void DebugRefCountReleaseNativeOverlapped()
860             {
861                 Debug.Assert(Interlocked.Decrement(ref _nativeOverlappedCounter) == 0, "NativeOverlapped released too many times.");
862                 Interlocked.Decrement(ref _nativeOverlappedUsed);
863             }
864 
DebugRefCountAllocNativeOverlapped()865             private void DebugRefCountAllocNativeOverlapped()
866             {
867                 Debug.Assert(Interlocked.Increment(ref _nativeOverlappedCounter) == 1, "NativeOverlapped allocated without release.");
868             }
869 #endif
870 
HttpListenerAsyncEventArgs(WebSocketBase webSocket, WebSocketHttpListenerDuplexStream stream)871             public HttpListenerAsyncEventArgs(WebSocketBase webSocket, WebSocketHttpListenerDuplexStream stream)
872                 : base()
873             {
874                 _webSocket = webSocket;
875                 _currentStream = stream;
876             }
877 
878             public int BytesTransferred
879             {
880                 get { return _bytesTransferred; }
881             }
882 
883             public byte[] Buffer
884             {
885                 get { return _buffer; }
886             }
887 
888             // BufferList property.
889             // Mutually exclusive with Buffer.
890             // Setting this property with an existing non-null Buffer will cause an assert.
891             public IList<ArraySegment<byte>> BufferList
892             {
893                 get { return _bufferList; }
894                 set
895                 {
896                     Debug.Assert(!_shouldCloseOutput, "'m_ShouldCloseOutput' MUST be 'false' at this point.");
897                     Debug.Assert(value == null || _buffer == null,
898                         "Either 'm_Buffer' or 'm_BufferList' MUST be NULL.");
899                     Debug.Assert(_operating == Free,
900                         "This property can only be modified if no IO operation is outstanding.");
901                     Debug.Assert(value == null || value.Count == 2,
902                         "This list can only be 'NULL' or MUST have exactly '2' items.");
903                     _bufferList = value;
904                 }
905             }
906 
907             public bool ShouldCloseOutput
908             {
909                 get { return _shouldCloseOutput; }
910             }
911 
912             public int Offset
913             {
914                 get { return _offset; }
915             }
916 
917             public int Count
918             {
919                 get { return _count; }
920             }
921 
922             public Exception Exception
923             {
924                 get { return _exception; }
925             }
926 
927             public ushort EntityChunkCount
928             {
929                 get
930                 {
931                     if (_dataChunks == null)
932                     {
933                         return 0;
934                     }
935 
936                     return _dataChunkCount;
937                 }
938             }
939 
940             internal unsafe NativeOverlapped* NativeOverlapped
941             {
942                 get
943                 {
944 #if DEBUG
945                     Debug.Assert(Interlocked.Increment(ref _nativeOverlappedUsed) == 1, "NativeOverlapped reused.");
946 #endif
947                     return _ptrNativeOverlapped;
948                 }
949             }
950 
951             public IntPtr EntityChunks
952             {
953                 get
954                 {
955                     if (_dataChunks == null)
956                     {
957                         return IntPtr.Zero;
958                     }
959 
960                     return Marshal.UnsafeAddrOfPinnedArrayElement(_dataChunks, 0);
961                 }
962             }
963 
964             public WebSocketHttpListenerDuplexStream CurrentStream
965             {
966                 get { return _currentStream; }
967             }
968 
969             public event EventHandler<HttpListenerAsyncEventArgs> Completed
970             {
971                 add
972                 {
973                     m_Completed += value;
974                 }
975                 remove
976                 {
977                     m_Completed -= value;
978                 }
979             }
980 
OnCompleted(HttpListenerAsyncEventArgs e)981             protected virtual void OnCompleted(HttpListenerAsyncEventArgs e)
982             {
983                 m_Completed?.Invoke(e._currentStream, e);
984             }
985 
SetShouldCloseOutput()986             public void SetShouldCloseOutput()
987             {
988                 _bufferList = null;
989                 _buffer = null;
990                 _shouldCloseOutput = true;
991             }
992 
Dispose()993             public void Dispose()
994             {
995                 // Remember that Dispose was called.
996                 _disposeCalled = true;
997 
998                 // Check if this object is in-use for an async socket operation.
999                 if (Interlocked.CompareExchange(ref _operating, Disposed, Free) != Free)
1000                 {
1001                     // Either already disposed or will be disposed when current operation completes.
1002                     return;
1003                 }
1004 
1005                 // Don't bother finalizing later.
1006                 GC.SuppressFinalize(this);
1007             }
1008 
InitializeOverlapped(ThreadPoolBoundHandle boundHandle)1009             private unsafe void InitializeOverlapped(ThreadPoolBoundHandle boundHandle)
1010             {
1011 #if DEBUG
1012                 DebugRefCountAllocNativeOverlapped();
1013 #endif
1014                 _boundHandle = boundHandle;
1015                 _ptrNativeOverlapped = boundHandle.AllocateNativeOverlapped(CompletionPortCallback, null, null);
1016             }
1017 
1018             // Method to clean up any existing Overlapped object and related state variables.
FreeOverlapped(bool checkForShutdown)1019             private unsafe void FreeOverlapped(bool checkForShutdown)
1020             {
1021                 if (!checkForShutdown || !Environment.HasShutdownStarted)
1022                 {
1023                     // Free the overlapped object
1024                     if (_ptrNativeOverlapped != null)
1025                     {
1026 #if DEBUG
1027                         DebugRefCountReleaseNativeOverlapped();
1028 #endif
1029                         _boundHandle.FreeNativeOverlapped(_ptrNativeOverlapped);
1030                         _ptrNativeOverlapped = null;
1031                     }
1032 
1033                     if (_dataChunksGCHandle.IsAllocated)
1034                     {
1035                         _dataChunksGCHandle.Free();
1036                         _dataChunks = null;
1037                     }
1038                 }
1039             }
1040 
1041             // Method called to prepare for a native async http.sys call.
1042             // This method performs the tasks common to all http.sys operations.
StartOperationCommon(WebSocketHttpListenerDuplexStream currentStream, ThreadPoolBoundHandle boundHandle)1043             internal void StartOperationCommon(WebSocketHttpListenerDuplexStream currentStream, ThreadPoolBoundHandle boundHandle)
1044             {
1045                 // Change status to "in-use".
1046                 if (Interlocked.CompareExchange(ref _operating, InProgress, Free) != Free)
1047                 {
1048                     // If it was already "in-use" check if Dispose was called.
1049                     if (_disposeCalled)
1050                     {
1051                         // Dispose was called - throw ObjectDisposed.
1052                         throw new ObjectDisposedException(GetType().FullName);
1053                     }
1054 
1055                     Debug.Assert(false, "Only one outstanding async operation is allowed per HttpListenerAsyncEventArgs instance.");
1056                     // Only one at a time.
1057                     throw new InvalidOperationException();
1058                 }
1059 
1060                 // HttpSendResponseEntityBody can return ERROR_INVALID_PARAMETER if the InternalHigh field of the overlapped
1061                 // is not IntPtr.Zero, so we have to reset this field because we are reusing the Overlapped.
1062                 // When using the IAsyncResult based approach of HttpListenerResponseStream the Overlapped is reinitialized
1063                 // for each operation by the CLR when returned from the OverlappedDataCache.
1064 
1065                 InitializeOverlapped(boundHandle);
1066 
1067                 _exception = null;
1068                 _bytesTransferred = 0;
1069             }
1070 
StartOperationReceive()1071             internal void StartOperationReceive()
1072             {
1073                 // Remember the operation type.
1074                 _completedOperation = HttpListenerAsyncOperation.Receive;
1075             }
1076 
StartOperationSend()1077             internal void StartOperationSend()
1078             {
1079                 UpdateDataChunk();
1080 
1081                 // Remember the operation type.
1082                 _completedOperation = HttpListenerAsyncOperation.Send;
1083             }
1084 
SetBuffer(byte[] buffer, int offset, int count)1085             public void SetBuffer(byte[] buffer, int offset, int count)
1086             {
1087                 Debug.Assert(!_shouldCloseOutput, "'m_ShouldCloseOutput' MUST be 'false' at this point.");
1088                 Debug.Assert(buffer == null || _bufferList == null, "Either 'm_Buffer' or 'm_BufferList' MUST be NULL.");
1089                 _buffer = buffer;
1090                 _offset = offset;
1091                 _count = count;
1092             }
1093 
UpdateDataChunk()1094             private unsafe void UpdateDataChunk()
1095             {
1096                 if (_dataChunks == null)
1097                 {
1098                     _dataChunks = new Interop.HttpApi.HTTP_DATA_CHUNK[2];
1099                     _dataChunksGCHandle = GCHandle.Alloc(_dataChunks, GCHandleType.Pinned);
1100                     _dataChunks[0] = new Interop.HttpApi.HTTP_DATA_CHUNK();
1101                     _dataChunks[0].DataChunkType = Interop.HttpApi.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory;
1102                     _dataChunks[1] = new Interop.HttpApi.HTTP_DATA_CHUNK();
1103                     _dataChunks[1].DataChunkType = Interop.HttpApi.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory;
1104                 }
1105 
1106                 Debug.Assert(_buffer == null || _bufferList == null, "Either 'm_Buffer' or 'm_BufferList' MUST be NULL.");
1107                 Debug.Assert(_shouldCloseOutput || _buffer != null || _bufferList != null, "Either 'm_Buffer' or 'm_BufferList' MUST NOT be NULL.");
1108 
1109                 // The underlying byte[] m_Buffer or each m_BufferList[].Array are pinned already
1110                 if (_buffer != null)
1111                 {
1112                     UpdateDataChunk(0, _buffer, _offset, _count);
1113                     UpdateDataChunk(1, null, 0, 0);
1114                     _dataChunkCount = 1;
1115                 }
1116                 else if (_bufferList != null)
1117                 {
1118                     Debug.Assert(_bufferList != null && _bufferList.Count == 2,
1119                         "'m_BufferList' MUST NOT be NULL and have exactly '2' items at this point.");
1120                     UpdateDataChunk(0, _bufferList[0].Array, _bufferList[0].Offset, _bufferList[0].Count);
1121                     UpdateDataChunk(1, _bufferList[1].Array, _bufferList[1].Offset, _bufferList[1].Count);
1122                     _dataChunkCount = 2;
1123                 }
1124                 else
1125                 {
1126                     Debug.Assert(_shouldCloseOutput, "'m_ShouldCloseOutput' MUST be 'true' at this point.");
1127                     _dataChunks = null;
1128                 }
1129             }
1130 
UpdateDataChunk(int index, byte[] buffer, int offset, int count)1131             private unsafe void UpdateDataChunk(int index, byte[] buffer, int offset, int count)
1132             {
1133                 if (buffer == null)
1134                 {
1135                     _dataChunks[index].pBuffer = null;
1136                     _dataChunks[index].BufferLength = 0;
1137                     return;
1138                 }
1139 
1140                 if (_webSocket.InternalBuffer.IsInternalBuffer(buffer, offset, count))
1141                 {
1142                     _dataChunks[index].pBuffer = (byte*)(_webSocket.InternalBuffer.ToIntPtr(offset));
1143                 }
1144                 else
1145                 {
1146                     _dataChunks[index].pBuffer =
1147                         (byte*)_webSocket.InternalBuffer.ConvertPinnedSendPayloadToNative(buffer, offset, count);
1148                 }
1149 
1150                 _dataChunks[index].BufferLength = (uint)count;
1151             }
1152 
1153             // Method to mark this object as no longer "in-use".
1154             // Will also execute a Dispose deferred because I/O was in progress.
Complete()1155             internal void Complete()
1156             {
1157                 FreeOverlapped(false);
1158                 // Mark as not in-use
1159                 Interlocked.Exchange(ref _operating, Free);
1160 
1161                 // Check for deferred Dispose().
1162                 // The deferred Dispose is not guaranteed if Dispose is called while an operation is in progress.
1163                 // The m_DisposeCalled variable is not managed in a thread-safe manner on purpose for performance.
1164                 if (_disposeCalled)
1165                 {
1166                     Dispose();
1167                 }
1168             }
1169 
1170             // Method to update internal state after sync or async completion.
SetResults(Exception exception, int bytesTransferred)1171             private void SetResults(Exception exception, int bytesTransferred)
1172             {
1173                 _exception = exception;
1174                 _bytesTransferred = bytesTransferred;
1175             }
1176 
FinishOperationFailure(Exception exception, bool syncCompletion)1177             internal void FinishOperationFailure(Exception exception, bool syncCompletion)
1178             {
1179                 SetResults(exception, 0);
1180 
1181                 if (NetEventSource.IsEnabled)
1182                 {
1183                     string methodName = _completedOperation == HttpListenerAsyncOperation.Receive ? nameof(ReadAsyncFast) : nameof(WriteAsyncFast);
1184                     NetEventSource.Error(_currentStream, $"{methodName} {exception.ToString()}");
1185                 }
1186 
1187                 Complete();
1188                 OnCompleted(this);
1189             }
1190 
FinishOperationSuccess(int bytesTransferred, bool syncCompletion)1191             internal void FinishOperationSuccess(int bytesTransferred, bool syncCompletion)
1192             {
1193                 SetResults(null, bytesTransferred);
1194 
1195                 if (NetEventSource.IsEnabled)
1196                 {
1197                     if (_buffer != null && NetEventSource.IsEnabled)
1198                     {
1199                         string methodName = _completedOperation == HttpListenerAsyncOperation.Receive ? nameof(ReadAsyncFast) : nameof(WriteAsyncFast);
1200                         NetEventSource.DumpBuffer(_currentStream, _buffer, _offset, bytesTransferred, methodName);
1201                     }
1202                     else if (_bufferList != null)
1203                     {
1204                         Debug.Assert(_completedOperation == HttpListenerAsyncOperation.Send,
1205                             "'BufferList' is only supported for send operations.");
1206 
1207                         foreach (ArraySegment<byte> buffer in BufferList)
1208                         {
1209                             NetEventSource.DumpBuffer(this, buffer.Array, buffer.Offset, buffer.Count, nameof(WriteAsyncFast));
1210                         }
1211                     }
1212                 }
1213 
1214                 if (_shouldCloseOutput)
1215                 {
1216                     _currentStream._outputStream.SetClosedFlag();
1217                 }
1218 
1219                 // Complete the operation and raise completion event.
1220                 Complete();
1221                 OnCompleted(this);
1222             }
1223 
CompletionPortCallback(uint errorCode, uint numBytes, NativeOverlapped* nativeOverlapped)1224             private unsafe void CompletionPortCallback(uint errorCode, uint numBytes, NativeOverlapped* nativeOverlapped)
1225             {
1226                 if (errorCode == Interop.HttpApi.ERROR_SUCCESS ||
1227                     errorCode == Interop.HttpApi.ERROR_HANDLE_EOF)
1228                 {
1229                     FinishOperationSuccess((int)numBytes, false);
1230                 }
1231                 else
1232                 {
1233                     FinishOperationFailure(new HttpListenerException((int)errorCode), false);
1234                 }
1235             }
1236 
1237             public enum HttpListenerAsyncOperation
1238             {
1239                 None,
1240                 Receive,
1241                 Send
1242             }
1243         }
1244     }
1245 }
1246