1 //------------------------------------------------------------------------------ 2 // <copyright file="SqlConnectionHelper.cs" company="Microsoft"> 3 // Copyright (c) Microsoft Corporation. All rights reserved. 4 // </copyright> 5 //------------------------------------------------------------------------------ 6 7 namespace System.Web.DataAccess { 8 9 using System; 10 using System.Collections.Specialized; 11 using System.Configuration; 12 using System.Configuration.Provider; 13 using System.Data; 14 using System.Data.SqlClient; 15 using System.Diagnostics; 16 using System.Globalization; 17 using System.IO; 18 using System.Security.Permissions; 19 using System.Threading; 20 using System.Web.Configuration; 21 using System.Web.Hosting; 22 using System.Web.Management; 23 using System.Web.Util; 24 25 internal static class SqlConnectionHelper { 26 internal const string s_strDataDir = "DataDirectory"; 27 internal const string s_strUpperDataDirWithToken = "|DATADIRECTORY|"; 28 internal const string s_strSqlExprFileExt = ".MDF"; 29 internal const string s_strUpperUserInstance = "USER INSTANCE"; 30 private const string s_localDbName = "(LOCALDB)"; 31 private static object s_lock = new object(); 32 EnsureNoUserInstance(string connectionString)33 internal static void EnsureNoUserInstance(string connectionString) { 34 SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(connectionString); 35 if (builder.UserInstance) { 36 throw new ProviderException(SR.GetString(SR.LocalDB_cannot_have_userinstance_flag)); 37 } 38 } 39 GetConnection(string connectionString, bool revertImpersonation)40 internal static SqlConnectionHolder GetConnection(string connectionString, bool revertImpersonation) { 41 string strTempConnection = connectionString.ToUpperInvariant(); 42 if (strTempConnection.Contains(s_strUpperDataDirWithToken)) { 43 EnsureDBFile(connectionString); 44 } 45 46 // Only block UserInstance for LocalDB connections 47 if (strTempConnection.Contains(s_localDbName)) { 48 EnsureNoUserInstance(connectionString); 49 } 50 51 SqlConnectionHolder holder = new SqlConnectionHolder(connectionString); 52 bool closeConn = true; 53 try { 54 try { 55 holder.Open(null, revertImpersonation); 56 closeConn = false; 57 } 58 finally { 59 if (closeConn) { 60 holder.Close(); 61 holder = null; 62 } 63 } 64 } 65 catch { 66 throw; 67 } 68 return holder; 69 } 70 GetConnectionString(string specifiedConnectionString, bool lookupConnectionString, bool appLevel)71 internal static string GetConnectionString(string specifiedConnectionString, bool lookupConnectionString, bool appLevel) { 72 System.Web.Util.Debug.Assert((specifiedConnectionString != null) && (specifiedConnectionString.Length != 0)); 73 if (specifiedConnectionString == null || specifiedConnectionString.Length < 1) 74 return null; 75 76 string connectionString = null; 77 78 // Step 1: Check <connectionStrings> config section for this connection string 79 if (lookupConnectionString) { 80 RuntimeConfig config = (appLevel) ? RuntimeConfig.GetAppConfig() : RuntimeConfig.GetConfig(); 81 ConnectionStringSettings connObj = config.ConnectionStrings.ConnectionStrings[specifiedConnectionString]; 82 if (connObj != null) 83 connectionString = connObj.ConnectionString; 84 85 if (connectionString == null) 86 return null; 87 88 //HandlerBase.CheckAndReadRegistryValue (ref connectionString, true); 89 } 90 else { 91 connectionString = specifiedConnectionString; 92 } 93 94 return connectionString; 95 } 96 97 [PermissionSet(SecurityAction.Assert, Unrestricted = true)] GetDataDirectory()98 internal static string GetDataDirectory() { 99 if (HostingEnvironment.IsHosted) 100 return Path.Combine(HttpRuntime.AppDomainAppPath, HttpRuntime.DataDirectoryName); 101 102 string dataDir = AppDomain.CurrentDomain.GetData(s_strDataDir) as string; 103 if (string.IsNullOrEmpty(dataDir)) { 104 string appPath = null; 105 106 #if !FEATURE_PAL // FEATURE_PAL does not support ProcessModule 107 Process p = Process.GetCurrentProcess(); 108 ProcessModule pm = (p != null ? p.MainModule : null); 109 string exeName = (pm != null ? pm.FileName : null); 110 111 if (!string.IsNullOrEmpty(exeName)) 112 appPath = Path.GetDirectoryName(exeName); 113 #endif // !FEATURE_PAL 114 115 if (string.IsNullOrEmpty(appPath)) 116 appPath = Environment.CurrentDirectory; 117 118 dataDir = Path.Combine(appPath, HttpRuntime.DataDirectoryName); 119 AppDomain.CurrentDomain.SetData(s_strDataDir, dataDir, new FileIOPermission(FileIOPermissionAccess.PathDiscovery, dataDir)); 120 } 121 122 return dataDir; 123 } 124 EnsureDBFile(string connectionString)125 private static void EnsureDBFile(string connectionString) { 126 string partialFileName = null; 127 string fullFileName = null; 128 string dataDir = GetDataDirectory(); 129 bool lookingForDataDir = true; 130 bool lookingForDB = true; 131 string[] splitedConnStr = connectionString.Split(new char[] { ';' }, StringSplitOptions.RemoveEmptyEntries); 132 bool lookingForUserInstance = !connectionString.ToUpperInvariant().Contains(s_localDbName); // We don't require UserInstance=True for LocalDb 133 bool lookingForTimeout = true; 134 135 foreach (string str in splitedConnStr) { 136 string strUpper = str.ToUpper(CultureInfo.InvariantCulture).Trim(); 137 138 if (lookingForDataDir && strUpper.Contains(s_strUpperDataDirWithToken)) { 139 lookingForDataDir = false; 140 141 // Replace the AttachDBFilename part with "Pooling=false" 142 connectionString = connectionString.Replace(str, "Pooling=false"); 143 144 // Extract the filenames 145 int startPos = strUpper.IndexOf(s_strUpperDataDirWithToken, StringComparison.Ordinal) + s_strUpperDataDirWithToken.Length; 146 partialFileName = strUpper.Substring(startPos).Trim(); 147 while (partialFileName.StartsWith("\\", StringComparison.Ordinal)) 148 partialFileName = partialFileName.Substring(1); 149 if (partialFileName.Contains("..")) // don't allow it to traverse-up 150 partialFileName = null; 151 else 152 fullFileName = Path.Combine(dataDir, partialFileName); 153 if (!lookingForDB) 154 break; // done 155 } 156 else if (lookingForDB && (strUpper.StartsWith("INITIAL CATALOG", StringComparison.Ordinal) || strUpper.StartsWith("DATABASE", StringComparison.Ordinal))) { 157 lookingForDB = false; 158 connectionString = connectionString.Replace(str, "Database=master"); 159 if (!lookingForDataDir) 160 break; // done 161 } 162 else if (lookingForUserInstance && strUpper.StartsWith(s_strUpperUserInstance, StringComparison.Ordinal)) { 163 lookingForUserInstance = false; 164 int pos = strUpper.IndexOf('='); 165 if (pos < 0) 166 return; 167 string strTemp = strUpper.Substring(pos + 1).Trim(); 168 if (strTemp != "TRUE") 169 return; 170 } 171 else if (lookingForTimeout && strUpper.StartsWith("CONNECT TIMEOUT", StringComparison.Ordinal)) { 172 lookingForTimeout = false; 173 } 174 } 175 if (lookingForUserInstance) 176 return; 177 178 if (fullFileName == null) 179 throw new ProviderException(SR.GetString(SR.SqlExpress_file_not_found_in_connection_string)); 180 181 if (File.Exists(fullFileName)) 182 return; 183 184 if (!HttpRuntime.HasAspNetHostingPermission(AspNetHostingPermissionLevel.High)) 185 throw new ProviderException(SR.GetString(SR.Provider_can_not_create_file_in_this_trust_level)); 186 187 if (!connectionString.Contains("Database=master")) 188 connectionString += ";Database=master"; 189 if (lookingForTimeout) 190 connectionString += ";Connect Timeout=45"; 191 using (new ApplicationImpersonationContext()) 192 lock (s_lock) 193 if (!File.Exists(fullFileName)) 194 CreateMdfFile(fullFileName, dataDir, connectionString); 195 } 196 197 [PermissionSet(SecurityAction.Assert, Unrestricted = true)] CreateMdfFile(string fullFileName, string dataDir, string connectionString)198 private static void CreateMdfFile(string fullFileName, string dataDir, string connectionString) { 199 bool creatingDir = false; 200 string databaseName = null; 201 HttpContext context = HttpContext.Current; 202 string tempFileName = null; 203 204 try { 205 if (!Directory.Exists(dataDir)) { 206 creatingDir = true; 207 Directory.CreateDirectory(dataDir); 208 creatingDir = false; 209 try { 210 if (context != null) 211 HttpRuntime.RestrictIISFolders(context); 212 } 213 catch { } 214 } 215 216 fullFileName = fullFileName.ToUpper(CultureInfo.InvariantCulture); 217 char[] strippedFileNameChars = Path.GetFileNameWithoutExtension(fullFileName).ToCharArray(); 218 for (int iter = 0; iter < strippedFileNameChars.Length; iter++) 219 if (!char.IsLetterOrDigit(strippedFileNameChars[iter])) 220 strippedFileNameChars[iter] = '_'; 221 string strippedFileName = new string(strippedFileNameChars); 222 if (strippedFileName.Length > 30) 223 databaseName = strippedFileName.Substring(0, 30) + "_" + Guid.NewGuid().ToString("N", CultureInfo.InvariantCulture); 224 else 225 databaseName = strippedFileName + "_" + Guid.NewGuid().ToString("N", CultureInfo.InvariantCulture); 226 227 tempFileName = Path.Combine(Path.GetDirectoryName(fullFileName), strippedFileName + "_TMP" + s_strSqlExprFileExt); 228 229 // Auto create the temporary database 230 SqlServices.Install(databaseName, tempFileName, connectionString); 231 DetachDB(databaseName, connectionString); 232 try { 233 File.Move(tempFileName, fullFileName); 234 } 235 catch { 236 if (!File.Exists(fullFileName)) { 237 File.Copy(tempFileName, fullFileName); 238 try { 239 File.Delete(tempFileName); 240 } 241 catch { } 242 } 243 } 244 try { 245 File.Delete(tempFileName.Replace("_TMP.MDF", "_TMP_log.LDF")); 246 } 247 catch { } 248 } 249 catch (Exception e) { 250 if (context == null || context.IsCustomErrorEnabled) 251 throw; 252 HttpException httpExec = new HttpException(e.Message, e); 253 if (e is UnauthorizedAccessException) 254 httpExec.SetFormatter(new SqlExpressConnectionErrorFormatter(creatingDir ? DataConnectionErrorEnum.CanNotCreateDataDir : DataConnectionErrorEnum.CanNotWriteToDataDir)); 255 else 256 httpExec.SetFormatter(new SqlExpressDBFileAutoCreationErrorFormatter(e)); 257 throw httpExec; 258 } 259 } 260 DetachDB(string databaseName, string connectionString)261 private static void DetachDB(string databaseName, string connectionString) { 262 SqlConnection connection = new SqlConnection(connectionString); 263 try { 264 connection.Open(); 265 SqlCommand command = new SqlCommand("USE master", connection); 266 command.ExecuteNonQuery(); 267 command = new SqlCommand("sp_detach_db", connection); 268 command.CommandType = CommandType.StoredProcedure; 269 command.Parameters.AddWithValue("@dbname", databaseName); 270 command.Parameters.AddWithValue("@skipchecks", "true"); 271 command.ExecuteNonQuery(); 272 } 273 catch { 274 } 275 finally { 276 connection.Close(); 277 } 278 } 279 } 280 281 internal sealed class SqlConnectionHolder { 282 internal SqlConnection _Connection; 283 private bool _Opened; 284 285 internal SqlConnection Connection { 286 get { return _Connection; } 287 } 288 SqlConnectionHolder(string connectionString)289 internal SqlConnectionHolder(string connectionString) { 290 try { 291 _Connection = new SqlConnection(connectionString); 292 System.Web.Util.Debug.Assert(_Connection != null); 293 } 294 catch (ArgumentException e) { 295 throw new ArgumentException(SR.GetString(SR.SqlError_Connection_String), "connectionString", e); 296 } 297 } 298 Open(HttpContext context, bool revertImpersonate)299 internal void Open(HttpContext context, bool revertImpersonate) { 300 if (_Opened) 301 return; // Already opened 302 303 if (revertImpersonate) { 304 using (new ApplicationImpersonationContext()) { 305 Connection.Open(); 306 } 307 } 308 else { 309 Connection.Open(); 310 } 311 312 _Opened = true; // Open worked! 313 } 314 Close()315 internal void Close() { 316 if (!_Opened) // Not open! 317 return; 318 // Close connection 319 Connection.Close(); 320 _Opened = false; 321 } 322 } 323 } 324 325