1 // Licensed to the .NET Foundation under one or more agreements. 2 // The .NET Foundation licenses this file to you under the MIT license. 3 // See the LICENSE file in the project root for more information. 4 5 using System.Data.Common; 6 using System.Data.SqlClient; 7 using System.Runtime.InteropServices; 8 9 namespace System.Data.SqlClient 10 { 11 internal static partial class SNINativeMethodWrapper 12 { 13 private const string SNI = "sni.dll"; 14 15 private static int s_sniMaxComposedSpnLength = -1; 16 17 private const int SniOpenTimeOut = -1; // infinite 18 19 [UnmanagedFunctionPointer(CallingConvention.StdCall)] SqlAsyncCallbackDelegate(IntPtr m_ConsKey, IntPtr pPacket, uint dwError)20 internal delegate void SqlAsyncCallbackDelegate(IntPtr m_ConsKey, IntPtr pPacket, uint dwError); 21 22 internal static int SniMaxComposedSpnLength 23 { 24 get 25 { 26 if (s_sniMaxComposedSpnLength == -1) 27 { 28 s_sniMaxComposedSpnLength = checked((int)GetSniMaxComposedSpnLength()); 29 } 30 return s_sniMaxComposedSpnLength; 31 } 32 } 33 34 #region Structs\Enums 35 [StructLayout(LayoutKind.Sequential)] 36 internal struct ConsumerInfo 37 { 38 internal int defaultBufferSize; 39 internal SqlAsyncCallbackDelegate readDelegate; 40 internal SqlAsyncCallbackDelegate writeDelegate; 41 internal IntPtr key; 42 } 43 44 internal enum ConsumerNumber 45 { 46 SNI_Consumer_SNI, 47 SNI_Consumer_SSB, 48 SNI_Consumer_PacketIsReleased, 49 SNI_Consumer_Invalid, 50 } 51 52 internal enum IOType 53 { 54 READ, 55 WRITE, 56 } 57 58 internal enum PrefixEnum 59 { 60 UNKNOWN_PREFIX, 61 SM_PREFIX, 62 TCP_PREFIX, 63 NP_PREFIX, 64 VIA_PREFIX, 65 INVALID_PREFIX, 66 } 67 68 internal enum ProviderEnum 69 { 70 HTTP_PROV, 71 NP_PROV, 72 SESSION_PROV, 73 SIGN_PROV, 74 SM_PROV, 75 SMUX_PROV, 76 SSL_PROV, 77 TCP_PROV, 78 VIA_PROV, 79 MAX_PROVS, 80 INVALID_PROV, 81 } 82 83 internal enum QTypes 84 { 85 SNI_QUERY_CONN_INFO, 86 SNI_QUERY_CONN_BUFSIZE, 87 SNI_QUERY_CONN_KEY, 88 SNI_QUERY_CLIENT_ENCRYPT_POSSIBLE, 89 SNI_QUERY_SERVER_ENCRYPT_POSSIBLE, 90 SNI_QUERY_CERTIFICATE, 91 SNI_QUERY_LOCALDB_HMODULE, 92 SNI_QUERY_CONN_ENCRYPT, 93 SNI_QUERY_CONN_PROVIDERNUM, 94 SNI_QUERY_CONN_CONNID, 95 SNI_QUERY_CONN_PARENTCONNID, 96 SNI_QUERY_CONN_SECPKG, 97 SNI_QUERY_CONN_NETPACKETSIZE, 98 SNI_QUERY_CONN_NODENUM, 99 SNI_QUERY_CONN_PACKETSRECD, 100 SNI_QUERY_CONN_PACKETSSENT, 101 SNI_QUERY_CONN_PEERADDR, 102 SNI_QUERY_CONN_PEERPORT, 103 SNI_QUERY_CONN_LASTREADTIME, 104 SNI_QUERY_CONN_LASTWRITETIME, 105 SNI_QUERY_CONN_CONSUMER_ID, 106 SNI_QUERY_CONN_CONNECTTIME, 107 SNI_QUERY_CONN_HTTPENDPOINT, 108 SNI_QUERY_CONN_LOCALADDR, 109 SNI_QUERY_CONN_LOCALPORT, 110 SNI_QUERY_CONN_SSLHANDSHAKESTATE, 111 SNI_QUERY_CONN_SOBUFAUTOTUNING, 112 SNI_QUERY_CONN_SECPKGNAME, 113 SNI_QUERY_CONN_SECPKGMUTUALAUTH, 114 SNI_QUERY_CONN_CONSUMERCONNID, 115 SNI_QUERY_CONN_SNIUCI, 116 SNI_QUERY_CONN_SUPPORTS_EXTENDED_PROTECTION, 117 SNI_QUERY_CONN_CHANNEL_PROVIDES_AUTHENTICATION_CONTEXT, 118 SNI_QUERY_CONN_PEERID, 119 SNI_QUERY_CONN_SUPPORTS_SYNC_OVER_ASYNC, 120 } 121 122 internal enum TransparentNetworkResolutionMode : byte 123 { 124 DisabledMode = 0, 125 SequentialMode, 126 ParallelMode 127 }; 128 129 [StructLayout(LayoutKind.Sequential)] 130 private struct Sni_Consumer_Info 131 { 132 public int DefaultUserDataLength; 133 public IntPtr ConsumerKey; 134 public IntPtr fnReadComp; 135 public IntPtr fnWriteComp; 136 public IntPtr fnTrace; 137 public IntPtr fnAcceptComp; 138 public uint dwNumProts; 139 public IntPtr rgListenInfo; 140 public IntPtr NodeAffinity; 141 } 142 143 [StructLayout(LayoutKind.Sequential)] 144 private unsafe struct SNI_CLIENT_CONSUMER_INFO 145 { 146 public Sni_Consumer_Info ConsumerInfo; 147 [MarshalAs(UnmanagedType.LPWStr)] 148 public string wszConnectionString; 149 public PrefixEnum networkLibrary; 150 public byte* szSPN; 151 public uint cchSPN; 152 public byte* szInstanceName; 153 public uint cchInstanceName; 154 [MarshalAs(UnmanagedType.Bool)] 155 public bool fOverrideLastConnectCache; 156 [MarshalAs(UnmanagedType.Bool)] 157 public bool fSynchronousConnection; 158 public int timeout; 159 [MarshalAs(UnmanagedType.Bool)] 160 public bool fParallel; 161 public TransparentNetworkResolutionMode transparentNetworkResolution; 162 public int totalTimeout; 163 public bool isAzureSqlServerEndpoint; 164 } 165 166 [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] 167 internal struct SNI_Error 168 { 169 internal ProviderEnum provider; 170 [MarshalAs(UnmanagedType.ByValTStr, SizeConst = 261)] 171 internal string errorMessage; 172 internal uint nativeError; 173 internal uint sniError; 174 [MarshalAs(UnmanagedType.LPWStr)] 175 internal string fileName; 176 [MarshalAs(UnmanagedType.LPWStr)] 177 internal string function; 178 internal uint lineNumber; 179 } 180 181 #endregion 182 183 #region DLL Imports 184 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNIAddProviderWrapper")] SNIAddProvider(SNIHandle pConn, ProviderEnum ProvNum, [In] ref uint pInfo)185 internal static extern uint SNIAddProvider(SNIHandle pConn, ProviderEnum ProvNum, [In] ref uint pInfo); 186 187 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNICheckConnectionWrapper")] SNICheckConnection([In] SNIHandle pConn)188 internal static extern uint SNICheckConnection([In] SNIHandle pConn); 189 190 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNICloseWrapper")] SNIClose(IntPtr pConn)191 internal static extern uint SNIClose(IntPtr pConn); 192 193 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] SNIGetLastError(out SNI_Error pErrorStruct)194 internal static extern void SNIGetLastError(out SNI_Error pErrorStruct); 195 196 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] SNIPacketRelease(IntPtr pPacket)197 internal static extern void SNIPacketRelease(IntPtr pPacket); 198 199 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNIPacketResetWrapper")] SNIPacketReset([In] SNIHandle pConn, IOType IOType, SNIPacket pPacket, ConsumerNumber ConsNum)200 internal static extern void SNIPacketReset([In] SNIHandle pConn, IOType IOType, SNIPacket pPacket, ConsumerNumber ConsNum); 201 202 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] SNIQueryInfo(QTypes QType, ref uint pbQInfo)203 internal static extern uint SNIQueryInfo(QTypes QType, ref uint pbQInfo); 204 205 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] SNIQueryInfo(QTypes QType, ref IntPtr pbQInfo)206 internal static extern uint SNIQueryInfo(QTypes QType, ref IntPtr pbQInfo); 207 208 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNIReadAsyncWrapper")] SNIReadAsync(SNIHandle pConn, ref IntPtr ppNewPacket)209 internal static extern uint SNIReadAsync(SNIHandle pConn, ref IntPtr ppNewPacket); 210 211 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] SNIReadSyncOverAsync(SNIHandle pConn, ref IntPtr ppNewPacket, int timeout)212 internal static extern uint SNIReadSyncOverAsync(SNIHandle pConn, ref IntPtr ppNewPacket, int timeout); 213 214 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNIRemoveProviderWrapper")] SNIRemoveProvider(SNIHandle pConn, ProviderEnum ProvNum)215 internal static extern uint SNIRemoveProvider(SNIHandle pConn, ProviderEnum ProvNum); 216 217 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] SNISecInitPackage(ref uint pcbMaxToken)218 internal static extern uint SNISecInitPackage(ref uint pcbMaxToken); 219 220 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNISetInfoWrapper")] SNISetInfo(SNIHandle pConn, QTypes QType, [In] ref uint pbQInfo)221 internal static extern uint SNISetInfo(SNIHandle pConn, QTypes QType, [In] ref uint pbQInfo); 222 223 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] SNITerminate()224 internal static extern uint SNITerminate(); 225 226 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNIWaitForSSLHandshakeToCompleteWrapper")] SNIWaitForSSLHandshakeToComplete([In] SNIHandle pConn, int dwMilliseconds)227 internal static extern uint SNIWaitForSSLHandshakeToComplete([In] SNIHandle pConn, int dwMilliseconds); 228 229 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] UnmanagedIsTokenRestricted([In] IntPtr token, [MarshalAs(UnmanagedType.Bool)] out bool isRestricted)230 internal static extern uint UnmanagedIsTokenRestricted([In] IntPtr token, [MarshalAs(UnmanagedType.Bool)] out bool isRestricted); 231 232 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] GetSniMaxComposedSpnLength()233 private static extern uint GetSniMaxComposedSpnLength(); 234 235 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] SNIGetInfoWrapper([In] SNIHandle pConn, SNINativeMethodWrapper.QTypes QType, out Guid pbQInfo)236 private static extern uint SNIGetInfoWrapper([In] SNIHandle pConn, SNINativeMethodWrapper.QTypes QType, out Guid pbQInfo); 237 238 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] SNIInitialize([In] IntPtr pmo)239 private static extern uint SNIInitialize([In] IntPtr pmo); 240 241 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] SNIOpenSyncExWrapper(ref SNI_CLIENT_CONSUMER_INFO pClientConsumerInfo, out IntPtr ppConn)242 private static extern uint SNIOpenSyncExWrapper(ref SNI_CLIENT_CONSUMER_INFO pClientConsumerInfo, out IntPtr ppConn); 243 244 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] SNIOpenWrapper( [In] ref Sni_Consumer_Info pConsumerInfo, [MarshalAs(UnmanagedType.LPStr)] string szConnect, [In] SNIHandle pConn, out IntPtr ppConn, [MarshalAs(UnmanagedType.Bool)] bool fSync)245 private static extern uint SNIOpenWrapper( 246 [In] ref Sni_Consumer_Info pConsumerInfo, 247 [MarshalAs(UnmanagedType.LPStr)] string szConnect, 248 [In] SNIHandle pConn, 249 out IntPtr ppConn, 250 [MarshalAs(UnmanagedType.Bool)] bool fSync); 251 252 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] SNIPacketAllocateWrapper([In] SafeHandle pConn, IOType IOType)253 private static extern IntPtr SNIPacketAllocateWrapper([In] SafeHandle pConn, IOType IOType); 254 255 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] SNIPacketGetDataWrapper([In] IntPtr packet, [In, Out] byte[] readBuffer, uint readBufferLength, out uint dataSize)256 private static extern uint SNIPacketGetDataWrapper([In] IntPtr packet, [In, Out] byte[] readBuffer, uint readBufferLength, out uint dataSize); 257 258 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] SNIPacketSetData(SNIPacket pPacket, [In] byte* pbBuf, uint cbBuf)259 private static extern unsafe void SNIPacketSetData(SNIPacket pPacket, [In] byte* pbBuf, uint cbBuf); 260 261 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] SNISecGenClientContextWrapper( [In] SNIHandle pConn, [In, Out] byte[] pIn, uint cbIn, [In, Out] byte[] pOut, [In] ref uint pcbOut, [MarshalAsAttribute(UnmanagedType.Bool)] out bool pfDone, byte* szServerInfo, uint cbServerInfo, [MarshalAsAttribute(UnmanagedType.LPWStr)] string pwszUserName, [MarshalAsAttribute(UnmanagedType.LPWStr)] string pwszPassword)262 private static extern unsafe uint SNISecGenClientContextWrapper( 263 [In] SNIHandle pConn, 264 [In, Out] byte[] pIn, 265 uint cbIn, 266 [In, Out] byte[] pOut, 267 [In] ref uint pcbOut, 268 [MarshalAsAttribute(UnmanagedType.Bool)] out bool pfDone, 269 byte* szServerInfo, 270 uint cbServerInfo, 271 [MarshalAsAttribute(UnmanagedType.LPWStr)] string pwszUserName, 272 [MarshalAsAttribute(UnmanagedType.LPWStr)] string pwszPassword); 273 274 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] SNIWriteAsyncWrapper(SNIHandle pConn, [In] SNIPacket pPacket)275 private static extern uint SNIWriteAsyncWrapper(SNIHandle pConn, [In] SNIPacket pPacket); 276 277 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] SNIWriteSyncOverAsync(SNIHandle pConn, [In] SNIPacket pPacket)278 private static extern uint SNIWriteSyncOverAsync(SNIHandle pConn, [In] SNIPacket pPacket); 279 #endregion 280 SniGetConnectionId(SNIHandle pConn, ref Guid connId)281 internal static uint SniGetConnectionId(SNIHandle pConn, ref Guid connId) 282 { 283 return SNIGetInfoWrapper(pConn, QTypes.SNI_QUERY_CONN_CONNID, out connId); 284 } 285 SNIInitialize()286 internal static uint SNIInitialize() 287 { 288 return SNIInitialize(IntPtr.Zero); 289 } 290 SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHandle parent, ref IntPtr pConn, bool fSync)291 internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHandle parent, ref IntPtr pConn, bool fSync) 292 { 293 // initialize consumer info for MARS 294 Sni_Consumer_Info native_consumerInfo = new Sni_Consumer_Info(); 295 MarshalConsumerInfo(consumerInfo, ref native_consumerInfo); 296 297 return SNIOpenWrapper(ref native_consumerInfo, "session:", parent, out pConn, fSync); 298 } 299 SNIOpenSyncEx(ConsumerInfo consumerInfo, string constring, ref IntPtr pConn, byte[] spnBuffer, byte[] instanceName, bool fOverrideCache, bool fSync, int timeout, bool fParallel)300 internal static unsafe uint SNIOpenSyncEx(ConsumerInfo consumerInfo, string constring, ref IntPtr pConn, byte[] spnBuffer, byte[] instanceName, bool fOverrideCache, bool fSync, int timeout, bool fParallel) 301 { 302 fixed (byte* pin_instanceName = &instanceName[0]) 303 { 304 SNI_CLIENT_CONSUMER_INFO clientConsumerInfo = new SNI_CLIENT_CONSUMER_INFO(); 305 306 // initialize client ConsumerInfo part first 307 MarshalConsumerInfo(consumerInfo, ref clientConsumerInfo.ConsumerInfo); 308 309 clientConsumerInfo.wszConnectionString = constring; 310 clientConsumerInfo.networkLibrary = PrefixEnum.UNKNOWN_PREFIX; 311 312 clientConsumerInfo.szInstanceName = pin_instanceName; 313 clientConsumerInfo.cchInstanceName = (uint)instanceName.Length; 314 clientConsumerInfo.fOverrideLastConnectCache = fOverrideCache; 315 clientConsumerInfo.fSynchronousConnection = fSync; 316 clientConsumerInfo.timeout = timeout; 317 clientConsumerInfo.fParallel = fParallel; 318 319 clientConsumerInfo.transparentNetworkResolution = TransparentNetworkResolutionMode.DisabledMode; 320 clientConsumerInfo.totalTimeout = SniOpenTimeOut; 321 clientConsumerInfo.isAzureSqlServerEndpoint = ADP.IsAzureSqlServerEndpoint(constring); 322 323 if (spnBuffer != null) 324 { 325 fixed (byte* pin_spnBuffer = &spnBuffer[0]) 326 { 327 clientConsumerInfo.szSPN = pin_spnBuffer; 328 clientConsumerInfo.cchSPN = (uint)spnBuffer.Length; 329 return SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn); 330 } 331 } 332 else 333 { 334 // else leave szSPN null (SQL Auth) 335 return SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn); 336 } 337 } 338 } 339 SNIPacketAllocate(SafeHandle pConn, IOType IOType, ref IntPtr pPacket)340 internal static void SNIPacketAllocate(SafeHandle pConn, IOType IOType, ref IntPtr pPacket) 341 { 342 pPacket = SNIPacketAllocateWrapper(pConn, IOType); 343 } 344 SNIPacketGetData(IntPtr packet, byte[] readBuffer, ref uint dataSize)345 internal static unsafe uint SNIPacketGetData(IntPtr packet, byte[] readBuffer, ref uint dataSize) 346 { 347 return SNIPacketGetDataWrapper(packet, readBuffer, (uint)readBuffer.Length, out dataSize); 348 } 349 SNIPacketSetData(SNIPacket packet, byte[] data, int length)350 internal static unsafe void SNIPacketSetData(SNIPacket packet, byte[] data, int length) 351 { 352 fixed (byte* pin_data = &data[0]) 353 { 354 SNIPacketSetData(packet, pin_data, (uint)length); 355 } 356 } 357 SNISecGenClientContext(SNIHandle pConnectionObject, byte[] inBuff, uint receivedLength, byte[] OutBuff, ref uint sendLength, byte[] serverUserName)358 internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, byte[] inBuff, uint receivedLength, byte[] OutBuff, ref uint sendLength, byte[] serverUserName) 359 { 360 fixed (byte* pin_serverUserName = &serverUserName[0]) 361 { 362 bool local_fDone; 363 return SNISecGenClientContextWrapper( 364 pConnectionObject, 365 inBuff, 366 receivedLength, 367 OutBuff, 368 ref sendLength, 369 out local_fDone, 370 pin_serverUserName, 371 (uint)serverUserName.Length, 372 null, 373 null); 374 } 375 } 376 SNIWritePacket(SNIHandle pConn, SNIPacket packet, bool sync)377 internal static uint SNIWritePacket(SNIHandle pConn, SNIPacket packet, bool sync) 378 { 379 if (sync) 380 { 381 return SNIWriteSyncOverAsync(pConn, packet); 382 } 383 else 384 { 385 return SNIWriteAsyncWrapper(pConn, packet); 386 } 387 } 388 MarshalConsumerInfo(ConsumerInfo consumerInfo, ref Sni_Consumer_Info native_consumerInfo)389 private static void MarshalConsumerInfo(ConsumerInfo consumerInfo, ref Sni_Consumer_Info native_consumerInfo) 390 { 391 native_consumerInfo.DefaultUserDataLength = consumerInfo.defaultBufferSize; 392 native_consumerInfo.fnReadComp = null != consumerInfo.readDelegate 393 ? Marshal.GetFunctionPointerForDelegate(consumerInfo.readDelegate) 394 : IntPtr.Zero; 395 native_consumerInfo.fnWriteComp = null != consumerInfo.writeDelegate 396 ? Marshal.GetFunctionPointerForDelegate(consumerInfo.writeDelegate) 397 : IntPtr.Zero; 398 native_consumerInfo.ConsumerKey = consumerInfo.key; 399 } 400 } 401 } 402 403 namespace System.Data 404 { 405 internal static partial class SafeNativeMethods 406 { 407 [DllImport("kernel32.dll", CharSet = CharSet.Ansi, BestFitMapping = false, ThrowOnUnmappableChar = true, SetLastError = true)] GetProcAddress(IntPtr HModule, [MarshalAs(UnmanagedType.LPStr), In] string funcName)408 internal static extern IntPtr GetProcAddress(IntPtr HModule, [MarshalAs(UnmanagedType.LPStr), In] string funcName); 409 } 410 } 411 412 namespace System.Data 413 { 414 internal static class Win32NativeMethods 415 { IsTokenRestrictedWrapper(IntPtr token)416 internal static bool IsTokenRestrictedWrapper(IntPtr token) 417 { 418 bool isRestricted; 419 uint result = SNINativeMethodWrapper.UnmanagedIsTokenRestricted(token, out isRestricted); 420 421 if (result != 0) 422 { 423 Marshal.ThrowExceptionForHR(unchecked((int)result)); 424 } 425 426 return isRestricted; 427 } 428 } 429 } 430