1 // Licensed to the .NET Foundation under one or more agreements.
2 // The .NET Foundation licenses this file to you under the MIT license.
3 // See the LICENSE file in the project root for more information.
4 
5 using System.Data.Common;
6 using System.Data.SqlClient;
7 using System.Runtime.InteropServices;
8 
9 namespace System.Data.SqlClient
10 {
11     internal static partial class SNINativeMethodWrapper
12     {
13         private const string SNI = "sni.dll";
14 
15         private static int s_sniMaxComposedSpnLength = -1;
16 
17         private const int SniOpenTimeOut = -1; // infinite
18 
19         [UnmanagedFunctionPointer(CallingConvention.StdCall)]
SqlAsyncCallbackDelegate(IntPtr m_ConsKey, IntPtr pPacket, uint dwError)20         internal delegate void SqlAsyncCallbackDelegate(IntPtr m_ConsKey, IntPtr pPacket, uint dwError);
21 
22         internal static int SniMaxComposedSpnLength
23         {
24             get
25             {
26                 if (s_sniMaxComposedSpnLength == -1)
27                 {
28                     s_sniMaxComposedSpnLength = checked((int)GetSniMaxComposedSpnLength());
29                 }
30                 return s_sniMaxComposedSpnLength;
31             }
32         }
33 
34         #region Structs\Enums
35         [StructLayout(LayoutKind.Sequential)]
36         internal struct ConsumerInfo
37         {
38             internal int defaultBufferSize;
39             internal SqlAsyncCallbackDelegate readDelegate;
40             internal SqlAsyncCallbackDelegate writeDelegate;
41             internal IntPtr key;
42         }
43 
44         internal enum ConsumerNumber
45         {
46             SNI_Consumer_SNI,
47             SNI_Consumer_SSB,
48             SNI_Consumer_PacketIsReleased,
49             SNI_Consumer_Invalid,
50         }
51 
52         internal enum IOType
53         {
54             READ,
55             WRITE,
56         }
57 
58         internal enum PrefixEnum
59         {
60             UNKNOWN_PREFIX,
61             SM_PREFIX,
62             TCP_PREFIX,
63             NP_PREFIX,
64             VIA_PREFIX,
65             INVALID_PREFIX,
66         }
67 
68         internal enum ProviderEnum
69         {
70             HTTP_PROV,
71             NP_PROV,
72             SESSION_PROV,
73             SIGN_PROV,
74             SM_PROV,
75             SMUX_PROV,
76             SSL_PROV,
77             TCP_PROV,
78             VIA_PROV,
79             MAX_PROVS,
80             INVALID_PROV,
81         }
82 
83         internal enum QTypes
84         {
85             SNI_QUERY_CONN_INFO,
86             SNI_QUERY_CONN_BUFSIZE,
87             SNI_QUERY_CONN_KEY,
88             SNI_QUERY_CLIENT_ENCRYPT_POSSIBLE,
89             SNI_QUERY_SERVER_ENCRYPT_POSSIBLE,
90             SNI_QUERY_CERTIFICATE,
91             SNI_QUERY_LOCALDB_HMODULE,
92             SNI_QUERY_CONN_ENCRYPT,
93             SNI_QUERY_CONN_PROVIDERNUM,
94             SNI_QUERY_CONN_CONNID,
95             SNI_QUERY_CONN_PARENTCONNID,
96             SNI_QUERY_CONN_SECPKG,
97             SNI_QUERY_CONN_NETPACKETSIZE,
98             SNI_QUERY_CONN_NODENUM,
99             SNI_QUERY_CONN_PACKETSRECD,
100             SNI_QUERY_CONN_PACKETSSENT,
101             SNI_QUERY_CONN_PEERADDR,
102             SNI_QUERY_CONN_PEERPORT,
103             SNI_QUERY_CONN_LASTREADTIME,
104             SNI_QUERY_CONN_LASTWRITETIME,
105             SNI_QUERY_CONN_CONSUMER_ID,
106             SNI_QUERY_CONN_CONNECTTIME,
107             SNI_QUERY_CONN_HTTPENDPOINT,
108             SNI_QUERY_CONN_LOCALADDR,
109             SNI_QUERY_CONN_LOCALPORT,
110             SNI_QUERY_CONN_SSLHANDSHAKESTATE,
111             SNI_QUERY_CONN_SOBUFAUTOTUNING,
112             SNI_QUERY_CONN_SECPKGNAME,
113             SNI_QUERY_CONN_SECPKGMUTUALAUTH,
114             SNI_QUERY_CONN_CONSUMERCONNID,
115             SNI_QUERY_CONN_SNIUCI,
116             SNI_QUERY_CONN_SUPPORTS_EXTENDED_PROTECTION,
117             SNI_QUERY_CONN_CHANNEL_PROVIDES_AUTHENTICATION_CONTEXT,
118             SNI_QUERY_CONN_PEERID,
119             SNI_QUERY_CONN_SUPPORTS_SYNC_OVER_ASYNC,
120         }
121 
122         internal enum TransparentNetworkResolutionMode : byte
123         {
124             DisabledMode = 0,
125             SequentialMode,
126             ParallelMode
127         };
128 
129         [StructLayout(LayoutKind.Sequential)]
130         private struct Sni_Consumer_Info
131         {
132             public int DefaultUserDataLength;
133             public IntPtr ConsumerKey;
134             public IntPtr fnReadComp;
135             public IntPtr fnWriteComp;
136             public IntPtr fnTrace;
137             public IntPtr fnAcceptComp;
138             public uint dwNumProts;
139             public IntPtr rgListenInfo;
140             public IntPtr NodeAffinity;
141         }
142 
143         [StructLayout(LayoutKind.Sequential)]
144         private unsafe struct SNI_CLIENT_CONSUMER_INFO
145         {
146             public Sni_Consumer_Info ConsumerInfo;
147             [MarshalAs(UnmanagedType.LPWStr)]
148             public string wszConnectionString;
149             public PrefixEnum networkLibrary;
150             public byte* szSPN;
151             public uint cchSPN;
152             public byte* szInstanceName;
153             public uint cchInstanceName;
154             [MarshalAs(UnmanagedType.Bool)]
155             public bool fOverrideLastConnectCache;
156             [MarshalAs(UnmanagedType.Bool)]
157             public bool fSynchronousConnection;
158             public int timeout;
159             [MarshalAs(UnmanagedType.Bool)]
160             public bool fParallel;
161             public TransparentNetworkResolutionMode transparentNetworkResolution;
162             public int totalTimeout;
163             public bool isAzureSqlServerEndpoint;
164         }
165 
166         [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)]
167         internal struct SNI_Error
168         {
169             internal ProviderEnum provider;
170             [MarshalAs(UnmanagedType.ByValTStr, SizeConst = 261)]
171             internal string errorMessage;
172             internal uint nativeError;
173             internal uint sniError;
174             [MarshalAs(UnmanagedType.LPWStr)]
175             internal string fileName;
176             [MarshalAs(UnmanagedType.LPWStr)]
177             internal string function;
178             internal uint lineNumber;
179         }
180 
181         #endregion
182 
183         #region DLL Imports
184         [DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNIAddProviderWrapper")]
SNIAddProvider(SNIHandle pConn, ProviderEnum ProvNum, [In] ref uint pInfo)185         internal static extern uint SNIAddProvider(SNIHandle pConn, ProviderEnum ProvNum, [In] ref uint pInfo);
186 
187         [DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNICheckConnectionWrapper")]
SNICheckConnection([In] SNIHandle pConn)188         internal static extern uint SNICheckConnection([In] SNIHandle pConn);
189 
190         [DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNICloseWrapper")]
SNIClose(IntPtr pConn)191         internal static extern uint SNIClose(IntPtr pConn);
192 
193         [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
SNIGetLastError(out SNI_Error pErrorStruct)194         internal static extern void SNIGetLastError(out SNI_Error pErrorStruct);
195 
196         [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
SNIPacketRelease(IntPtr pPacket)197         internal static extern void SNIPacketRelease(IntPtr pPacket);
198 
199         [DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNIPacketResetWrapper")]
SNIPacketReset([In] SNIHandle pConn, IOType IOType, SNIPacket pPacket, ConsumerNumber ConsNum)200         internal static extern void SNIPacketReset([In] SNIHandle pConn, IOType IOType, SNIPacket pPacket, ConsumerNumber ConsNum);
201 
202         [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
SNIQueryInfo(QTypes QType, ref uint pbQInfo)203         internal static extern uint SNIQueryInfo(QTypes QType, ref uint pbQInfo);
204 
205         [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
SNIQueryInfo(QTypes QType, ref IntPtr pbQInfo)206         internal static extern uint SNIQueryInfo(QTypes QType, ref IntPtr pbQInfo);
207 
208         [DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNIReadAsyncWrapper")]
SNIReadAsync(SNIHandle pConn, ref IntPtr ppNewPacket)209         internal static extern uint SNIReadAsync(SNIHandle pConn, ref IntPtr ppNewPacket);
210 
211         [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
SNIReadSyncOverAsync(SNIHandle pConn, ref IntPtr ppNewPacket, int timeout)212         internal static extern uint SNIReadSyncOverAsync(SNIHandle pConn, ref IntPtr ppNewPacket, int timeout);
213 
214         [DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNIRemoveProviderWrapper")]
SNIRemoveProvider(SNIHandle pConn, ProviderEnum ProvNum)215         internal static extern uint SNIRemoveProvider(SNIHandle pConn, ProviderEnum ProvNum);
216 
217         [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
SNISecInitPackage(ref uint pcbMaxToken)218         internal static extern uint SNISecInitPackage(ref uint pcbMaxToken);
219 
220         [DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNISetInfoWrapper")]
SNISetInfo(SNIHandle pConn, QTypes QType, [In] ref uint pbQInfo)221         internal static extern uint SNISetInfo(SNIHandle pConn, QTypes QType, [In] ref uint pbQInfo);
222 
223         [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
SNITerminate()224         internal static extern uint SNITerminate();
225 
226         [DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNIWaitForSSLHandshakeToCompleteWrapper")]
SNIWaitForSSLHandshakeToComplete([In] SNIHandle pConn, int dwMilliseconds)227         internal static extern uint SNIWaitForSSLHandshakeToComplete([In] SNIHandle pConn, int dwMilliseconds);
228 
229         [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
UnmanagedIsTokenRestricted([In] IntPtr token, [MarshalAs(UnmanagedType.Bool)] out bool isRestricted)230         internal static extern uint UnmanagedIsTokenRestricted([In] IntPtr token, [MarshalAs(UnmanagedType.Bool)] out bool isRestricted);
231 
232         [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
GetSniMaxComposedSpnLength()233         private static extern uint GetSniMaxComposedSpnLength();
234 
235         [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
SNIGetInfoWrapper([In] SNIHandle pConn, SNINativeMethodWrapper.QTypes QType, out Guid pbQInfo)236         private static extern uint SNIGetInfoWrapper([In] SNIHandle pConn, SNINativeMethodWrapper.QTypes QType, out Guid pbQInfo);
237 
238         [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
SNIInitialize([In] IntPtr pmo)239         private static extern uint SNIInitialize([In] IntPtr pmo);
240 
241         [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
SNIOpenSyncExWrapper(ref SNI_CLIENT_CONSUMER_INFO pClientConsumerInfo, out IntPtr ppConn)242         private static extern uint SNIOpenSyncExWrapper(ref SNI_CLIENT_CONSUMER_INFO pClientConsumerInfo, out IntPtr ppConn);
243 
244         [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
SNIOpenWrapper( [In] ref Sni_Consumer_Info pConsumerInfo, [MarshalAs(UnmanagedType.LPStr)] string szConnect, [In] SNIHandle pConn, out IntPtr ppConn, [MarshalAs(UnmanagedType.Bool)] bool fSync)245         private static extern uint SNIOpenWrapper(
246             [In] ref Sni_Consumer_Info pConsumerInfo,
247             [MarshalAs(UnmanagedType.LPStr)] string szConnect,
248             [In] SNIHandle pConn,
249             out IntPtr ppConn,
250             [MarshalAs(UnmanagedType.Bool)] bool fSync);
251 
252         [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
SNIPacketAllocateWrapper([In] SafeHandle pConn, IOType IOType)253         private static extern IntPtr SNIPacketAllocateWrapper([In] SafeHandle pConn, IOType IOType);
254 
255         [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
SNIPacketGetDataWrapper([In] IntPtr packet, [In, Out] byte[] readBuffer, uint readBufferLength, out uint dataSize)256         private static extern uint SNIPacketGetDataWrapper([In] IntPtr packet, [In, Out] byte[] readBuffer, uint readBufferLength, out uint dataSize);
257 
258         [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
SNIPacketSetData(SNIPacket pPacket, [In] byte* pbBuf, uint cbBuf)259         private static extern unsafe void SNIPacketSetData(SNIPacket pPacket, [In] byte* pbBuf, uint cbBuf);
260 
261         [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
SNISecGenClientContextWrapper( [In] SNIHandle pConn, [In, Out] byte[] pIn, uint cbIn, [In, Out] byte[] pOut, [In] ref uint pcbOut, [MarshalAsAttribute(UnmanagedType.Bool)] out bool pfDone, byte* szServerInfo, uint cbServerInfo, [MarshalAsAttribute(UnmanagedType.LPWStr)] string pwszUserName, [MarshalAsAttribute(UnmanagedType.LPWStr)] string pwszPassword)262         private static extern unsafe uint SNISecGenClientContextWrapper(
263             [In] SNIHandle pConn,
264             [In, Out] byte[] pIn,
265             uint cbIn,
266             [In, Out] byte[] pOut,
267             [In] ref uint pcbOut,
268             [MarshalAsAttribute(UnmanagedType.Bool)] out bool pfDone,
269             byte* szServerInfo,
270             uint cbServerInfo,
271             [MarshalAsAttribute(UnmanagedType.LPWStr)] string pwszUserName,
272             [MarshalAsAttribute(UnmanagedType.LPWStr)] string pwszPassword);
273 
274         [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
SNIWriteAsyncWrapper(SNIHandle pConn, [In] SNIPacket pPacket)275         private static extern uint SNIWriteAsyncWrapper(SNIHandle pConn, [In] SNIPacket pPacket);
276 
277         [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
SNIWriteSyncOverAsync(SNIHandle pConn, [In] SNIPacket pPacket)278         private static extern uint SNIWriteSyncOverAsync(SNIHandle pConn, [In] SNIPacket pPacket);
279         #endregion
280 
SniGetConnectionId(SNIHandle pConn, ref Guid connId)281         internal static uint SniGetConnectionId(SNIHandle pConn, ref Guid connId)
282         {
283             return SNIGetInfoWrapper(pConn, QTypes.SNI_QUERY_CONN_CONNID, out connId);
284         }
285 
SNIInitialize()286         internal static uint SNIInitialize()
287         {
288             return SNIInitialize(IntPtr.Zero);
289         }
290 
SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHandle parent, ref IntPtr pConn, bool fSync)291         internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHandle parent, ref IntPtr pConn, bool fSync)
292         {
293             // initialize consumer info for MARS
294             Sni_Consumer_Info native_consumerInfo = new Sni_Consumer_Info();
295             MarshalConsumerInfo(consumerInfo, ref native_consumerInfo);
296 
297             return SNIOpenWrapper(ref native_consumerInfo, "session:", parent, out pConn, fSync);
298         }
299 
SNIOpenSyncEx(ConsumerInfo consumerInfo, string constring, ref IntPtr pConn, byte[] spnBuffer, byte[] instanceName, bool fOverrideCache, bool fSync, int timeout, bool fParallel)300         internal static unsafe uint SNIOpenSyncEx(ConsumerInfo consumerInfo, string constring, ref IntPtr pConn, byte[] spnBuffer, byte[] instanceName, bool fOverrideCache, bool fSync, int timeout, bool fParallel)
301         {
302             fixed (byte* pin_instanceName = &instanceName[0])
303             {
304                 SNI_CLIENT_CONSUMER_INFO clientConsumerInfo = new SNI_CLIENT_CONSUMER_INFO();
305 
306                 // initialize client ConsumerInfo part first
307                 MarshalConsumerInfo(consumerInfo, ref clientConsumerInfo.ConsumerInfo);
308 
309                 clientConsumerInfo.wszConnectionString = constring;
310                 clientConsumerInfo.networkLibrary = PrefixEnum.UNKNOWN_PREFIX;
311 
312                 clientConsumerInfo.szInstanceName = pin_instanceName;
313                 clientConsumerInfo.cchInstanceName = (uint)instanceName.Length;
314                 clientConsumerInfo.fOverrideLastConnectCache = fOverrideCache;
315                 clientConsumerInfo.fSynchronousConnection = fSync;
316                 clientConsumerInfo.timeout = timeout;
317                 clientConsumerInfo.fParallel = fParallel;
318 
319                 clientConsumerInfo.transparentNetworkResolution = TransparentNetworkResolutionMode.DisabledMode;
320                 clientConsumerInfo.totalTimeout = SniOpenTimeOut;
321                 clientConsumerInfo.isAzureSqlServerEndpoint = ADP.IsAzureSqlServerEndpoint(constring);
322 
323                 if (spnBuffer != null)
324                 {
325                     fixed (byte* pin_spnBuffer = &spnBuffer[0])
326                     {
327                         clientConsumerInfo.szSPN = pin_spnBuffer;
328                         clientConsumerInfo.cchSPN = (uint)spnBuffer.Length;
329                         return SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn);
330                     }
331                 }
332                 else
333                 {
334                     // else leave szSPN null (SQL Auth)
335                     return SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn);
336                 }
337             }
338         }
339 
SNIPacketAllocate(SafeHandle pConn, IOType IOType, ref IntPtr pPacket)340         internal static void SNIPacketAllocate(SafeHandle pConn, IOType IOType, ref IntPtr pPacket)
341         {
342             pPacket = SNIPacketAllocateWrapper(pConn, IOType);
343         }
344 
SNIPacketGetData(IntPtr packet, byte[] readBuffer, ref uint dataSize)345         internal static unsafe uint SNIPacketGetData(IntPtr packet, byte[] readBuffer, ref uint dataSize)
346         {
347             return SNIPacketGetDataWrapper(packet, readBuffer, (uint)readBuffer.Length, out dataSize);
348         }
349 
SNIPacketSetData(SNIPacket packet, byte[] data, int length)350         internal static unsafe void SNIPacketSetData(SNIPacket packet, byte[] data, int length)
351         {
352             fixed (byte* pin_data = &data[0])
353             {
354                 SNIPacketSetData(packet, pin_data, (uint)length);
355             }
356         }
357 
SNISecGenClientContext(SNIHandle pConnectionObject, byte[] inBuff, uint receivedLength, byte[] OutBuff, ref uint sendLength, byte[] serverUserName)358         internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, byte[] inBuff, uint receivedLength, byte[] OutBuff, ref uint sendLength, byte[] serverUserName)
359         {
360             fixed (byte* pin_serverUserName = &serverUserName[0])
361             {
362                 bool local_fDone;
363                 return SNISecGenClientContextWrapper(
364                     pConnectionObject,
365                     inBuff,
366                     receivedLength,
367                     OutBuff,
368                     ref sendLength,
369                     out local_fDone,
370                     pin_serverUserName,
371                     (uint)serverUserName.Length,
372                     null,
373                     null);
374             }
375         }
376 
SNIWritePacket(SNIHandle pConn, SNIPacket packet, bool sync)377         internal static uint SNIWritePacket(SNIHandle pConn, SNIPacket packet, bool sync)
378         {
379             if (sync)
380             {
381                 return SNIWriteSyncOverAsync(pConn, packet);
382             }
383             else
384             {
385                 return SNIWriteAsyncWrapper(pConn, packet);
386             }
387         }
388 
MarshalConsumerInfo(ConsumerInfo consumerInfo, ref Sni_Consumer_Info native_consumerInfo)389         private static void MarshalConsumerInfo(ConsumerInfo consumerInfo, ref Sni_Consumer_Info native_consumerInfo)
390         {
391             native_consumerInfo.DefaultUserDataLength = consumerInfo.defaultBufferSize;
392             native_consumerInfo.fnReadComp = null != consumerInfo.readDelegate
393                 ? Marshal.GetFunctionPointerForDelegate(consumerInfo.readDelegate)
394                 : IntPtr.Zero;
395             native_consumerInfo.fnWriteComp = null != consumerInfo.writeDelegate
396                 ? Marshal.GetFunctionPointerForDelegate(consumerInfo.writeDelegate)
397                 : IntPtr.Zero;
398             native_consumerInfo.ConsumerKey = consumerInfo.key;
399         }
400     }
401 }
402 
403 namespace System.Data
404 {
405     internal static partial class SafeNativeMethods
406     {
407         [DllImport("kernel32.dll", CharSet = CharSet.Ansi, BestFitMapping = false, ThrowOnUnmappableChar = true, SetLastError = true)]
GetProcAddress(IntPtr HModule, [MarshalAs(UnmanagedType.LPStr), In] string funcName)408         internal static extern IntPtr GetProcAddress(IntPtr HModule, [MarshalAs(UnmanagedType.LPStr), In] string funcName);
409     }
410 }
411 
412 namespace System.Data
413 {
414     internal static class Win32NativeMethods
415     {
IsTokenRestrictedWrapper(IntPtr token)416         internal static bool IsTokenRestrictedWrapper(IntPtr token)
417         {
418             bool isRestricted;
419             uint result = SNINativeMethodWrapper.UnmanagedIsTokenRestricted(token, out isRestricted);
420 
421             if (result != 0)
422             {
423                 Marshal.ThrowExceptionForHR(unchecked((int)result));
424             }
425 
426             return isRestricted;
427         }
428     }
429 }
430