1 //------------------------------------------------------------------------------
2 // <copyright file="SqlBatchCommand.cs" company="Microsoft">
3 //     Copyright (c) Microsoft Corporation.  All rights reserved.
4 // </copyright>
5 // <owner current="true" primary="true">Microsoft</owner>
6 // <owner current="true" primary="false">Microsoft</owner>
7 //------------------------------------------------------------------------------
8 
9 namespace System.Data.SqlClient {
10 
11     using System;
12     using System.Collections.Generic;
13     using System.ComponentModel;
14     using System.Data;
15     using System.Data.Common;
16     using System.Diagnostics;
17     using System.Globalization;
18     using System.Text;
19     using System.Text.RegularExpressions;
20 
21     internal sealed class SqlCommandSet {
22 
23         private const string SqlIdentifierPattern = "^@[\\p{Lo}\\p{Lu}\\p{Ll}\\p{Lm}_@#][\\p{Lo}\\p{Lu}\\p{Ll}\\p{Lm}\\p{Nd}\uff3f_@#\\$]*$";
24         private static readonly Regex SqlIdentifierParser = new Regex(SqlIdentifierPattern, RegexOptions.ExplicitCapture|RegexOptions.Singleline);
25 
26         private List<LocalCommand> _commandList = new List<LocalCommand>();
27 
28         private SqlCommand _batchCommand;
29 
30         private static int _objectTypeCount; // Bid counter
31         internal readonly int _objectID = System.Threading.Interlocked.Increment(ref _objectTypeCount);
32 
33         private sealed class LocalCommand {
34             internal readonly string CommandText;
35             internal readonly SqlParameterCollection Parameters;
36             internal readonly int ReturnParameterIndex;
37             internal readonly CommandType CmdType;
38             internal readonly SqlCommandColumnEncryptionSetting ColumnEncryptionSetting;
39 
LocalCommand(string commandText, SqlParameterCollection parameters, int returnParameterIndex, CommandType cmdType, SqlCommandColumnEncryptionSetting columnEncryptionSetting)40             internal LocalCommand(string commandText, SqlParameterCollection parameters,  int returnParameterIndex, CommandType cmdType, SqlCommandColumnEncryptionSetting columnEncryptionSetting) {
41                 Debug.Assert(0 <= commandText.Length, "no text");
42                 this.CommandText = commandText;
43                 this.Parameters = parameters;
44                 this.ReturnParameterIndex = returnParameterIndex;
45                 this.CmdType = cmdType;
46                 this.ColumnEncryptionSetting = columnEncryptionSetting;
47             }
48         }
49 
SqlCommandSet()50         internal SqlCommandSet() : base() {
51             _batchCommand = new SqlCommand();
52         }
53 
54         private SqlCommand BatchCommand {
55             get {
56                 SqlCommand command = _batchCommand;
57                 if (null == command) {
58                     throw ADP.ObjectDisposed(this);
59                 }
60                 return command;
61             }
62         }
63 
64         internal int CommandCount {
65             get {
66                 return CommandList.Count;
67             }
68         }
69 
70         private List<LocalCommand> CommandList {
71             get {
72                 List<LocalCommand> commandList = _commandList;
73                 if (null == commandList) {
74                     throw ADP.ObjectDisposed(this);
75                 }
76                 return commandList;
77             }
78         }
79 
80         internal int CommandTimeout {
81             /*get {
82                 return BatchCommand.CommandTimeout;
83             }*/
84             set {
85                 BatchCommand.CommandTimeout = value;
86             }
87         }
88 
89         internal SqlConnection Connection {
90             get {
91                 return BatchCommand.Connection;
92             }
93             set {
94                 BatchCommand.Connection = value;
95             }
96         }
97 
98         internal SqlTransaction Transaction {
99             /*get {
100                 return BatchCommand.Transaction;
101             }*/
102             set {
103                 BatchCommand.Transaction = value;
104             }
105         }
106 
107         internal int ObjectID {
108             get {
109                 return _objectID;
110             }
111         }
112 
Append(SqlCommand command)113         internal void Append(SqlCommand command) {
114             ADP.CheckArgumentNull(command, "command");
115             Bid.Trace("<sc.SqlCommandSet.Append|API> %d#, command=%d, parameterCount=%d\n", ObjectID, command.ObjectID, command.Parameters.Count);
116 
117             string cmdText = command.CommandText;
118             if (ADP.IsEmpty(cmdText)) {
119                 throw ADP.CommandTextRequired(ADP.Append);
120             }
121 
122             CommandType commandType = command.CommandType;
123             switch(commandType) {
124             case CommandType.Text:
125             case CommandType.StoredProcedure:
126                 break;
127             case CommandType.TableDirect:
128                 Debug.Assert(false, "command.CommandType");
129                 throw System.Data.SqlClient.SQL.NotSupportedCommandType(commandType);
130             default:
131                 Debug.Assert(false, "command.CommandType");
132                 throw ADP.InvalidCommandType(commandType);
133             }
134 
135             SqlParameterCollection parameters = null;
136 
137             SqlParameterCollection collection = command.Parameters;
138             if (0 < collection.Count) {
139                 parameters = new SqlParameterCollection();
140 
141                 // clone parameters so they aren't destroyed
142                 for(int i = 0; i < collection.Count; ++i) {
143                     SqlParameter p = new SqlParameter();
144                     collection[i].CopyTo(p);
145                     parameters.Add(p);
146 
147                     // SQL Injection awarene
148                     if (!SqlIdentifierParser.IsMatch(p.ParameterName)) {
149                         throw ADP.BadParameterName(p.ParameterName);
150                     }
151                 }
152 
153                 foreach(SqlParameter p in parameters) {
154                     // deep clone the parameter value if byte[] or char[]
155                     object obj = p.Value;
156                     byte[] byteValues = (obj as byte[]);
157                     if (null != byteValues) {
158                         int offset = p.Offset;
159                         int size = p.Size;
160                         int countOfBytes = byteValues.Length - offset;
161                         if ((0 != size) && (size < countOfBytes)) {
162                             countOfBytes = size;
163                         }
164                         byte[] copy = new byte[Math.Max(countOfBytes, 0)];
165                         Buffer.BlockCopy(byteValues, offset, copy, 0, copy.Length);
166                         p.Offset = 0;
167                         p.Value = copy;
168                     }
169                     else {
170                         char[] charValues = (obj as char[]);
171                         if (null != charValues) {
172                             int offset = p.Offset;
173                             int size = p.Size;
174                             int countOfChars = charValues.Length - offset;
175                             if ((0 != size) && (size < countOfChars)) {
176                                 countOfChars = size;
177                             }
178                             char[] copy = new char[Math.Max(countOfChars, 0)];
179                             Buffer.BlockCopy(charValues, offset, copy, 0, copy.Length*2);
180                             p.Offset = 0;
181                             p.Value = copy;
182                         }
183                         else {
184                             ICloneable cloneable = (obj as ICloneable);
185                             if (null != cloneable) {
186                                 p.Value = cloneable.Clone();
187                             }
188                         }
189                     }
190                 }
191             }
192 
193             int returnParameterIndex = -1;
194             if (null != parameters) {
195                 for(int i = 0; i < parameters.Count; ++i) {
196                     if (ParameterDirection.ReturnValue == parameters[i].Direction) {
197                         returnParameterIndex = i;
198                         break;
199                     }
200                 }
201             }
202             LocalCommand cmd = new LocalCommand(cmdText, parameters, returnParameterIndex, command.CommandType, command.ColumnEncryptionSetting);
203             CommandList.Add(cmd);
204         }
205 
BuildStoredProcedureName(StringBuilder builder, string part)206         internal static void BuildStoredProcedureName(StringBuilder builder, string part) {
207             if ((null != part) && (0 < part.Length)) {
208                 if ('[' == part[0]) {
209                     int count = 0;
210                     foreach(char c in part) {
211                         if (']' == c) {
212                             count++;
213                         }
214                     }
215                     if (1 == (count%2)) {
216                         builder.Append(part);
217                         return;
218                     }
219                 }
220 
221                 // the part is not escaped, escape it now
222                 SqlServerEscapeHelper.EscapeIdentifier(builder, part);
223             }
224         }
225 
Clear()226         internal void Clear() {
227             Bid.Trace("<sc.SqlCommandSet.Clear|API> %d#\n", ObjectID);
228             DbCommand batchCommand = BatchCommand;
229             if (null != batchCommand) {
230                 batchCommand.Parameters.Clear();
231                 batchCommand.CommandText = null;
232             }
233             List<LocalCommand> commandList = _commandList;
234             if (null != commandList) {
235                 commandList.Clear();
236             }
237         }
238 
Dispose()239         internal void Dispose() {
240             Bid.Trace("<sc.SqlCommandSet.Dispose|API> %d#\n", ObjectID);
241             SqlCommand command = _batchCommand;
242             _commandList = null;
243             _batchCommand = null;
244 
245             if (null != command) {
246                 command.Dispose();
247             }
248         }
249 
ExecuteNonQuery()250         internal int ExecuteNonQuery() {
251             SqlConnection.ExecutePermission.Demand();
252 
253             IntPtr hscp;
254             Bid.ScopeEnter(out hscp, "<sc.SqlCommandSet.ExecuteNonQuery|API> %d#", ObjectID);
255             try {
256                 if (Connection.IsContextConnection) {
257                     throw SQL.BatchedUpdatesNotAvailableOnContextConnection();
258                 }
259                 ValidateCommandBehavior(ADP.ExecuteNonQuery, CommandBehavior.Default);
260                 BatchCommand.BatchRPCMode = true;
261                 BatchCommand.ClearBatchCommand();
262                 BatchCommand.Parameters.Clear();
263                 for (int ii = 0 ; ii < _commandList.Count; ii++) {
264                     LocalCommand cmd = _commandList[ii];
265                     BatchCommand.AddBatchCommand(cmd.CommandText, cmd.Parameters, cmd.CmdType, cmd.ColumnEncryptionSetting);
266                 }
267                 return BatchCommand.ExecuteBatchRPCCommand();
268             }
269             finally {
270                 Bid.ScopeLeave(ref hscp);
271             }
272         }
273 
GetParameter(int commandIndex, int parameterIndex)274         internal SqlParameter GetParameter(int commandIndex, int parameterIndex) {
275             return CommandList[commandIndex].Parameters[parameterIndex];
276         }
277 
GetBatchedAffected(int commandIdentifier, out int recordsAffected, out Exception error)278         internal bool GetBatchedAffected(int commandIdentifier, out int recordsAffected, out Exception error) {
279             error = BatchCommand.GetErrors(commandIdentifier);
280             int? affected = BatchCommand.GetRecordsAffected(commandIdentifier);
281             recordsAffected = affected.GetValueOrDefault();
282             return affected.HasValue;
283         }
284 
GetParameterCount(int commandIndex)285         internal int GetParameterCount(int commandIndex) {
286             return CommandList[commandIndex].Parameters.Count;
287         }
288 
ValidateCommandBehavior(string method, CommandBehavior behavior)289         private void ValidateCommandBehavior(string method, CommandBehavior behavior) {
290             if (0 != (behavior & ~(CommandBehavior.SequentialAccess|CommandBehavior.CloseConnection))) {
291                 ADP.ValidateCommandBehavior(behavior);
292                 throw ADP.NotSupportedCommandBehavior(behavior & ~(CommandBehavior.SequentialAccess|CommandBehavior.CloseConnection), method);
293             }
294         }
295     }
296 }
297 
298