1 // Licensed to the .NET Foundation under one or more agreements. 2 // The .NET Foundation licenses this file to you under the MIT license. 3 // See the LICENSE file in the project root for more information. 4 5 using System.Collections; 6 using System.Collections.Concurrent; 7 using System.Collections.Generic; 8 using System.Linq; 9 using System.Reflection; 10 11 namespace System.Data.SqlClient.ManualTesting.Tests.SystemDataInternals 12 { 13 internal static class ConnectionPoolHelper 14 { 15 private static Assembly s_systemDotData = Assembly.Load(new AssemblyName(typeof(SqlConnection).GetTypeInfo().Assembly.FullName)); 16 private static Type s_dbConnectionPool = s_systemDotData.GetType("System.Data.ProviderBase.DbConnectionPool"); 17 private static Type s_dbConnectionPoolGroup = s_systemDotData.GetType("System.Data.ProviderBase.DbConnectionPoolGroup"); 18 private static Type s_dbConnectionPoolIdentity = s_systemDotData.GetType("System.Data.ProviderBase.DbConnectionPoolIdentity"); 19 private static Type s_dbConnectionFactory = s_systemDotData.GetType("System.Data.ProviderBase.DbConnectionFactory"); 20 private static Type s_sqlConnectionFactory = s_systemDotData.GetType("System.Data.SqlClient.SqlConnectionFactory"); 21 private static Type s_dbConnectionPoolKey = s_systemDotData.GetType("System.Data.Common.DbConnectionPoolKey"); 22 private static Type s_dictStringPoolGroup = typeof(Dictionary<,>).MakeGenericType(s_dbConnectionPoolKey, s_dbConnectionPoolGroup); 23 private static Type s_dictPoolIdentityPool = typeof(ConcurrentDictionary<,>).MakeGenericType(s_dbConnectionPoolIdentity, s_dbConnectionPool); 24 private static PropertyInfo s_dbConnectionPoolCount = s_dbConnectionPool.GetProperty("Count", BindingFlags.Instance | BindingFlags.NonPublic); 25 private static PropertyInfo s_dictStringPoolGroupGetKeys = s_dictStringPoolGroup.GetProperty("Keys"); 26 private static PropertyInfo s_dictPoolIdentityPoolValues = s_dictPoolIdentityPool.GetProperty("Values"); 27 private static FieldInfo s_dbConnectionFactoryPoolGroupList = s_dbConnectionFactory.GetField("_connectionPoolGroups", BindingFlags.Instance | BindingFlags.NonPublic); 28 private static FieldInfo s_dbConnectionPoolGroupPoolCollection = s_dbConnectionPoolGroup.GetField("_poolCollection", BindingFlags.Instance | BindingFlags.NonPublic); 29 private static FieldInfo s_sqlConnectionFactorySingleton = s_sqlConnectionFactory.GetField("SingletonInstance", BindingFlags.Static | BindingFlags.Public); 30 private static FieldInfo s_dbConnectionPoolStackOld = s_dbConnectionPool.GetField("_stackOld", BindingFlags.Instance | BindingFlags.NonPublic); 31 private static FieldInfo s_dbConnectionPoolStackNew = s_dbConnectionPool.GetField("_stackNew", BindingFlags.Instance | BindingFlags.NonPublic); 32 private static MethodInfo s_dbConnectionPoolCleanup = s_dbConnectionPool.GetMethod("CleanupCallback", BindingFlags.Instance | BindingFlags.NonPublic); 33 private static MethodInfo s_dictStringPoolGroupTryGetValue = s_dictStringPoolGroup.GetMethod("TryGetValue"); 34 CountFreeConnections(object pool)35 public static int CountFreeConnections(object pool) 36 { 37 VerifyObjectIsPool(pool); 38 39 ICollection oldStack = (ICollection)s_dbConnectionPoolStackOld.GetValue(pool); 40 ICollection newStack = (ICollection)s_dbConnectionPoolStackNew.GetValue(pool); 41 42 return (oldStack.Count + newStack.Count); 43 } 44 45 /// <summary> 46 /// Finds all connection pools 47 /// </summary> 48 /// <returns></returns> AllConnectionPools()49 public static List<Tuple<object, object>> AllConnectionPools() 50 { 51 List<Tuple<object, object>> connectionPools = new List<Tuple<object, object>>(); 52 object factorySingleton = s_sqlConnectionFactorySingleton.GetValue(null); 53 object AllPoolGroups = s_dbConnectionFactoryPoolGroupList.GetValue(factorySingleton); 54 ICollection connectionPoolKeys = (ICollection)s_dictStringPoolGroupGetKeys.GetValue(AllPoolGroups, null); 55 foreach (var item in connectionPoolKeys) 56 { 57 object[] args = new object[] { item, null }; 58 s_dictStringPoolGroupTryGetValue.Invoke(AllPoolGroups, args); 59 if (args[1] != null) 60 { 61 object poolCollection = s_dbConnectionPoolGroupPoolCollection.GetValue(args[1]); 62 IEnumerable poolList = (IEnumerable)(s_dictPoolIdentityPoolValues.GetValue(poolCollection)); 63 foreach (object pool in poolList) 64 { 65 connectionPools.Add(new Tuple<object, object>(pool, item)); 66 } 67 } 68 } 69 70 return connectionPools; 71 } 72 73 /// <summary> 74 /// Finds a connection pool based on a connection string 75 /// </summary> 76 /// <param name="connectionString"></param> 77 /// <returns></returns> ConnectionPoolFromString(string connectionString)78 public static object ConnectionPoolFromString(string connectionString) 79 { 80 if (connectionString == null) 81 throw new ArgumentNullException("connectionString"); 82 83 object pool = null; 84 object factorySingleton = s_sqlConnectionFactorySingleton.GetValue(null); 85 object AllPoolGroups = s_dbConnectionFactoryPoolGroupList.GetValue(factorySingleton); 86 object[] args = new object[] { connectionString, null }; 87 bool found = (bool)s_dictStringPoolGroupTryGetValue.Invoke(AllPoolGroups, args); 88 if ((found) && (args[1] != null)) 89 { 90 ICollection poolList = (ICollection)s_dictPoolIdentityPoolValues.GetValue(args[1]); 91 if (poolList.Count == 1) 92 { 93 poolList.Cast<object>().First(); 94 } 95 else if (poolList.Count > 1) 96 { 97 throw new NotSupportedException("Using multiple identities with SSPI is not supported"); 98 } 99 } 100 101 return pool; 102 } 103 104 /// <summary> 105 /// Causes the cleanup timer code in the connection pool to be invoked 106 /// </summary> 107 /// <param name="obj">A connection pool object</param> CleanConnectionPool(object pool)108 internal static void CleanConnectionPool(object pool) 109 { 110 VerifyObjectIsPool(pool); 111 s_dbConnectionPoolCleanup.Invoke(pool, new object[] { null }); 112 } 113 114 /// <summary> 115 /// Counts the number of connections in a connection pool 116 /// </summary> 117 /// <param name="pool">Pool to count connections in</param> 118 /// <returns></returns> CountConnectionsInPool(object pool)119 internal static int CountConnectionsInPool(object pool) 120 { 121 VerifyObjectIsPool(pool); 122 return (int)s_dbConnectionPoolCount.GetValue(pool, null); 123 } 124 125 VerifyObjectIsPool(object pool)126 private static void VerifyObjectIsPool(object pool) 127 { 128 if (pool == null) 129 throw new ArgumentNullException("pool"); 130 if (!s_dbConnectionPool.IsInstanceOfType(pool)) 131 throw new ArgumentException("Object provided was not a DbConnectionPool", "pool"); 132 } 133 } 134 } 135