1 using Microsoft.Win32.SafeHandles;
2 using System;
3 using System.Collections;
4 using System.IO;
5 using System.Linq;
6 using System.Runtime.ConstrainedExecution;
7 using System.Runtime.InteropServices;
8 using System.Text;
9 using System.Threading;
10 
11 namespace Ansible.Process
12 {
13     internal class NativeHelpers
14     {
15         [StructLayout(LayoutKind.Sequential)]
16         public class SECURITY_ATTRIBUTES
17         {
18             public UInt32 nLength;
19             public IntPtr lpSecurityDescriptor;
20             public bool bInheritHandle = false;
SECURITY_ATTRIBUTES()21             public SECURITY_ATTRIBUTES()
22             {
23                 nLength = (UInt32)Marshal.SizeOf(this);
24             }
25         }
26 
27         [StructLayout(LayoutKind.Sequential)]
28         public class STARTUPINFO
29         {
30             public UInt32 cb;
31             public IntPtr lpReserved;
32             [MarshalAs(UnmanagedType.LPWStr)] public string lpDesktop;
33             [MarshalAs(UnmanagedType.LPWStr)] public string lpTitle;
34             public UInt32 dwX;
35             public UInt32 dwY;
36             public UInt32 dwXSize;
37             public UInt32 dwYSize;
38             public UInt32 dwXCountChars;
39             public UInt32 dwYCountChars;
40             public UInt32 dwFillAttribute;
41             public StartupInfoFlags dwFlags;
42             public UInt16 wShowWindow;
43             public UInt16 cbReserved2;
44             public IntPtr lpReserved2;
45             public SafeFileHandle hStdInput;
46             public SafeFileHandle hStdOutput;
47             public SafeFileHandle hStdError;
STARTUPINFO()48             public STARTUPINFO()
49             {
50                 cb = (UInt32)Marshal.SizeOf(this);
51             }
52         }
53 
54         [StructLayout(LayoutKind.Sequential)]
55         public class STARTUPINFOEX
56         {
57             public STARTUPINFO startupInfo;
58             public IntPtr lpAttributeList;
STARTUPINFOEX()59             public STARTUPINFOEX()
60             {
61                 startupInfo = new STARTUPINFO();
62                 startupInfo.cb = (UInt32)Marshal.SizeOf(this);
63             }
64         }
65 
66         [StructLayout(LayoutKind.Sequential)]
67         public struct PROCESS_INFORMATION
68         {
69             public IntPtr hProcess;
70             public IntPtr hThread;
71             public int dwProcessId;
72             public int dwThreadId;
73         }
74 
75         [Flags]
76         public enum ProcessCreationFlags : uint
77         {
78             CREATE_NEW_CONSOLE = 0x00000010,
79             CREATE_UNICODE_ENVIRONMENT = 0x00000400,
80             EXTENDED_STARTUPINFO_PRESENT = 0x00080000
81         }
82 
83         [Flags]
84         public enum StartupInfoFlags : uint
85         {
86             USESTDHANDLES = 0x00000100
87         }
88 
89         [Flags]
90         public enum HandleFlags : uint
91         {
92             None = 0,
93             INHERIT = 1
94         }
95     }
96 
97     internal class NativeMethods
98     {
99         [DllImport("kernel32.dll", SetLastError = true)]
AllocConsole()100         public static extern bool AllocConsole();
101 
102         [DllImport("shell32.dll", SetLastError = true)]
CommandLineToArgvW( [MarshalAs(UnmanagedType.LPWStr)] string lpCmdLine, out int pNumArgs)103         public static extern SafeMemoryBuffer CommandLineToArgvW(
104             [MarshalAs(UnmanagedType.LPWStr)] string lpCmdLine,
105             out int pNumArgs);
106 
107         [DllImport("kernel32.dll", SetLastError = true)]
CreatePipe( out SafeFileHandle hReadPipe, out SafeFileHandle hWritePipe, NativeHelpers.SECURITY_ATTRIBUTES lpPipeAttributes, UInt32 nSize)108         public static extern bool CreatePipe(
109             out SafeFileHandle hReadPipe,
110             out SafeFileHandle hWritePipe,
111             NativeHelpers.SECURITY_ATTRIBUTES lpPipeAttributes,
112             UInt32 nSize);
113 
114         [DllImport("kernel32.dll", SetLastError = true, CharSet = CharSet.Unicode)]
CreateProcessW( [MarshalAs(UnmanagedType.LPWStr)] string lpApplicationName, StringBuilder lpCommandLine, IntPtr lpProcessAttributes, IntPtr lpThreadAttributes, bool bInheritHandles, NativeHelpers.ProcessCreationFlags dwCreationFlags, SafeMemoryBuffer lpEnvironment, [MarshalAs(UnmanagedType.LPWStr)] string lpCurrentDirectory, NativeHelpers.STARTUPINFOEX lpStartupInfo, out NativeHelpers.PROCESS_INFORMATION lpProcessInformation)115         public static extern bool CreateProcessW(
116             [MarshalAs(UnmanagedType.LPWStr)] string lpApplicationName,
117             StringBuilder lpCommandLine,
118             IntPtr lpProcessAttributes,
119             IntPtr lpThreadAttributes,
120             bool bInheritHandles,
121             NativeHelpers.ProcessCreationFlags dwCreationFlags,
122             SafeMemoryBuffer lpEnvironment,
123             [MarshalAs(UnmanagedType.LPWStr)] string lpCurrentDirectory,
124             NativeHelpers.STARTUPINFOEX lpStartupInfo,
125             out NativeHelpers.PROCESS_INFORMATION lpProcessInformation);
126 
127         [DllImport("kernel32.dll", SetLastError = true)]
FreeConsole()128         public static extern bool FreeConsole();
129 
130         [DllImport("kernel32.dll", SetLastError = true)]
GetConsoleWindow()131         public static extern IntPtr GetConsoleWindow();
132 
133         [DllImport("kernel32.dll", SetLastError = true)]
GetExitCodeProcess( SafeWaitHandle hProcess, out UInt32 lpExitCode)134         public static extern bool GetExitCodeProcess(
135             SafeWaitHandle hProcess,
136             out UInt32 lpExitCode);
137 
138         [DllImport("kernel32.dll", SetLastError = true, CharSet = CharSet.Unicode)]
SearchPathW( [MarshalAs(UnmanagedType.LPWStr)] string lpPath, [MarshalAs(UnmanagedType.LPWStr)] string lpFileName, [MarshalAs(UnmanagedType.LPWStr)] string lpExtension, UInt32 nBufferLength, [MarshalAs(UnmanagedType.LPTStr)] StringBuilder lpBuffer, out IntPtr lpFilePart)139         public static extern uint SearchPathW(
140             [MarshalAs(UnmanagedType.LPWStr)] string lpPath,
141             [MarshalAs(UnmanagedType.LPWStr)] string lpFileName,
142             [MarshalAs(UnmanagedType.LPWStr)] string lpExtension,
143             UInt32 nBufferLength,
144             [MarshalAs(UnmanagedType.LPTStr)] StringBuilder lpBuffer,
145             out IntPtr lpFilePart);
146 
147         [DllImport("kernel32.dll", SetLastError = true)]
SetConsoleCP( UInt32 wCodePageID)148         public static extern bool SetConsoleCP(
149             UInt32 wCodePageID);
150 
151         [DllImport("kernel32.dll", SetLastError = true)]
SetConsoleOutputCP( UInt32 wCodePageID)152         public static extern bool SetConsoleOutputCP(
153             UInt32 wCodePageID);
154 
155         [DllImport("kernel32.dll", SetLastError = true)]
SetHandleInformation( SafeFileHandle hObject, NativeHelpers.HandleFlags dwMask, NativeHelpers.HandleFlags dwFlags)156         public static extern bool SetHandleInformation(
157             SafeFileHandle hObject,
158             NativeHelpers.HandleFlags dwMask,
159             NativeHelpers.HandleFlags dwFlags);
160 
161         [DllImport("kernel32.dll")]
WaitForSingleObject( SafeWaitHandle hHandle, UInt32 dwMilliseconds)162         public static extern UInt32 WaitForSingleObject(
163             SafeWaitHandle hHandle,
164             UInt32 dwMilliseconds);
165     }
166 
167     internal class SafeMemoryBuffer : SafeHandleZeroOrMinusOneIsInvalid
168     {
SafeMemoryBuffer()169         public SafeMemoryBuffer() : base(true) { }
SafeMemoryBuffer(int cb)170         public SafeMemoryBuffer(int cb) : base(true)
171         {
172             base.SetHandle(Marshal.AllocHGlobal(cb));
173         }
SafeMemoryBuffer(IntPtr handle)174         public SafeMemoryBuffer(IntPtr handle) : base(true)
175         {
176             base.SetHandle(handle);
177         }
178 
179         [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)]
ReleaseHandle()180         protected override bool ReleaseHandle()
181         {
182             Marshal.FreeHGlobal(handle);
183             return true;
184         }
185     }
186 
187     public class Win32Exception : System.ComponentModel.Win32Exception
188     {
189         private string _msg;
190 
Win32Exception(string message)191         public Win32Exception(string message) : this(Marshal.GetLastWin32Error(), message) { }
Win32Exception(int errorCode, string message)192         public Win32Exception(int errorCode, string message) : base(errorCode)
193         {
194             _msg = String.Format("{0} ({1}, Win32ErrorCode {2})", message, base.Message, errorCode);
195         }
196 
197         public override string Message { get { return _msg; } }
operator Win32Exception(string message)198         public static explicit operator Win32Exception(string message) { return new Win32Exception(message); }
199     }
200 
201     public class Result
202     {
203         public string StandardOut { get; internal set; }
204         public string StandardError { get; internal set; }
205         public uint ExitCode { get; internal set; }
206     }
207 
208     public class ProcessUtil
209     {
210         /// <summary>
211         /// Parses a command line string into an argv array according to the Windows rules
212         /// </summary>
213         /// <param name="lpCommandLine">The command line to parse</param>
214         /// <returns>An array of arguments interpreted by Windows</returns>
ParseCommandLine(string lpCommandLine)215         public static string[] ParseCommandLine(string lpCommandLine)
216         {
217             int numArgs;
218             using (SafeMemoryBuffer buf = NativeMethods.CommandLineToArgvW(lpCommandLine, out numArgs))
219             {
220                 if (buf.IsInvalid)
221                     throw new Win32Exception("Error parsing command line");
222                 IntPtr[] strptrs = new IntPtr[numArgs];
223                 Marshal.Copy(buf.DangerousGetHandle(), strptrs, 0, numArgs);
224                 return strptrs.Select(s => Marshal.PtrToStringUni(s)).ToArray();
225             }
226         }
227 
228         /// <summary>
229         /// Searches the path for the executable specified. Will throw a Win32Exception if the file is not found.
230         /// </summary>
231         /// <param name="lpFileName">The executable to search for</param>
232         /// <returns>The full path of the executable to search for</returns>
SearchPath(string lpFileName)233         public static string SearchPath(string lpFileName)
234         {
235             StringBuilder sbOut = new StringBuilder(0);
236             IntPtr filePartOut = IntPtr.Zero;
237             UInt32 res = NativeMethods.SearchPathW(null, lpFileName, null, (UInt32)sbOut.Capacity, sbOut, out filePartOut);
238             if (res == 0)
239             {
240                 int lastErr = Marshal.GetLastWin32Error();
241                 if (lastErr == 2)  // ERROR_FILE_NOT_FOUND
242                     throw new FileNotFoundException(String.Format("Could not find file '{0}'.", lpFileName));
243                 else
244                     throw new Win32Exception(String.Format("SearchPathW({0}) failed to get buffer length", lpFileName));
245             }
246 
247             sbOut.EnsureCapacity((int)res);
248             if (NativeMethods.SearchPathW(null, lpFileName, null, (UInt32)sbOut.Capacity, sbOut, out filePartOut) == 0)
249                 throw new Win32Exception(String.Format("SearchPathW({0}) failed", lpFileName));
250 
251             return sbOut.ToString();
252         }
253 
CreateProcess(string command)254         public static Result CreateProcess(string command)
255         {
256             return CreateProcess(null, command, null, null, String.Empty);
257         }
258 
CreateProcess(string lpApplicationName, string lpCommandLine, string lpCurrentDirectory, IDictionary environment)259         public static Result CreateProcess(string lpApplicationName, string lpCommandLine, string lpCurrentDirectory,
260             IDictionary environment)
261         {
262             return CreateProcess(lpApplicationName, lpCommandLine, lpCurrentDirectory, environment, String.Empty);
263         }
264 
CreateProcess(string lpApplicationName, string lpCommandLine, string lpCurrentDirectory, IDictionary environment, string stdin)265         public static Result CreateProcess(string lpApplicationName, string lpCommandLine, string lpCurrentDirectory,
266             IDictionary environment, string stdin)
267         {
268             return CreateProcess(lpApplicationName, lpCommandLine, lpCurrentDirectory, environment, stdin, null);
269         }
270 
CreateProcess(string lpApplicationName, string lpCommandLine, string lpCurrentDirectory, IDictionary environment, byte[] stdin)271         public static Result CreateProcess(string lpApplicationName, string lpCommandLine, string lpCurrentDirectory,
272             IDictionary environment, byte[] stdin)
273         {
274             return CreateProcess(lpApplicationName, lpCommandLine, lpCurrentDirectory, environment, stdin, null);
275         }
276 
CreateProcess(string lpApplicationName, string lpCommandLine, string lpCurrentDirectory, IDictionary environment, string stdin, string outputEncoding)277         public static Result CreateProcess(string lpApplicationName, string lpCommandLine, string lpCurrentDirectory,
278             IDictionary environment, string stdin, string outputEncoding)
279         {
280             byte[] stdinBytes;
281             if (String.IsNullOrEmpty(stdin))
282                 stdinBytes = new byte[0];
283             else
284             {
285                 if (!stdin.EndsWith(Environment.NewLine))
286                     stdin += Environment.NewLine;
287                 stdinBytes = new UTF8Encoding(false).GetBytes(stdin);
288             }
289             return CreateProcess(lpApplicationName, lpCommandLine, lpCurrentDirectory, environment, stdinBytes, outputEncoding);
290         }
291 
292         /// <summary>
293         /// Creates a process based on the CreateProcess API call.
294         /// </summary>
295         /// <param name="lpApplicationName">The name of the executable or batch file to execute</param>
296         /// <param name="lpCommandLine">The command line to execute, typically this includes lpApplication as the first argument</param>
297         /// <param name="lpCurrentDirectory">The full path to the current directory for the process, null will have the same cwd as the calling process</param>
298         /// <param name="environment">A dictionary of key/value pairs to define the new process environment</param>
299         /// <param name="stdin">A byte array to send over the stdin pipe</param>
300         /// <param name="outputEncoding">The character encoding for decoding stdout/stderr output of the process.</param>
301         /// <returns>Result object that contains the command output and return code</returns>
CreateProcess(string lpApplicationName, string lpCommandLine, string lpCurrentDirectory, IDictionary environment, byte[] stdin, string outputEncoding)302         public static Result CreateProcess(string lpApplicationName, string lpCommandLine, string lpCurrentDirectory,
303             IDictionary environment, byte[] stdin, string outputEncoding)
304         {
305             NativeHelpers.ProcessCreationFlags creationFlags = NativeHelpers.ProcessCreationFlags.CREATE_UNICODE_ENVIRONMENT |
306                 NativeHelpers.ProcessCreationFlags.EXTENDED_STARTUPINFO_PRESENT;
307             NativeHelpers.PROCESS_INFORMATION pi = new NativeHelpers.PROCESS_INFORMATION();
308             NativeHelpers.STARTUPINFOEX si = new NativeHelpers.STARTUPINFOEX();
309             si.startupInfo.dwFlags = NativeHelpers.StartupInfoFlags.USESTDHANDLES;
310 
311             SafeFileHandle stdoutRead, stdoutWrite, stderrRead, stderrWrite, stdinRead, stdinWrite;
312             CreateStdioPipes(si, out stdoutRead, out stdoutWrite, out stderrRead, out stderrWrite, out stdinRead,
313                 out stdinWrite);
314             FileStream stdinStream = new FileStream(stdinWrite, FileAccess.Write);
315 
316             // $null from PowerShell ends up as an empty string, we need to convert back as an empty string doesn't
317             // make sense for these parameters
318             if (lpApplicationName == "")
319                 lpApplicationName = null;
320 
321             if (lpCurrentDirectory == "")
322                 lpCurrentDirectory = null;
323 
324             using (SafeMemoryBuffer lpEnvironment = CreateEnvironmentPointer(environment))
325             {
326                 // Create console with utf-8 CP if no existing console is present
327                 bool isConsole = false;
328                 if (NativeMethods.GetConsoleWindow() == IntPtr.Zero)
329                 {
330                     isConsole = NativeMethods.AllocConsole();
331 
332                     // Set console input/output codepage to UTF-8
333                     NativeMethods.SetConsoleCP(65001);
334                     NativeMethods.SetConsoleOutputCP(65001);
335                 }
336 
337                 try
338                 {
339                     StringBuilder commandLine = new StringBuilder(lpCommandLine);
340                     if (!NativeMethods.CreateProcessW(lpApplicationName, commandLine, IntPtr.Zero, IntPtr.Zero,
341                         true, creationFlags, lpEnvironment, lpCurrentDirectory, si, out pi))
342                     {
343                         throw new Win32Exception("CreateProcessW() failed");
344                     }
345                 }
346                 finally
347                 {
348                     if (isConsole)
349                         NativeMethods.FreeConsole();
350                 }
351             }
352 
353             return WaitProcess(stdoutRead, stdoutWrite, stderrRead, stderrWrite, stdinStream, stdin, pi.hProcess,
354                 outputEncoding);
355         }
356 
CreateStdioPipes(NativeHelpers.STARTUPINFOEX si, out SafeFileHandle stdoutRead, out SafeFileHandle stdoutWrite, out SafeFileHandle stderrRead, out SafeFileHandle stderrWrite, out SafeFileHandle stdinRead, out SafeFileHandle stdinWrite)357         internal static void CreateStdioPipes(NativeHelpers.STARTUPINFOEX si, out SafeFileHandle stdoutRead,
358             out SafeFileHandle stdoutWrite, out SafeFileHandle stderrRead, out SafeFileHandle stderrWrite,
359             out SafeFileHandle stdinRead, out SafeFileHandle stdinWrite)
360         {
361             NativeHelpers.SECURITY_ATTRIBUTES pipesec = new NativeHelpers.SECURITY_ATTRIBUTES();
362             pipesec.bInheritHandle = true;
363 
364             if (!NativeMethods.CreatePipe(out stdoutRead, out stdoutWrite, pipesec, 0))
365                 throw new Win32Exception("STDOUT pipe setup failed");
366             if (!NativeMethods.SetHandleInformation(stdoutRead, NativeHelpers.HandleFlags.INHERIT, 0))
367                 throw new Win32Exception("STDOUT pipe handle setup failed");
368 
369             if (!NativeMethods.CreatePipe(out stderrRead, out stderrWrite, pipesec, 0))
370                 throw new Win32Exception("STDERR pipe setup failed");
371             if (!NativeMethods.SetHandleInformation(stderrRead, NativeHelpers.HandleFlags.INHERIT, 0))
372                 throw new Win32Exception("STDERR pipe handle setup failed");
373 
374             if (!NativeMethods.CreatePipe(out stdinRead, out stdinWrite, pipesec, 0))
375                 throw new Win32Exception("STDIN pipe setup failed");
376             if (!NativeMethods.SetHandleInformation(stdinWrite, NativeHelpers.HandleFlags.INHERIT, 0))
377                 throw new Win32Exception("STDIN pipe handle setup failed");
378 
379             si.startupInfo.hStdOutput = stdoutWrite;
380             si.startupInfo.hStdError = stderrWrite;
381             si.startupInfo.hStdInput = stdinRead;
382         }
383 
CreateEnvironmentPointer(IDictionary environment)384         internal static SafeMemoryBuffer CreateEnvironmentPointer(IDictionary environment)
385         {
386             IntPtr lpEnvironment = IntPtr.Zero;
387             if (environment != null && environment.Count > 0)
388             {
389                 StringBuilder environmentString = new StringBuilder();
390                 foreach (DictionaryEntry kv in environment)
391                     environmentString.AppendFormat("{0}={1}\0", kv.Key, kv.Value);
392                 environmentString.Append('\0');
393 
394                 lpEnvironment = Marshal.StringToHGlobalUni(environmentString.ToString());
395             }
396             return new SafeMemoryBuffer(lpEnvironment);
397         }
398 
WaitProcess(SafeFileHandle stdoutRead, SafeFileHandle stdoutWrite, SafeFileHandle stderrRead, SafeFileHandle stderrWrite, FileStream stdinStream, byte[] stdin, IntPtr hProcess, string outputEncoding = null)399         internal static Result WaitProcess(SafeFileHandle stdoutRead, SafeFileHandle stdoutWrite, SafeFileHandle stderrRead,
400             SafeFileHandle stderrWrite, FileStream stdinStream, byte[] stdin, IntPtr hProcess, string outputEncoding = null)
401         {
402             // Default to using UTF-8 as the output encoding, this should be a sane default for most scenarios.
403             outputEncoding = String.IsNullOrEmpty(outputEncoding) ? "utf-8" : outputEncoding;
404             Encoding encodingInstance = Encoding.GetEncoding(outputEncoding);
405 
406             FileStream stdoutFS = new FileStream(stdoutRead, FileAccess.Read, 4096);
407             StreamReader stdout = new StreamReader(stdoutFS, encodingInstance, true, 4096);
408             stdoutWrite.Close();
409 
410             FileStream stderrFS = new FileStream(stderrRead, FileAccess.Read, 4096);
411             StreamReader stderr = new StreamReader(stderrFS, encodingInstance, true, 4096);
412             stderrWrite.Close();
413 
414             stdinStream.Write(stdin, 0, stdin.Length);
415             stdinStream.Close();
416 
417             string stdoutStr, stderrStr = null;
418             GetProcessOutput(stdout, stderr, out stdoutStr, out stderrStr);
419             UInt32 rc = GetProcessExitCode(hProcess);
420 
421             return new Result
422             {
423                 StandardOut = stdoutStr,
424                 StandardError = stderrStr,
425                 ExitCode = rc
426             };
427         }
428 
GetProcessOutput(StreamReader stdoutStream, StreamReader stderrStream, out string stdout, out string stderr)429         internal static void GetProcessOutput(StreamReader stdoutStream, StreamReader stderrStream, out string stdout, out string stderr)
430         {
431             var sowait = new EventWaitHandle(false, EventResetMode.ManualReset);
432             var sewait = new EventWaitHandle(false, EventResetMode.ManualReset);
433             string so = null, se = null;
434             ThreadPool.QueueUserWorkItem((s) =>
435             {
436                 so = stdoutStream.ReadToEnd();
437                 sowait.Set();
438             });
439             ThreadPool.QueueUserWorkItem((s) =>
440             {
441                 se = stderrStream.ReadToEnd();
442                 sewait.Set();
443             });
444             foreach (var wh in new WaitHandle[] { sowait, sewait })
445                 wh.WaitOne();
446             stdout = so;
447             stderr = se;
448         }
449 
GetProcessExitCode(IntPtr processHandle)450         internal static UInt32 GetProcessExitCode(IntPtr processHandle)
451         {
452             SafeWaitHandle hProcess = new SafeWaitHandle(processHandle, true);
453             NativeMethods.WaitForSingleObject(hProcess, 0xFFFFFFFF);
454 
455             UInt32 exitCode;
456             if (!NativeMethods.GetExitCodeProcess(hProcess, out exitCode))
457                 throw new Win32Exception("GetExitCodeProcess() failed");
458             return exitCode;
459         }
460     }
461 }
462