1 //------------------------------------------------------------------------------
2 // <copyright file="SqlConnectionHelper.cs" company="Microsoft">
3 //     Copyright (c) Microsoft Corporation.  All rights reserved.
4 // </copyright>
5 //------------------------------------------------------------------------------
6 
7 namespace System.Web.DataAccess {
8 
9     using System;
10     using System.Collections.Specialized;
11     using System.Configuration;
12     using System.Configuration.Provider;
13     using System.Data;
14     using System.Data.SqlClient;
15     using System.Diagnostics;
16     using System.Globalization;
17     using System.IO;
18     using System.Security.Permissions;
19     using System.Threading;
20     using System.Web.Configuration;
21     using System.Web.Hosting;
22     using System.Web.Management;
23     using System.Web.Util;
24 
25     internal static class SqlConnectionHelper {
26         internal const string s_strDataDir = "DataDirectory";
27         internal const string s_strUpperDataDirWithToken = "|DATADIRECTORY|";
28         internal const string s_strSqlExprFileExt = ".MDF";
29         internal const string s_strUpperUserInstance = "USER INSTANCE";
30         private const string s_localDbName = "(LOCALDB)";
31         private static object s_lock = new object();
32 
EnsureNoUserInstance(string connectionString)33         internal static void EnsureNoUserInstance(string connectionString) {
34             SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(connectionString);
35             if (builder.UserInstance) {
36                 throw new ProviderException(SR.GetString(SR.LocalDB_cannot_have_userinstance_flag));
37             }
38         }
39 
GetConnection(string connectionString, bool revertImpersonation)40         internal static SqlConnectionHolder GetConnection(string connectionString, bool revertImpersonation) {
41             string strTempConnection = connectionString.ToUpperInvariant();
42             if (strTempConnection.Contains(s_strUpperDataDirWithToken)) {
43                 EnsureDBFile(connectionString);
44             }
45 
46             // Only block UserInstance for LocalDB connections
47             if (strTempConnection.Contains(s_localDbName)) {
48                 EnsureNoUserInstance(connectionString);
49             }
50 
51             SqlConnectionHolder holder = new SqlConnectionHolder(connectionString);
52             bool closeConn = true;
53             try {
54                 try {
55                     holder.Open(null, revertImpersonation);
56                     closeConn = false;
57                 }
58                 finally {
59                     if (closeConn) {
60                         holder.Close();
61                         holder = null;
62                     }
63                 }
64             }
65             catch {
66                 throw;
67             }
68             return holder;
69         }
70 
GetConnectionString(string specifiedConnectionString, bool lookupConnectionString, bool appLevel)71         internal static string GetConnectionString(string specifiedConnectionString, bool lookupConnectionString, bool appLevel) {
72             System.Web.Util.Debug.Assert((specifiedConnectionString != null) && (specifiedConnectionString.Length != 0));
73             if (specifiedConnectionString == null || specifiedConnectionString.Length < 1)
74                 return null;
75 
76             string connectionString = null;
77 
78             // Step 1: Check <connectionStrings> config section for this connection string
79             if (lookupConnectionString) {
80                 RuntimeConfig config = (appLevel) ? RuntimeConfig.GetAppConfig() : RuntimeConfig.GetConfig();
81                 ConnectionStringSettings connObj = config.ConnectionStrings.ConnectionStrings[specifiedConnectionString];
82                 if (connObj != null)
83                     connectionString = connObj.ConnectionString;
84 
85                 if (connectionString == null)
86                     return null;
87 
88                 //HandlerBase.CheckAndReadRegistryValue (ref connectionString, true);
89             }
90             else {
91                 connectionString = specifiedConnectionString;
92             }
93 
94             return connectionString;
95         }
96 
97         [PermissionSet(SecurityAction.Assert, Unrestricted = true)]
GetDataDirectory()98         internal static string GetDataDirectory() {
99             if (HostingEnvironment.IsHosted)
100                 return Path.Combine(HttpRuntime.AppDomainAppPath, HttpRuntime.DataDirectoryName);
101 
102             string dataDir = AppDomain.CurrentDomain.GetData(s_strDataDir) as string;
103             if (string.IsNullOrEmpty(dataDir)) {
104                 string appPath = null;
105 
106 #if !FEATURE_PAL // FEATURE_PAL does not support ProcessModule
107                 Process p = Process.GetCurrentProcess();
108                 ProcessModule pm = (p != null ? p.MainModule : null);
109                 string exeName = (pm != null ? pm.FileName : null);
110 
111                 if (!string.IsNullOrEmpty(exeName))
112                     appPath = Path.GetDirectoryName(exeName);
113 #endif // !FEATURE_PAL
114 
115                 if (string.IsNullOrEmpty(appPath))
116                     appPath = Environment.CurrentDirectory;
117 
118                 dataDir = Path.Combine(appPath, HttpRuntime.DataDirectoryName);
119                 AppDomain.CurrentDomain.SetData(s_strDataDir, dataDir, new FileIOPermission(FileIOPermissionAccess.PathDiscovery, dataDir));
120             }
121 
122             return dataDir;
123         }
124 
EnsureDBFile(string connectionString)125         private static void EnsureDBFile(string connectionString) {
126             string partialFileName = null;
127             string fullFileName = null;
128             string dataDir = GetDataDirectory();
129             bool lookingForDataDir = true;
130             bool lookingForDB = true;
131             string[] splitedConnStr = connectionString.Split(new char[] { ';' }, StringSplitOptions.RemoveEmptyEntries);
132             bool lookingForUserInstance = !connectionString.ToUpperInvariant().Contains(s_localDbName); // We don't require UserInstance=True for LocalDb
133             bool lookingForTimeout = true;
134 
135             foreach (string str in splitedConnStr) {
136                 string strUpper = str.ToUpper(CultureInfo.InvariantCulture).Trim();
137 
138                 if (lookingForDataDir && strUpper.Contains(s_strUpperDataDirWithToken)) {
139                     lookingForDataDir = false;
140 
141                     // Replace the AttachDBFilename part with "Pooling=false"
142                     connectionString = connectionString.Replace(str, "Pooling=false");
143 
144                     // Extract the filenames
145                     int startPos = strUpper.IndexOf(s_strUpperDataDirWithToken, StringComparison.Ordinal) + s_strUpperDataDirWithToken.Length;
146                     partialFileName = strUpper.Substring(startPos).Trim();
147                     while (partialFileName.StartsWith("\\", StringComparison.Ordinal))
148                         partialFileName = partialFileName.Substring(1);
149                     if (partialFileName.Contains("..")) // don't allow it to traverse-up
150                         partialFileName = null;
151                     else
152                         fullFileName = Path.Combine(dataDir, partialFileName);
153                     if (!lookingForDB)
154                         break; // done
155                 }
156                 else if (lookingForDB && (strUpper.StartsWith("INITIAL CATALOG", StringComparison.Ordinal) || strUpper.StartsWith("DATABASE", StringComparison.Ordinal))) {
157                     lookingForDB = false;
158                     connectionString = connectionString.Replace(str, "Database=master");
159                     if (!lookingForDataDir)
160                         break; // done
161                 }
162                 else if (lookingForUserInstance && strUpper.StartsWith(s_strUpperUserInstance, StringComparison.Ordinal)) {
163                     lookingForUserInstance = false;
164                     int pos = strUpper.IndexOf('=');
165                     if (pos < 0)
166                         return;
167                     string strTemp = strUpper.Substring(pos + 1).Trim();
168                     if (strTemp != "TRUE")
169                         return;
170                 }
171                 else if (lookingForTimeout && strUpper.StartsWith("CONNECT TIMEOUT", StringComparison.Ordinal)) {
172                     lookingForTimeout = false;
173                 }
174             }
175             if (lookingForUserInstance)
176                 return;
177 
178             if (fullFileName == null)
179                 throw new ProviderException(SR.GetString(SR.SqlExpress_file_not_found_in_connection_string));
180 
181             if (File.Exists(fullFileName))
182                 return;
183 
184             if (!HttpRuntime.HasAspNetHostingPermission(AspNetHostingPermissionLevel.High))
185                 throw new ProviderException(SR.GetString(SR.Provider_can_not_create_file_in_this_trust_level));
186 
187             if (!connectionString.Contains("Database=master"))
188                 connectionString += ";Database=master";
189             if (lookingForTimeout)
190                 connectionString += ";Connect Timeout=45";
191             using (new ApplicationImpersonationContext())
192                 lock (s_lock)
193                     if (!File.Exists(fullFileName))
194                         CreateMdfFile(fullFileName, dataDir, connectionString);
195         }
196 
197         [PermissionSet(SecurityAction.Assert, Unrestricted = true)]
CreateMdfFile(string fullFileName, string dataDir, string connectionString)198         private static void CreateMdfFile(string fullFileName, string dataDir, string connectionString) {
199             bool creatingDir = false;
200             string databaseName = null;
201             HttpContext context = HttpContext.Current;
202             string tempFileName = null;
203 
204             try {
205                 if (!Directory.Exists(dataDir)) {
206                     creatingDir = true;
207                     Directory.CreateDirectory(dataDir);
208                     creatingDir = false;
209                     try {
210                         if (context != null)
211                             HttpRuntime.RestrictIISFolders(context);
212                     }
213                     catch { }
214                 }
215 
216                 fullFileName = fullFileName.ToUpper(CultureInfo.InvariantCulture);
217                 char[] strippedFileNameChars = Path.GetFileNameWithoutExtension(fullFileName).ToCharArray();
218                 for (int iter = 0; iter < strippedFileNameChars.Length; iter++)
219                     if (!char.IsLetterOrDigit(strippedFileNameChars[iter]))
220                         strippedFileNameChars[iter] = '_';
221                 string strippedFileName = new string(strippedFileNameChars);
222                 if (strippedFileName.Length > 30)
223                     databaseName = strippedFileName.Substring(0, 30) + "_" + Guid.NewGuid().ToString("N", CultureInfo.InvariantCulture);
224                 else
225                     databaseName = strippedFileName + "_" + Guid.NewGuid().ToString("N", CultureInfo.InvariantCulture);
226 
227                 tempFileName = Path.Combine(Path.GetDirectoryName(fullFileName), strippedFileName + "_TMP" + s_strSqlExprFileExt);
228 
229                 // Auto create the temporary database
230                 SqlServices.Install(databaseName, tempFileName, connectionString);
231                 DetachDB(databaseName, connectionString);
232                 try {
233                     File.Move(tempFileName, fullFileName);
234                 }
235                 catch {
236                     if (!File.Exists(fullFileName)) {
237                         File.Copy(tempFileName, fullFileName);
238                         try {
239                             File.Delete(tempFileName);
240                         }
241                         catch { }
242                     }
243                 }
244                 try {
245                     File.Delete(tempFileName.Replace("_TMP.MDF", "_TMP_log.LDF"));
246                 }
247                 catch { }
248             }
249             catch (Exception e) {
250                 if (context == null || context.IsCustomErrorEnabled)
251                     throw;
252                 HttpException httpExec = new HttpException(e.Message, e);
253                 if (e is UnauthorizedAccessException)
254                     httpExec.SetFormatter(new SqlExpressConnectionErrorFormatter(creatingDir ? DataConnectionErrorEnum.CanNotCreateDataDir : DataConnectionErrorEnum.CanNotWriteToDataDir));
255                 else
256                     httpExec.SetFormatter(new SqlExpressDBFileAutoCreationErrorFormatter(e));
257                 throw httpExec;
258             }
259         }
260 
DetachDB(string databaseName, string connectionString)261         private static void DetachDB(string databaseName, string connectionString) {
262             SqlConnection connection = new SqlConnection(connectionString);
263             try {
264                 connection.Open();
265                 SqlCommand command = new SqlCommand("USE master", connection);
266                 command.ExecuteNonQuery();
267                 command = new SqlCommand("sp_detach_db", connection);
268                 command.CommandType = CommandType.StoredProcedure;
269                 command.Parameters.AddWithValue("@dbname", databaseName);
270                 command.Parameters.AddWithValue("@skipchecks", "true");
271                 command.ExecuteNonQuery();
272             }
273             catch {
274             }
275             finally {
276                 connection.Close();
277             }
278         }
279     }
280 
281     internal sealed class SqlConnectionHolder {
282         internal SqlConnection _Connection;
283         private bool _Opened;
284 
285         internal SqlConnection Connection {
286             get { return _Connection; }
287         }
288 
SqlConnectionHolder(string connectionString)289         internal SqlConnectionHolder(string connectionString) {
290             try {
291                 _Connection = new SqlConnection(connectionString);
292                 System.Web.Util.Debug.Assert(_Connection != null);
293             }
294             catch (ArgumentException e) {
295                 throw new ArgumentException(SR.GetString(SR.SqlError_Connection_String), "connectionString", e);
296             }
297         }
298 
Open(HttpContext context, bool revertImpersonate)299         internal void Open(HttpContext context, bool revertImpersonate) {
300             if (_Opened)
301                 return; // Already opened
302 
303             if (revertImpersonate) {
304                 using (new ApplicationImpersonationContext()) {
305                     Connection.Open();
306                 }
307             }
308             else {
309                 Connection.Open();
310             }
311 
312             _Opened = true; // Open worked!
313         }
314 
Close()315         internal void Close() {
316             if (!_Opened) // Not open!
317                 return;
318             // Close connection
319             Connection.Close();
320             _Opened = false;
321         }
322     }
323 }
324 
325