1 // Copyright (c) Microsoft. All rights reserved.
2 // Licensed under the MIT license. See LICENSE file in the project root for full license information.
3 //-----------------------------------------------------------------------
4 // </copyright>
5 // <summary>Helper methods for dealing with 'await'able objects.</summary>
6 //-----------------------------------------------------------------------
7 
8 using System;
9 using System.Collections.Concurrent;
10 using System.Globalization;
11 using System.Runtime.CompilerServices;
12 using System.Threading;
13 using System.Threading.Tasks;
14 
15 namespace Microsoft.Build.Shared
16 {
17     /// <summary>
18     /// Class defining extension methods for awaitable objects.
19     /// </summary>
20     internal static class AwaitExtensions
21     {
22         /// <summary>
23         /// Synchronizes access to the staScheduler field.
24         /// </summary>
25         private static Object s_staSchedulerSync = new Object();
26 
27         /// <summary>
28         /// The singleton STA scheduler object.
29         /// </summary>
30         private static TaskScheduler s_staScheduler;
31 
32         /// <summary>
33         /// Gets the STA scheduler.
34         /// </summary>
35         internal static TaskScheduler OneSTAThreadPerTaskSchedulerInstance
36         {
37             get
38             {
39                 if (s_staScheduler == null)
40                 {
41                     lock (s_staSchedulerSync)
42                     {
43                         if (s_staScheduler == null)
44                         {
45                             s_staScheduler = new OneSTAThreadPerTaskScheduler();
46                         }
47                     }
48                 }
49 
50                 return s_staScheduler;
51             }
52         }
53 
54         /// <summary>
55         /// Provides await functionality for ordinary <see cref="WaitHandle"/>s.
56         /// </summary>
57         /// <param name="handle">The handle to wait on.</param>
58         /// <returns>The awaiter.</returns>
GetAwaiter(this WaitHandle handle)59         internal static TaskAwaiter GetAwaiter(this WaitHandle handle)
60         {
61             ErrorUtilities.VerifyThrowArgumentNull(handle, "handle");
62             return handle.ToTask().GetAwaiter();
63         }
64 
65         /// <summary>
66         /// Provides await functionality for an array of ordinary <see cref="WaitHandle"/>s.
67         /// </summary>
68         /// <param name="handles">The handles to wait on.</param>
69         /// <returns>The awaiter.</returns>
GetAwaiter(this WaitHandle[] handles)70         internal static TaskAwaiter<int> GetAwaiter(this WaitHandle[] handles)
71         {
72             ErrorUtilities.VerifyThrowArgumentNull(handles, "handle");
73             return handles.ToTask().GetAwaiter();
74         }
75 
76         /// <summary>
77         /// Creates a TPL Task that is marked as completed when a <see cref="WaitHandle"/> is signaled.
78         /// </summary>
79         /// <param name="handle">The handle whose signal triggers the task to be completed.  Do not use a <see cref="Mutex"/> here.</param>
80         /// <param name="timeout">The timeout (in milliseconds) after which the task will fault with a <see cref="TimeoutException"/> if the handle is not signaled by that time.</param>
81         /// <returns>A Task that is completed after the handle is signaled.</returns>
82         /// <remarks>
83         /// There is a (brief) time delay between when the handle is signaled and when the task is marked as completed.
84         /// </remarks>
ToTask(this WaitHandle handle, int timeout = Timeout.Infinite)85         internal static Task ToTask(this WaitHandle handle, int timeout = Timeout.Infinite)
86         {
87             return ToTask(new WaitHandle[1] { handle }, timeout);
88         }
89 
90         /// <summary>
91         /// Creates a TPL Task that is marked as completed when any <see cref="WaitHandle"/> in the array is signaled.
92         /// </summary>
93         /// <param name="handles">The handles whose signals triggers the task to be completed.  Do not use a <see cref="Mutex"/> here.</param>
94         /// <param name="timeout">The timeout (in milliseconds) after which the task will return a value of WaitTimeout.</param>
95         /// <returns>A Task that is completed after any handle is signaled.</returns>
96         /// <remarks>
97         /// There is a (brief) time delay between when the handles are signaled and when the task is marked as completed.
98         /// </remarks>
ToTask(this WaitHandle[] handles, int timeout = Timeout.Infinite)99         internal static Task<int> ToTask(this WaitHandle[] handles, int timeout = Timeout.Infinite)
100         {
101             ErrorUtilities.VerifyThrowArgumentNull(handles, "handle");
102 
103             var tcs = new TaskCompletionSource<int>();
104             int signalledHandle = WaitHandle.WaitAny(handles, 0);
105             if (signalledHandle != WaitHandle.WaitTimeout)
106             {
107                 // An optimization for if the handle is already signaled
108                 // to return a completed task.
109                 tcs.SetResult(signalledHandle);
110             }
111             else
112             {
113                 var localVariableInitLock = new object();
114                 var culture = CultureInfo.CurrentCulture;
115                 var uiCulture = CultureInfo.CurrentUICulture;
116                 lock (localVariableInitLock)
117                 {
118                     RegisteredWaitHandle[] callbackHandles = new RegisteredWaitHandle[handles.Length];
119                     for (int i = 0; i < handles.Length; i++)
120                     {
121                         callbackHandles[i] = ThreadPool.RegisterWaitForSingleObject(
122                             handles[i],
123                             (state, timedOut) =>
124                             {
125                                 int handleIndex = (int)state;
126                                 if (timedOut)
127                                 {
128                                     tcs.TrySetResult(WaitHandle.WaitTimeout);
129                                 }
130                                 else
131                                 {
132                                     tcs.TrySetResult(handleIndex);
133                                 }
134 
135                                 // We take a lock here to make sure the outer method has completed setting the local variable callbackHandles contents.
136                                 lock (localVariableInitLock)
137                                 {
138                                     foreach (var handle in callbackHandles)
139                                     {
140                                         handle.Unregister(null);
141                                     }
142                                 }
143                             },
144                             state: i,
145                             millisecondsTimeOutInterval: timeout,
146                             executeOnlyOnce: true);
147                     }
148                 }
149             }
150 
151             return tcs.Task;
152         }
153 
154         /// <summary>
155         /// A class which acts as a task scheduler and ensures each scheduled task gets its
156         /// own STA thread.
157         /// </summary>
158         private class OneSTAThreadPerTaskScheduler : TaskScheduler
159         {
160             /// <summary>
161             /// The current queue of tasks.
162             /// </summary>
163             private ConcurrentQueue<Task> _queuedTasks = new ConcurrentQueue<Task>();
164 
165             /// <summary>
166             /// Returns the list of queued tasks.
167             /// </summary>
GetScheduledTasks()168             protected override System.Collections.Generic.IEnumerable<Task> GetScheduledTasks()
169             {
170                 return _queuedTasks;
171             }
172 
173             /// <summary>
174             /// Queues a task to the scheduler.
175             /// </summary>
QueueTask(Task task)176             protected override void QueueTask(Task task)
177             {
178                 _queuedTasks.Enqueue(task);
179 
180                 ParameterizedThreadStart threadStart = new ParameterizedThreadStart((_) =>
181                 {
182                     Task t;
183                     if (_queuedTasks.TryDequeue(out t))
184                     {
185                         base.TryExecuteTask(t);
186                     }
187                 });
188 
189                 Thread thread = new Thread(threadStart);
190 #if FEATURE_APARTMENT_STATE
191                 thread.SetApartmentState(ApartmentState.STA);
192 #endif
193                 thread.Start(task);
194             }
195 
196             /// <summary>
197             /// Tries to execute the task immediately.  This method will always return false for the STA scheduler.
198             /// </summary>
TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued)199             protected override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued)
200             {
201                 // We don't get STA threads back here, so just deny the inline execution.
202                 return false;
203             }
204         }
205     }
206 }
207