1 // Licensed to the .NET Foundation under one or more agreements.
2 // The .NET Foundation licenses this file to you under the MIT license.
3 // See the LICENSE file in the project root for more information.
4 
5 using System.Collections;
6 using System.Collections.Concurrent;
7 using System.Collections.Generic;
8 using System.Linq;
9 using System.Reflection;
10 
11 namespace System.Data.SqlClient.ManualTesting.Tests.SystemDataInternals
12 {
13     internal static class ConnectionPoolHelper
14     {
15         private static Assembly s_systemDotData = Assembly.Load(new AssemblyName(typeof(SqlConnection).GetTypeInfo().Assembly.FullName));
16         private static Type s_dbConnectionPool = s_systemDotData.GetType("System.Data.ProviderBase.DbConnectionPool");
17         private static Type s_dbConnectionPoolGroup = s_systemDotData.GetType("System.Data.ProviderBase.DbConnectionPoolGroup");
18         private static Type s_dbConnectionPoolIdentity = s_systemDotData.GetType("System.Data.ProviderBase.DbConnectionPoolIdentity");
19         private static Type s_dbConnectionFactory = s_systemDotData.GetType("System.Data.ProviderBase.DbConnectionFactory");
20         private static Type s_sqlConnectionFactory = s_systemDotData.GetType("System.Data.SqlClient.SqlConnectionFactory");
21         private static Type s_dbConnectionPoolKey = s_systemDotData.GetType("System.Data.Common.DbConnectionPoolKey");
22         private static Type s_dictStringPoolGroup = typeof(Dictionary<,>).MakeGenericType(s_dbConnectionPoolKey, s_dbConnectionPoolGroup);
23         private static Type s_dictPoolIdentityPool = typeof(ConcurrentDictionary<,>).MakeGenericType(s_dbConnectionPoolIdentity, s_dbConnectionPool);
24         private static PropertyInfo s_dbConnectionPoolCount = s_dbConnectionPool.GetProperty("Count", BindingFlags.Instance | BindingFlags.NonPublic);
25         private static PropertyInfo s_dictStringPoolGroupGetKeys = s_dictStringPoolGroup.GetProperty("Keys");
26         private static PropertyInfo s_dictPoolIdentityPoolValues = s_dictPoolIdentityPool.GetProperty("Values");
27         private static FieldInfo s_dbConnectionFactoryPoolGroupList = s_dbConnectionFactory.GetField("_connectionPoolGroups", BindingFlags.Instance | BindingFlags.NonPublic);
28         private static FieldInfo s_dbConnectionPoolGroupPoolCollection = s_dbConnectionPoolGroup.GetField("_poolCollection", BindingFlags.Instance | BindingFlags.NonPublic);
29         private static FieldInfo s_sqlConnectionFactorySingleton = s_sqlConnectionFactory.GetField("SingletonInstance", BindingFlags.Static | BindingFlags.Public);
30         private static FieldInfo s_dbConnectionPoolStackOld = s_dbConnectionPool.GetField("_stackOld", BindingFlags.Instance | BindingFlags.NonPublic);
31         private static FieldInfo s_dbConnectionPoolStackNew = s_dbConnectionPool.GetField("_stackNew", BindingFlags.Instance | BindingFlags.NonPublic);
32         private static MethodInfo s_dbConnectionPoolCleanup = s_dbConnectionPool.GetMethod("CleanupCallback", BindingFlags.Instance | BindingFlags.NonPublic);
33         private static MethodInfo s_dictStringPoolGroupTryGetValue = s_dictStringPoolGroup.GetMethod("TryGetValue");
34 
CountFreeConnections(object pool)35         public static int CountFreeConnections(object pool)
36         {
37             VerifyObjectIsPool(pool);
38 
39             ICollection oldStack = (ICollection)s_dbConnectionPoolStackOld.GetValue(pool);
40             ICollection newStack = (ICollection)s_dbConnectionPoolStackNew.GetValue(pool);
41 
42             return (oldStack.Count + newStack.Count);
43         }
44 
45         /// <summary>
46         /// Finds all connection pools
47         /// </summary>
48         /// <returns></returns>
AllConnectionPools()49         public static List<Tuple<object, object>> AllConnectionPools()
50         {
51             List<Tuple<object, object>> connectionPools = new List<Tuple<object, object>>();
52             object factorySingleton = s_sqlConnectionFactorySingleton.GetValue(null);
53             object AllPoolGroups = s_dbConnectionFactoryPoolGroupList.GetValue(factorySingleton);
54             ICollection connectionPoolKeys = (ICollection)s_dictStringPoolGroupGetKeys.GetValue(AllPoolGroups, null);
55             foreach (var item in connectionPoolKeys)
56             {
57                 object[] args = new object[] { item, null };
58                 s_dictStringPoolGroupTryGetValue.Invoke(AllPoolGroups, args);
59                 if (args[1] != null)
60                 {
61                     object poolCollection = s_dbConnectionPoolGroupPoolCollection.GetValue(args[1]);
62                     IEnumerable poolList = (IEnumerable)(s_dictPoolIdentityPoolValues.GetValue(poolCollection));
63                     foreach (object pool in poolList)
64                     {
65                         connectionPools.Add(new Tuple<object, object>(pool, item));
66                     }
67                 }
68             }
69 
70             return connectionPools;
71         }
72 
73         /// <summary>
74         /// Finds a connection pool based on a connection string
75         /// </summary>
76         /// <param name="connectionString"></param>
77         /// <returns></returns>
ConnectionPoolFromString(string connectionString)78         public static object ConnectionPoolFromString(string connectionString)
79         {
80             if (connectionString == null)
81                 throw new ArgumentNullException("connectionString");
82 
83             object pool = null;
84             object factorySingleton = s_sqlConnectionFactorySingleton.GetValue(null);
85             object AllPoolGroups = s_dbConnectionFactoryPoolGroupList.GetValue(factorySingleton);
86             object[] args = new object[] { connectionString, null };
87             bool found = (bool)s_dictStringPoolGroupTryGetValue.Invoke(AllPoolGroups, args);
88             if ((found) && (args[1] != null))
89             {
90                 ICollection poolList = (ICollection)s_dictPoolIdentityPoolValues.GetValue(args[1]);
91                 if (poolList.Count == 1)
92                 {
93                     poolList.Cast<object>().First();
94                 }
95                 else if (poolList.Count > 1)
96                 {
97                     throw new NotSupportedException("Using multiple identities with SSPI is not supported");
98                 }
99             }
100 
101             return pool;
102         }
103 
104         /// <summary>
105         /// Causes the cleanup timer code in the connection pool to be invoked
106         /// </summary>
107         /// <param name="obj">A connection pool object</param>
CleanConnectionPool(object pool)108         internal static void CleanConnectionPool(object pool)
109         {
110             VerifyObjectIsPool(pool);
111             s_dbConnectionPoolCleanup.Invoke(pool, new object[] { null });
112         }
113 
114         /// <summary>
115         /// Counts the number of connections in a connection pool
116         /// </summary>
117         /// <param name="pool">Pool to count connections in</param>
118         /// <returns></returns>
CountConnectionsInPool(object pool)119         internal static int CountConnectionsInPool(object pool)
120         {
121             VerifyObjectIsPool(pool);
122             return (int)s_dbConnectionPoolCount.GetValue(pool, null);
123         }
124 
125 
VerifyObjectIsPool(object pool)126         private static void VerifyObjectIsPool(object pool)
127         {
128             if (pool == null)
129                 throw new ArgumentNullException("pool");
130             if (!s_dbConnectionPool.IsInstanceOfType(pool))
131                 throw new ArgumentException("Object provided was not a DbConnectionPool", "pool");
132         }
133     }
134 }
135